19
19
from enum import Enum
20
20
from typing import Any , Dict , List , Optional , Tuple , Type
21
21
22
+ import numpy as np
22
23
import torch
23
24
import torch_npu
24
25
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
26
AttentionLayer , AttentionType )
26
- from vllm .attention .backends .utils import CommonAttentionState
27
+ from vllm .attention .backends .utils import CommonAttentionState , PAD_SLOT_ID
27
28
from vllm .config import get_current_vllm_config
28
29
from vllm .forward_context import ForwardContext , get_forward_context
29
30
from vllm .utils import direct_register_custom_op
35
36
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
36
37
from vllm_ascend .ops .attention import vanilla_chunked_prefill
37
38
from vllm_ascend .utils import get_graph_params
39
+ from vllm_ascend .ascend_config import get_ascend_config
38
40
39
41
40
42
class AscendAttentionBackend (AttentionBackend ):
@@ -140,6 +142,8 @@ class AscendMetadata:
140
142
num_input_tokens : int = 0 # Number of tokens including padding.
141
143
142
144
enable_dbo_across_dp : bool = False
145
+ with_prefill_across_dp : bool = False
146
+ use_torchair_graph : bool = False
143
147
144
148
def split_metadata_for_multistream (
145
149
self ,
@@ -153,7 +157,6 @@ def split_metadata_for_multistream(
153
157
_metadata_cls = AscendMetadata ,
154
158
)
155
159
156
-
157
160
class AscendAttentionMetadataBuilder :
158
161
159
162
def __init__ (self , runner ):
@@ -163,6 +166,32 @@ def reorder_batch(self, input_batch: "InputBatch",
163
166
scheduler_output : "SchedulerOutput" ) -> bool :
164
167
return False
165
168
169
+ def _get_graph_runner_block_tables (
170
+ self , num_seqs : int , block_tables : torch .Tensor ) -> torch .Tensor :
171
+
172
+ max_batch_size , max_blocks = self .runner .graph_block_tables .shape
173
+ assert max_batch_size >= num_seqs
174
+
175
+ if isinstance (self .runner .graph_block_tables , np .ndarray ):
176
+ graph_block_tables = torch .zeros ((max_batch_size , max_blocks ),
177
+ dtype = block_tables .dtype ,
178
+ device = block_tables .device )
179
+ else :
180
+ graph_block_tables = self .runner .graph_block_tables .to (
181
+ device = block_tables .device , dtype = block_tables .dtype )
182
+
183
+ num_blocks = block_tables .size (1 )
184
+ if num_blocks <= max_blocks :
185
+ graph_block_tables [:num_seqs , :
186
+ num_blocks ] = block_tables [:num_seqs , :
187
+ num_blocks ]
188
+ else :
189
+ graph_block_tables [:num_seqs , :
190
+ max_blocks ] = block_tables [:num_seqs , :
191
+ max_blocks ]
192
+
193
+ return graph_block_tables [:num_seqs , :max_blocks ]
194
+
166
195
def build (self ,
167
196
num_reqs ,
168
197
num_actual_tokens ,
@@ -188,6 +217,41 @@ def build(self,
188
217
slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
189
218
attn_mask = self .runner .attn_mask
190
219
attn_state = self .runner .attn_state
220
+ query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
221
+ query_start_loc = query_start_loc_cpu .to (self .runner .device ,
222
+ non_blocking = True )
223
+
224
+ graph_pad_size = kwargs .get ("graph_pad_size" , - 1 )
225
+ with_prefill_across_dp = kwargs ["with_prefill_across_dp" ]
226
+ use_torchair_graph = graph_pad_size != - 1
227
+ if not with_prefill_across_dp :
228
+ if use_torchair_graph and self .runner .attn_state in [
229
+ AscendAttentionState .DecodeOnly ,
230
+ AscendAttentionState .SpecDecoding
231
+ ]:
232
+ num_seqs = len (seq_lens )
233
+ if graph_pad_size != 0 :
234
+ pad_value = 1
235
+ padded_seq_lens = seq_lens .tolist () + [pad_value
236
+ ] * graph_pad_size
237
+ else :
238
+ padded_seq_lens = seq_lens .tolist ()
239
+
240
+ seq_lens = torch .from_numpy (
241
+ np .array (padded_seq_lens ).astype (np .int32 ))
242
+ padding = torch .full ((graph_pad_size , ),
243
+ PAD_SLOT_ID ,
244
+ dtype = slot_mapping .dtype ,
245
+ device = slot_mapping .device )
246
+ slot_mapping = torch .cat ([slot_mapping , padding ])
247
+ block_table_padding = torch .zeros (
248
+ (graph_pad_size , ) + block_table .shape [1 :],
249
+ dtype = block_table .dtype ,
250
+ device = block_table .device )
251
+ block_table = torch .cat ([block_table , block_table_padding ],
252
+ dim = 0 )
253
+ block_table = self ._get_graph_runner_block_tables (
254
+ num_seqs + graph_pad_size , block_table )
191
255
192
256
attn_metadata = AscendMetadata (
193
257
num_actual_tokens = num_actual_tokens ,
@@ -200,7 +264,44 @@ def build(self,
200
264
slot_mapping = slot_mapping ,
201
265
attn_mask = attn_mask ,
202
266
attn_state = attn_state ,
203
- enable_dbo_across_dp = enable_dbo_across_dp )
267
+ enable_dbo_across_dp = enable_dbo_across_dp ,
268
+ with_prefill_across_dp = with_prefill_across_dp ,
269
+ use_torchair_graph = use_torchair_graph
270
+ )
271
+ return attn_metadata
272
+
273
+ def build_torchair_graph_dummy (self , num_reqs : int , num_actual_tokens : int ):
274
+ device = self .runner .device
275
+ _ , max_blocks = self .runner .graph_block_tables .shape
276
+ block_table = torch .zeros ((num_reqs , max_blocks ),
277
+ dtype = torch .int32 ,
278
+ device = device )
279
+ block_table = self ._get_graph_runner_block_tables (
280
+ num_reqs , block_table )
281
+ seq_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
282
+ slot_mapping = torch .full ((num_reqs , ),
283
+ PAD_SLOT_ID ,
284
+ dtype = torch .int32 ,
285
+ device = device )
286
+ query_start_loc = torch .full ((num_reqs , ),
287
+ - 1 ,
288
+ dtype = torch .int32 ,
289
+ device = device )
290
+
291
+ query_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
292
+ attn_mask = self .runner .attn_mask
293
+
294
+ attn_metadata = AscendMetadata (
295
+ num_actual_tokens = num_actual_tokens ,
296
+ block_tables = block_table ,
297
+ query_start_loc = query_start_loc ,
298
+ query_lens = query_lens ,
299
+ seq_lens = seq_lens ,
300
+ seq_lens_list = seq_lens .tolist (),
301
+ max_query_len = query_lens .max ().item (),
302
+ slot_mapping = slot_mapping ,
303
+ attn_mask = attn_mask ,
304
+ attn_state = AscendAttentionState .DecodeOnly )
204
305
return attn_metadata
205
306
206
307
def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
@@ -248,6 +349,7 @@ def __init__(
248
349
attn_type : str = AttentionType .DECODER ,
249
350
kv_sharing_target_layer_name : Optional [str ] = None ,
250
351
use_irope : bool = False ,
352
+ prefix : Optional [str ] = None ,
251
353
) -> None :
252
354
self .num_heads = num_heads
253
355
self .head_size = head_size
@@ -267,11 +369,34 @@ def __init__(
267
369
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
268
370
self .key_cache = None
269
371
self .value_cache = None
372
+ ascend_config = get_ascend_config ()
373
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
270
374
271
375
vllm_config = get_current_vllm_config ()
272
376
self .full_graph = vllm_config .compilation_config .full_cuda_graph
273
377
self .block_size = vllm_config .cache_config .block_size
274
378
379
+ def update_kv_cache (
380
+ self ,
381
+ key : torch .Tensor ,
382
+ value : torch .Tensor ,
383
+ key_cache : torch .Tensor ,
384
+ value_cache : torch .Tensor ,
385
+ slot_indices : torch .Tensor
386
+ ) -> None :
387
+ # calc indices by block_size
388
+ block_size = key_cache .shape [1 ]
389
+ slot_indices = slot_indices .view (- 1 ,1 ,1 ).to (torch .int64 )
390
+ block_idx = torch .div (slot_indices , block_size , rounding_mode = 'floor' )
391
+ block_offset = slot_indices % block_size
392
+ indices = torch .cat ([block_idx , block_offset ], dim = 2 )
393
+ indices = indices .npu ()
394
+
395
+ # [blocknum, blocksize, numKvHeads, headDims]
396
+ # -> [blocknum, blocksize, numKvHeads * headDims]
397
+ torch_npu .npu_scatter_nd_update_ (key_cache , indices , key )
398
+ torch_npu .npu_scatter_nd_update_ (value_cache , indices , value )
399
+
275
400
def forward (
276
401
self ,
277
402
layer : AttentionLayer ,
@@ -320,12 +445,18 @@ def forward(
320
445
if self .key_cache is None :
321
446
self .key_cache , self .value_cache = kv_cache [0 ], kv_cache [1 ]
322
447
slots = attn_metadata .slot_mapping
323
- torch_npu ._npu_reshape_and_cache (
324
- key = key [:num_actual_tokens ],
325
- value = value [:num_actual_tokens ],
326
- key_cache = self .key_cache ,
327
- value_cache = self .value_cache ,
328
- slot_indices = slots )
448
+ if not attn_metadata .with_prefill_across_dp and self .torchair_graph_enabled :
449
+ self .update_kv_cache (key = key ,
450
+ value = value ,
451
+ key_cache = self .key_cache ,
452
+ value_cache = self .value_cache ,
453
+ slot_indices = slots .to (torch .int64 ))
454
+ else :
455
+ torch_npu ._npu_reshape_and_cache (key = key [:num_actual_tokens ],
456
+ value = value [:num_actual_tokens ],
457
+ key_cache = self .key_cache ,
458
+ value_cache = self .value_cache ,
459
+ slot_indices = slots )
329
460
330
461
if hasattr (layer , 'quant_method' ):
331
462
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
@@ -363,10 +494,25 @@ def forward(
363
494
scale_value = self .scale ,
364
495
out = output )
365
496
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
366
- graph_params = get_graph_params ()
367
-
368
- forward_context = get_forward_context ()
369
- if not forward_context .capturing :
497
+ if self .torchair_graph_enabled :
498
+ # query change to BSND
499
+ query = query .view (- 1 , 1 , self .num_heads * self .head_size )
500
+ # [blocknum, numKvHeads, blocksize, headDims] -> [blocknum, blocksize, numKvHeads * headDims]
501
+ key_cache = self .key_cache .view (* self .key_cache .shape [:- 2 ], - 1 )
502
+ value_cache = self .value_cache .view (* self .value_cache .shape [:- 2 ], - 1 )
503
+
504
+ output = torch_npu .npu_incre_flash_attention (
505
+ query = query ,
506
+ key = key_cache ,
507
+ value = value_cache ,
508
+ num_heads = self .num_heads ,
509
+ num_key_value_heads = self .num_kv_heads ,
510
+ input_layout = 'BSH' ,
511
+ scale_value = self .scale ,
512
+ actual_seq_lengths = attn_metadata .seq_lens_list ,
513
+ block_table = attn_metadata .block_tables ,
514
+ block_size = kv_cache [0 ].shape [1 ],)
515
+ elif not get_forward_context ().capturing :
370
516
torch_npu ._npu_paged_attention (
371
517
query = query ,
372
518
key_cache = self .key_cache ,
@@ -384,6 +530,7 @@ def forward(
384
530
event = torch .npu .ExternalEvent ()
385
531
event .wait (stream )
386
532
event .reset (stream )
533
+ graph_params = get_graph_params ()
387
534
graph_params .events [num_tokens ].append (event )
388
535
389
536
graph_params .attn_params [num_tokens ].append ((
0 commit comments