19
19
from enum import Enum
20
20
from typing import Any , Dict , List , Optional , Tuple , Type
21
21
22
+ import numpy as np
22
23
import torch
23
24
import torch_npu
24
25
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
26
AttentionLayer , AttentionType )
26
- from vllm .attention .backends .utils import CommonAttentionState
27
+ from vllm .attention .backends .utils import PAD_SLOT_ID , CommonAttentionState
27
28
from vllm .config import get_current_vllm_config
28
29
from vllm .forward_context import ForwardContext , get_forward_context
29
30
from vllm .utils import direct_register_custom_op
30
31
from vllm .v1 .core .sched .output import SchedulerOutput
31
32
from vllm .v1 .worker .gpu_input_batch import InputBatch
32
33
34
+ from vllm_ascend .ascend_config import get_ascend_config
33
35
from vllm_ascend .attention .utils import \
34
36
AscendCommonAttentionMetadata as CommonAttentionMetadata
35
37
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
@@ -140,6 +142,8 @@ class AscendMetadata:
140
142
num_input_tokens : int = 0 # Number of tokens including padding.
141
143
142
144
enable_dbo_across_dp : bool = False
145
+ with_prefill_across_dp : bool = False
146
+ use_torchair_graph : bool = False
143
147
144
148
def split_metadata_for_multistream (
145
149
self ,
@@ -163,6 +167,32 @@ def reorder_batch(self, input_batch: "InputBatch",
163
167
scheduler_output : "SchedulerOutput" ) -> bool :
164
168
return False
165
169
170
+ def _get_graph_runner_block_tables (
171
+ self , num_seqs : int , block_tables : torch .Tensor ) -> torch .Tensor :
172
+
173
+ max_batch_size , max_blocks = self .runner .graph_block_tables .shape
174
+ assert max_batch_size >= num_seqs
175
+
176
+ if isinstance (self .runner .graph_block_tables , np .ndarray ):
177
+ graph_block_tables = torch .zeros ((max_batch_size , max_blocks ),
178
+ dtype = block_tables .dtype ,
179
+ device = block_tables .device )
180
+ else :
181
+ graph_block_tables = self .runner .graph_block_tables .to (
182
+ device = block_tables .device , dtype = block_tables .dtype )
183
+
184
+ num_blocks = block_tables .size (1 )
185
+ if num_blocks <= max_blocks :
186
+ graph_block_tables [:num_seqs , :
187
+ num_blocks ] = block_tables [:num_seqs , :
188
+ num_blocks ]
189
+ else :
190
+ graph_block_tables [:num_seqs , :
191
+ max_blocks ] = block_tables [:num_seqs , :
192
+ max_blocks ]
193
+
194
+ return graph_block_tables [:num_seqs , :max_blocks ]
195
+
166
196
def build (self ,
167
197
num_reqs ,
168
198
num_actual_tokens ,
@@ -188,6 +218,41 @@ def build(self,
188
218
slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
189
219
attn_mask = self .runner .attn_mask
190
220
attn_state = self .runner .attn_state
221
+ query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
222
+ query_start_loc = query_start_loc_cpu .to (self .runner .device ,
223
+ non_blocking = True )
224
+
225
+ graph_pad_size = kwargs .get ("graph_pad_size" , - 1 )
226
+ with_prefill_across_dp = kwargs ["with_prefill_across_dp" ]
227
+ use_torchair_graph = graph_pad_size != - 1
228
+ if not with_prefill_across_dp :
229
+ if use_torchair_graph and self .runner .attn_state in [
230
+ AscendAttentionState .DecodeOnly ,
231
+ AscendAttentionState .SpecDecoding
232
+ ]:
233
+ num_seqs = len (seq_lens )
234
+ if graph_pad_size != 0 :
235
+ pad_value = 1
236
+ padded_seq_lens = seq_lens .tolist () + [pad_value
237
+ ] * graph_pad_size
238
+ else :
239
+ padded_seq_lens = seq_lens .tolist ()
240
+
241
+ seq_lens = torch .from_numpy (
242
+ np .array (padded_seq_lens ).astype (np .int32 ))
243
+ padding = torch .full ((graph_pad_size , ),
244
+ PAD_SLOT_ID ,
245
+ dtype = slot_mapping .dtype ,
246
+ device = slot_mapping .device )
247
+ slot_mapping = torch .cat ([slot_mapping , padding ])
248
+ block_table_padding = torch .zeros (
249
+ (graph_pad_size , ) + block_table .shape [1 :],
250
+ dtype = block_table .dtype ,
251
+ device = block_table .device )
252
+ block_table = torch .cat ([block_table , block_table_padding ],
253
+ dim = 0 )
254
+ block_table = self ._get_graph_runner_block_tables (
255
+ num_seqs + graph_pad_size , block_table )
191
256
192
257
attn_metadata = AscendMetadata (
193
258
num_actual_tokens = num_actual_tokens ,
@@ -200,7 +265,44 @@ def build(self,
200
265
slot_mapping = slot_mapping ,
201
266
attn_mask = attn_mask ,
202
267
attn_state = attn_state ,
203
- enable_dbo_across_dp = enable_dbo_across_dp )
268
+ enable_dbo_across_dp = enable_dbo_across_dp ,
269
+ with_prefill_across_dp = with_prefill_across_dp ,
270
+ use_torchair_graph = use_torchair_graph )
271
+ return attn_metadata
272
+
273
+ def build_torchair_graph_dummy (self , num_reqs : int ,
274
+ num_actual_tokens : int ):
275
+ device = self .runner .device
276
+ _ , max_blocks = self .runner .graph_block_tables .shape
277
+ block_table = torch .zeros ((num_reqs , max_blocks ),
278
+ dtype = torch .int32 ,
279
+ device = device )
280
+ block_table = self ._get_graph_runner_block_tables (
281
+ num_reqs , block_table )
282
+ seq_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
283
+ slot_mapping = torch .full ((num_reqs , ),
284
+ PAD_SLOT_ID ,
285
+ dtype = torch .int32 ,
286
+ device = device )
287
+ query_start_loc = torch .full ((num_reqs , ),
288
+ - 1 ,
289
+ dtype = torch .int32 ,
290
+ device = device )
291
+
292
+ query_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
293
+ attn_mask = self .runner .attn_mask
294
+
295
+ attn_metadata = AscendMetadata (
296
+ num_actual_tokens = num_actual_tokens ,
297
+ block_tables = block_table ,
298
+ query_start_loc = query_start_loc ,
299
+ query_lens = query_lens ,
300
+ seq_lens = seq_lens ,
301
+ seq_lens_list = seq_lens .tolist (),
302
+ max_query_len = query_lens .max ().item (),
303
+ slot_mapping = slot_mapping ,
304
+ attn_mask = attn_mask ,
305
+ attn_state = AscendAttentionState .DecodeOnly )
204
306
return attn_metadata
205
307
206
308
def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
@@ -248,6 +350,7 @@ def __init__(
248
350
attn_type : str = AttentionType .DECODER ,
249
351
kv_sharing_target_layer_name : Optional [str ] = None ,
250
352
use_irope : bool = False ,
353
+ prefix : Optional [str ] = None ,
251
354
) -> None :
252
355
self .num_heads = num_heads
253
356
self .head_size = head_size
@@ -267,11 +370,29 @@ def __init__(
267
370
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
268
371
self .key_cache = None
269
372
self .value_cache = None
373
+ ascend_config = get_ascend_config ()
374
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
270
375
271
376
vllm_config = get_current_vllm_config ()
272
377
self .full_graph = vllm_config .compilation_config .full_cuda_graph
273
378
self .block_size = vllm_config .cache_config .block_size
274
379
380
+ def update_kv_cache (self , key : torch .Tensor , value : torch .Tensor ,
381
+ key_cache : torch .Tensor , value_cache : torch .Tensor ,
382
+ slot_indices : torch .Tensor ) -> None :
383
+ # calc indices by block_size
384
+ block_size = key_cache .shape [1 ]
385
+ slot_indices = slot_indices .view (- 1 , 1 , 1 ).to (torch .int64 )
386
+ block_idx = torch .div (slot_indices , block_size , rounding_mode = 'floor' )
387
+ block_offset = slot_indices % block_size
388
+ indices = torch .cat ([block_idx , block_offset ], dim = 2 )
389
+ indices = indices .npu ()
390
+
391
+ # [blocknum, blocksize, numKvHeads, headDims]
392
+ # -> [blocknum, blocksize, numKvHeads * headDims]
393
+ torch_npu .npu_scatter_nd_update_ (key_cache , indices , key )
394
+ torch_npu .npu_scatter_nd_update_ (value_cache , indices , value )
395
+
275
396
def forward (
276
397
self ,
277
398
layer : AttentionLayer ,
@@ -320,12 +441,19 @@ def forward(
320
441
if self .key_cache is None :
321
442
self .key_cache , self .value_cache = kv_cache [0 ], kv_cache [1 ]
322
443
slots = attn_metadata .slot_mapping
323
- torch_npu ._npu_reshape_and_cache (
324
- key = key [:num_actual_tokens ],
325
- value = value [:num_actual_tokens ],
326
- key_cache = self .key_cache ,
327
- value_cache = self .value_cache ,
328
- slot_indices = slots )
444
+ if not attn_metadata .with_prefill_across_dp and self .torchair_graph_enabled :
445
+ self .update_kv_cache (key = key ,
446
+ value = value ,
447
+ key_cache = self .key_cache ,
448
+ value_cache = self .value_cache ,
449
+ slot_indices = slots .to (torch .int64 ))
450
+ else :
451
+ torch_npu ._npu_reshape_and_cache (
452
+ key = key [:num_actual_tokens ],
453
+ value = value [:num_actual_tokens ],
454
+ key_cache = self .key_cache ,
455
+ value_cache = self .value_cache ,
456
+ slot_indices = slots )
329
457
330
458
if hasattr (layer , 'quant_method' ):
331
459
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
@@ -363,10 +491,28 @@ def forward(
363
491
scale_value = self .scale ,
364
492
out = output )
365
493
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
366
- graph_params = get_graph_params ()
367
-
368
- forward_context = get_forward_context ()
369
- if not forward_context .capturing :
494
+ if self .torchair_graph_enabled :
495
+ # query change to BSND
496
+ query = query .view (- 1 , 1 , self .num_heads * self .head_size )
497
+ # [blocknum, numKvHeads, blocksize, headDims] -> [blocknum, blocksize, numKvHeads * headDims]
498
+ key_cache = self .key_cache .view (* self .key_cache .shape [:- 2 ],
499
+ - 1 )
500
+ value_cache = self .value_cache .view (
501
+ * self .value_cache .shape [:- 2 ], - 1 )
502
+
503
+ output = torch_npu .npu_incre_flash_attention (
504
+ query = query ,
505
+ key = key_cache ,
506
+ value = value_cache ,
507
+ num_heads = self .num_heads ,
508
+ num_key_value_heads = self .num_kv_heads ,
509
+ input_layout = 'BSH' ,
510
+ scale_value = self .scale ,
511
+ actual_seq_lengths = attn_metadata .seq_lens_list ,
512
+ block_table = attn_metadata .block_tables ,
513
+ block_size = kv_cache [0 ].shape [1 ],
514
+ )
515
+ elif not get_forward_context ().capturing :
370
516
torch_npu ._npu_paged_attention (
371
517
query = query ,
372
518
key_cache = self .key_cache ,
@@ -384,6 +530,7 @@ def forward(
384
530
event = torch .npu .ExternalEvent ()
385
531
event .wait (stream )
386
532
event .reset (stream )
533
+ graph_params = get_graph_params ()
387
534
graph_params .events [num_tokens ].append (event )
388
535
389
536
graph_params .attn_params [num_tokens ].append ((
0 commit comments