19
19
from enum import Enum
20
20
from typing import Any , Dict , List , Optional , Tuple , Type
21
21
22
+ import numpy as np
22
23
import torch
23
24
import torch_npu
24
25
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
26
AttentionLayer , AttentionType )
26
- from vllm .attention .backends .utils import CommonAttentionState
27
+ from vllm .attention .backends .utils import CommonAttentionState , PAD_SLOT_ID
27
28
from vllm .config import get_current_vllm_config
28
29
from vllm .forward_context import ForwardContext , get_forward_context
29
30
from vllm .utils import direct_register_custom_op
32
33
33
34
from vllm_ascend .attention .utils import \
34
35
AscendCommonAttentionMetadata as CommonAttentionMetadata
35
- from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
36
36
from vllm_ascend .ops .attention import vanilla_chunked_prefill
37
37
from vllm_ascend .utils import get_graph_params
38
+ from vllm_ascend .ascend_config import get_ascend_config
38
39
39
40
40
41
class AscendAttentionBackend (AttentionBackend ):
@@ -140,7 +141,8 @@ class AscendMetadata:
140
141
num_input_tokens : int = 0 # Number of tokens including padding.
141
142
142
143
enable_dbo_across_dp : bool = False
143
-
144
+ with_prefill_across_dp : bool = False
145
+ use_torchair_graph : bool = False
144
146
def split_metadata_for_multistream (
145
147
self ,
146
148
ms_split_config : MSAttentionMetadataSplitConfig ,
@@ -153,7 +155,6 @@ def split_metadata_for_multistream(
153
155
_metadata_cls = AscendMetadata ,
154
156
)
155
157
156
-
157
158
class AscendAttentionMetadataBuilder :
158
159
159
160
def __init__ (self , runner ):
@@ -163,6 +164,32 @@ def reorder_batch(self, input_batch: "InputBatch",
163
164
scheduler_output : "SchedulerOutput" ) -> bool :
164
165
return False
165
166
167
+ def _get_graph_runner_block_tables (
168
+ self , num_seqs : int , block_tables : torch .Tensor ) -> torch .Tensor :
169
+
170
+ max_batch_size , max_blocks = self .runner .graph_block_tables .shape
171
+ assert max_batch_size >= num_seqs
172
+
173
+ if isinstance (self .runner .graph_block_tables , np .ndarray ):
174
+ graph_block_tables = torch .zeros ((max_batch_size , max_blocks ),
175
+ dtype = block_tables .dtype ,
176
+ device = block_tables .device )
177
+ else :
178
+ graph_block_tables = self .runner .graph_block_tables .to (
179
+ device = block_tables .device , dtype = block_tables .dtype )
180
+
181
+ num_blocks = block_tables .size (1 )
182
+ if num_blocks <= max_blocks :
183
+ graph_block_tables [:num_seqs , :
184
+ num_blocks ] = block_tables [:num_seqs , :
185
+ num_blocks ]
186
+ else :
187
+ graph_block_tables [:num_seqs , :
188
+ max_blocks ] = block_tables [:num_seqs , :
189
+ max_blocks ]
190
+
191
+ return graph_block_tables [:num_seqs , :max_blocks ]
192
+
166
193
def build (self ,
167
194
num_reqs ,
168
195
num_actual_tokens ,
@@ -188,6 +215,41 @@ def build(self,
188
215
slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
189
216
attn_mask = self .runner .attn_mask
190
217
attn_state = self .runner .attn_state
218
+ query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
219
+ query_start_loc = query_start_loc_cpu .to (self .runner .device ,
220
+ non_blocking = True )
221
+
222
+ graph_pad_size = kwargs ["graph_pad_size" ]
223
+ with_prefill_across_dp = kwargs ["with_prefill_across_dp" ]
224
+ use_torchair_graph = graph_pad_size != - 1
225
+ if not with_prefill_across_dp :
226
+ if use_torchair_graph and self .runner .attn_state in [
227
+ AscendAttentionState .DecodeOnly ,
228
+ AscendAttentionState .SpecDecoding
229
+ ]:
230
+ num_seqs = len (seq_lens )
231
+ if graph_pad_size != 0 :
232
+ pad_value = 1
233
+ padded_seq_lens = seq_lens .tolist () + [pad_value
234
+ ] * graph_pad_size
235
+ else :
236
+ padded_seq_lens = seq_lens .tolist ()
237
+
238
+ seq_lens = torch .from_numpy (
239
+ np .array (padded_seq_lens ).astype (np .int32 ))
240
+ padding = torch .full ((graph_pad_size , ),
241
+ PAD_SLOT_ID ,
242
+ dtype = slot_mapping .dtype ,
243
+ device = slot_mapping .device )
244
+ slot_mapping = torch .cat ([slot_mapping , padding ])
245
+ block_table_padding = torch .zeros (
246
+ (graph_pad_size , ) + block_table .shape [1 :],
247
+ dtype = block_table .dtype ,
248
+ device = block_table .device )
249
+ block_table = torch .cat ([block_table , block_table_padding ],
250
+ dim = 0 )
251
+ block_table = self ._get_graph_runner_block_tables (
252
+ num_seqs + graph_pad_size , block_table )
191
253
192
254
attn_metadata = AscendMetadata (
193
255
num_actual_tokens = num_actual_tokens ,
@@ -200,7 +262,44 @@ def build(self,
200
262
slot_mapping = slot_mapping ,
201
263
attn_mask = attn_mask ,
202
264
attn_state = attn_state ,
203
- enable_dbo_across_dp = enable_dbo_across_dp )
265
+ enable_dbo_across_dp = enable_dbo_across_dp ,
266
+ with_prefill_across_dp = with_prefill_across_dp ,
267
+ use_torchair_graph = use_torchair_graph
268
+ )
269
+ return attn_metadata
270
+
271
+ def build_torchair_graph_dummy (self , num_reqs : int , num_actual_tokens : int ):
272
+ device = self .runner .device
273
+ _ , max_blocks = self .runner .graph_block_tables .shape
274
+ block_table = torch .zeros ((num_reqs , max_blocks ),
275
+ dtype = torch .int32 ,
276
+ device = device )
277
+ block_table = self ._get_graph_runner_block_tables (
278
+ num_reqs , block_table )
279
+ seq_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
280
+ slot_mapping = torch .full ((num_reqs , ),
281
+ PAD_SLOT_ID ,
282
+ dtype = torch .int32 ,
283
+ device = device )
284
+ query_start_loc = torch .full ((num_reqs , ),
285
+ - 1 ,
286
+ dtype = torch .int32 ,
287
+ device = device )
288
+
289
+ query_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
290
+ attn_mask = self .runner .attn_mask
291
+
292
+ attn_metadata = AscendMetadata (
293
+ num_actual_tokens = num_actual_tokens ,
294
+ block_tables = block_table ,
295
+ query_start_loc = query_start_loc ,
296
+ query_lens = query_lens ,
297
+ seq_lens = seq_lens ,
298
+ seq_lens_list = seq_lens .tolist (),
299
+ max_query_len = query_lens .max ().item (),
300
+ slot_mapping = slot_mapping ,
301
+ attn_mask = attn_mask ,
302
+ attn_state = AscendAttentionState .DecodeOnly )
204
303
return attn_metadata
205
304
206
305
def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
@@ -248,6 +347,7 @@ def __init__(
248
347
attn_type : str = AttentionType .DECODER ,
249
348
kv_sharing_target_layer_name : Optional [str ] = None ,
250
349
use_irope : bool = False ,
350
+ prefix : Optional [str ] = None ,
251
351
) -> None :
252
352
self .num_heads = num_heads
253
353
self .head_size = head_size
@@ -267,11 +367,34 @@ def __init__(
267
367
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
268
368
self .key_cache = None
269
369
self .value_cache = None
370
+ ascend_config = get_ascend_config ()
371
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
270
372
271
373
vllm_config = get_current_vllm_config ()
272
374
self .full_graph = vllm_config .compilation_config .full_cuda_graph
273
375
self .block_size = vllm_config .cache_config .block_size
274
376
377
+ def update_kv_cache (
378
+ self ,
379
+ key : torch .Tensor ,
380
+ value : torch .Tensor ,
381
+ key_cache : torch .Tensor ,
382
+ value_cache : torch .Tensor ,
383
+ slot_indices : torch .Tensor
384
+ ) -> None :
385
+ # calc indices by block_size
386
+ block_size = key_cache .shape [1 ]
387
+ slot_indices = slot_indices .view (- 1 ,1 ,1 ).to (torch .int64 )
388
+ block_idx = torch .div (slot_indices , block_size , rounding_mode = 'floor' )
389
+ block_offset = slot_indices % block_size
390
+ indices = torch .cat ([block_idx , block_offset ], dim = 2 )
391
+ indices = indices .npu ()
392
+
393
+ # [blocknum, blocksize, numKvHeads, headDims]
394
+ # -> [blocknum, blocksize, numKvHeads * headDims]
395
+ torch_npu .npu_scatter_nd_update_ (key_cache , indices , key )
396
+ torch_npu .npu_scatter_nd_update_ (value_cache , indices , value )
397
+
275
398
def forward (
276
399
self ,
277
400
layer : AttentionLayer ,
@@ -320,12 +443,18 @@ def forward(
320
443
if self .key_cache is None :
321
444
self .key_cache , self .value_cache = kv_cache [0 ], kv_cache [1 ]
322
445
slots = attn_metadata .slot_mapping
323
- torch_npu ._npu_reshape_and_cache (
324
- key = key [:num_actual_tokens ],
325
- value = value [:num_actual_tokens ],
326
- key_cache = self .key_cache ,
327
- value_cache = self .value_cache ,
328
- slot_indices = slots )
446
+ if not attn_metadata .with_prefill_across_dp and self .torchair_graph_enabled :
447
+ self .update_kv_cache (key = key ,
448
+ value = value ,
449
+ key_cache = self .key_cache ,
450
+ value_cache = self .value_cache ,
451
+ slot_indices = slots .to (torch .int64 ))
452
+ else :
453
+ torch_npu ._npu_reshape_and_cache (key = key [:num_actual_tokens ],
454
+ value = value [:num_actual_tokens ],
455
+ key_cache = self .key_cache ,
456
+ value_cache = self .value_cache ,
457
+ slot_indices = slots )
329
458
330
459
if hasattr (layer , 'quant_method' ):
331
460
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
@@ -363,10 +492,25 @@ def forward(
363
492
scale_value = self .scale ,
364
493
out = output )
365
494
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
366
- graph_params = get_graph_params ()
367
-
368
- forward_context = get_forward_context ()
369
- if not forward_context .capturing :
495
+ if self .torchair_graph_enabled :
496
+ # query change to BSND
497
+ query = query .view (- 1 , 1 , self .num_heads * self .head_size )
498
+ # [blocknum, numKvHeads, blocksize, headDims] -> [blocknum, blocksize, numKvHeads * headDims]
499
+ key_cache = self .key_cache .view (* self .key_cache .shape [:- 2 ], - 1 )
500
+ value_cache = self .value_cache .view (* self .value_cache .shape [:- 2 ], - 1 )
501
+
502
+ output = torch_npu .npu_incre_flash_attention (
503
+ query = query ,
504
+ key = key_cache ,
505
+ value = value_cache ,
506
+ num_heads = self .num_heads ,
507
+ num_key_value_heads = self .num_kv_heads ,
508
+ input_layout = 'BSH' ,
509
+ scale_value = self .scale ,
510
+ actual_seq_lengths = attn_metadata .seq_lens_list ,
511
+ block_table = attn_metadata .block_tables ,
512
+ block_size = kv_cache [0 ].shape [1 ],)
513
+ elif not get_forward_context ().capturing :
370
514
torch_npu ._npu_paged_attention (
371
515
query = query ,
372
516
key_cache = self .key_cache ,
@@ -384,6 +528,7 @@ def forward(
384
528
event = torch .npu .ExternalEvent ()
385
529
event .wait (stream )
386
530
event .reset (stream )
531
+ graph_params = get_graph_params ()
387
532
graph_params .events [num_tokens ].append (event )
388
533
389
534
graph_params .attn_params [num_tokens ].append ((
0 commit comments