30
30
from vllm .v1 .worker .gpu_input_batch import InputBatch
31
31
32
32
from vllm_ascend .ops .attention import vanilla_chunked_prefill
33
+ from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
34
+ nd_to_nz_2d , nd_to_nz_spec )
33
35
34
36
35
37
class AscendAttentionBackend (AttentionBackend ):
@@ -62,6 +64,9 @@ def get_kv_cache_shape(
62
64
num_kv_heads : int ,
63
65
head_size : int ,
64
66
) -> Tuple [int , ...]:
67
+ if is_310p ():
68
+ return (2 , num_blocks , num_kv_heads * head_size // 16 , block_size ,
69
+ 16 )
65
70
return (2 , num_blocks , block_size , num_kv_heads , head_size )
66
71
67
72
@staticmethod
@@ -166,6 +171,16 @@ def build(self,
166
171
query_start_loc = query_start_loc_cpu .to (self .runner .device ,
167
172
non_blocking = True )
168
173
174
+ if is_310p ():
175
+ if attn_state == AscendAttentionState .PrefillNoCache :
176
+ mask_nz = nd_to_nz_2d (attn_mask )
177
+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
178
+ ACL_FORMAT_FRACTAL_NZ )
179
+ elif attn_state == AscendAttentionState .ChunkedPrefill :
180
+ mask_nz = nd_to_nz_spec (attn_mask )
181
+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
182
+ ACL_FORMAT_FRACTAL_NZ )
183
+
169
184
attn_metadata = AscendMetadata (
170
185
num_actual_tokens = num_actual_tokens ,
171
186
block_tables = block_table ,
@@ -249,6 +264,7 @@ def forward(
249
264
self .head_size ,
250
265
dtype = query .dtype ,
251
266
device = query .device )
267
+ ori_output = output
252
268
if trace_flag :
253
269
torch .ops .vllm .unified_ascend_attention_with_output (
254
270
query = query ,
@@ -293,6 +309,18 @@ def forward(
293
309
assert attn_metadata is not None
294
310
assert attn_metadata .attn_mask is not None
295
311
mask = attn_metadata .attn_mask
312
+ if is_310p ():
313
+ # align q k v output tensors
314
+ query = aligned_16 (query )
315
+ key = aligned_16 (key )
316
+ value = aligned_16 (value )
317
+ output = aligned_16 (output )
318
+
319
+ # do reformat in case of broadcasted tensors
320
+ mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
321
+ mask = torch_npu .npu_format_cast (mask .contiguous (),
322
+ ACL_FORMAT_FRACTAL_NZ )
323
+
296
324
torch_npu ._npu_flash_attention (query = query ,
297
325
key = key ,
298
326
value = value ,
@@ -302,6 +330,7 @@ def forward(
302
330
num_heads = self .num_heads ,
303
331
num_kv_heads = self .num_kv_heads ,
304
332
out = output )
333
+ output = output [:num_tokens , :, :]
305
334
elif attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
306
335
assert attn_metadata is not None
307
336
assert attn_metadata .attn_mask is not None
@@ -319,6 +348,10 @@ def forward(
319
348
scale_value = self .scale ,
320
349
out = output )
321
350
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
351
+ if is_310p ():
352
+ # # seq_lens_tensor needs to be transferred to the device for 310P
353
+ attn_metadata .seq_lens = \
354
+ attn_metadata .seq_lens .to (device = query .device )
322
355
torch_npu ._npu_paged_attention (
323
356
query = query ,
324
357
key_cache = self .key_cache ,
@@ -352,6 +385,14 @@ def forward(
352
385
self .scale , None , True )
353
386
else :
354
387
# use paged attention
388
+ assert attn_metadata is not None
389
+ assert attn_metadata .attn_mask is not None
390
+ if is_310p ():
391
+ # do reformat in case of broadcasted tensors
392
+ attn_metadata .attn_mask = \
393
+ torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (), ACL_FORMAT_FRACTAL_NZ )
394
+ attn_metadata .seq_lens = \
395
+ attn_metadata .seq_lens .to (device = query .device )
355
396
torch_npu ._npu_paged_attention_splitfuse (
356
397
query = query ,
357
398
key_cache = self .key_cache ,
@@ -364,6 +405,10 @@ def forward(
364
405
num_heads = self .num_heads ,
365
406
scale_value = self .scale ,
366
407
out = output )
408
+
409
+ # to make in-place change to the output tensor
410
+ if not id (ori_output ) == id (output ):
411
+ ori_output [:, :, :] = output [:num_tokens , :, :]
367
412
return output .view (num_tokens , self .hidden_size )
368
413
369
414
0 commit comments