19
19
from enum import Enum
20
20
from typing import Any , Dict , List , Optional , Tuple , Type
21
21
22
+ import numpy as np
22
23
import torch
23
24
import torch_npu
24
25
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
26
AttentionLayer , AttentionType )
26
- from vllm .attention .backends .utils import CommonAttentionState
27
+ from vllm .attention .backends .utils import PAD_SLOT_ID , CommonAttentionState
27
28
from vllm .config import get_current_vllm_config
28
29
from vllm .forward_context import ForwardContext , get_forward_context
29
30
from vllm .utils import direct_register_custom_op
30
31
from vllm .v1 .core .sched .output import SchedulerOutput
31
32
from vllm .v1 .worker .gpu_input_batch import InputBatch
32
33
34
+ from vllm_ascend .ascend_config import get_ascend_config
33
35
from vllm_ascend .attention .utils import \
34
36
AscendCommonAttentionMetadata as CommonAttentionMetadata
35
37
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
@@ -52,7 +54,7 @@ def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
52
54
def get_metadata_cls () -> Type ["AscendMetadata" ]:
53
55
return AscendMetadata
54
56
55
- @staticmethod
57
+ @AscendAttentionBackendImplstaticmethod
56
58
def get_state_cls () -> Type ["CommonAttentionState" ]:
57
59
return CommonAttentionState
58
60
@@ -140,6 +142,8 @@ class AscendMetadata:
140
142
num_input_tokens : int = 0 # Number of tokens including padding.
141
143
142
144
enable_dbo_across_dp : bool = False
145
+ with_prefill_across_dp : bool = False
146
+ use_torchair_graph : bool = False
143
147
144
148
def split_metadata_for_multistream (
145
149
self ,
@@ -163,6 +167,32 @@ def reorder_batch(self, input_batch: "InputBatch",
163
167
scheduler_output : "SchedulerOutput" ) -> bool :
164
168
return False
165
169
170
+ def _get_graph_runner_block_tables (
171
+ self , num_seqs : int , block_tables : torch .Tensor ) -> torch .Tensor :
172
+
173
+ max_batch_size , max_blocks = self .runner .graph_block_tables .shape
174
+ assert max_batch_size >= num_seqs
175
+
176
+ if isinstance (self .runner .graph_block_tables , np .ndarray ):
177
+ graph_block_tables = torch .zeros ((max_batch_size , max_blocks ),
178
+ dtype = block_tables .dtype ,
179
+ device = block_tables .device )
180
+ else :
181
+ graph_block_tables = self .runner .graph_block_tables .to (
182
+ device = block_tables .device , dtype = block_tables .dtype )
183
+
184
+ num_blocks = block_tables .size (1 )
185
+ if num_blocks <= max_blocks :
186
+ graph_block_tables [:num_seqs , :
187
+ num_blocks ] = block_tables [:num_seqs , :
188
+ num_blocks ]
189
+ else :
190
+ graph_block_tables [:num_seqs , :
191
+ max_blocks ] = block_tables [:num_seqs , :
192
+ max_blocks ]
193
+
194
+ return graph_block_tables [:num_seqs , :max_blocks ]
195
+
166
196
def build (self ,
167
197
num_reqs ,
168
198
num_actual_tokens ,
@@ -178,7 +208,7 @@ def build(self,
178
208
block_table [:num_reqs ])
179
209
180
210
query_start_loc = common_attn_metadata .query_start_loc
181
- seq_lens = common_attn_metadata .seq_lens
211
+ seq_lens = common_attn_metadata .seq_lens # type: ignore
182
212
# TODO: Refactor these two param to common metadata in runners,
183
213
# preparing for the hybrid KV groups feature
184
214
query_lens = common_attn_metadata .query_lens or self .runner .query_lens
@@ -188,6 +218,41 @@ def build(self,
188
218
slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
189
219
attn_mask = self .runner .attn_mask
190
220
attn_state = self .runner .attn_state
221
+ query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
222
+ query_start_loc = query_start_loc_cpu .to (self .runner .device ,
223
+ non_blocking = True )
224
+
225
+ graph_pad_size = kwargs .get ("graph_pad_size" , - 1 )
226
+ with_prefill_across_dp = kwargs ["with_prefill_across_dp" ]
227
+ use_torchair_graph = graph_pad_size != - 1
228
+ if not with_prefill_across_dp :
229
+ if use_torchair_graph and self .runner .attn_state in [
230
+ AscendAttentionState .DecodeOnly ,
231
+ AscendAttentionState .SpecDecoding
232
+ ]:
233
+ num_seqs = len (seq_lens )
234
+ if graph_pad_size != 0 :
235
+ pad_value = 1
236
+ padded_seq_lens = seq_lens .tolist () + [pad_value
237
+ ] * graph_pad_size
238
+ else :
239
+ padded_seq_lens = seq_lens .tolist ()
240
+
241
+ seq_lens = torch .from_numpy (
242
+ np .array (padded_seq_lens ).astype (np .int32 ))
243
+ padding = torch .full ((graph_pad_size , ),
244
+ PAD_SLOT_ID ,
245
+ dtype = slot_mapping .dtype ,
246
+ device = slot_mapping .device )
247
+ slot_mapping = torch .cat ([slot_mapping , padding ])
248
+ block_table_padding = torch .zeros (
249
+ (graph_pad_size , ) + block_table .shape [1 :],
250
+ dtype = block_table .dtype ,
251
+ device = block_table .device )
252
+ block_table = torch .cat ([block_table , block_table_padding ],
253
+ dim = 0 )
254
+ block_table = self ._get_graph_runner_block_tables (
255
+ num_seqs + graph_pad_size , block_table )
191
256
192
257
attn_metadata = AscendMetadata (
193
258
num_actual_tokens = num_actual_tokens ,
@@ -200,7 +265,44 @@ def build(self,
200
265
slot_mapping = slot_mapping ,
201
266
attn_mask = attn_mask ,
202
267
attn_state = attn_state ,
203
- enable_dbo_across_dp = enable_dbo_across_dp )
268
+ enable_dbo_across_dp = enable_dbo_across_dp ,
269
+ with_prefill_across_dp = with_prefill_across_dp ,
270
+ use_torchair_graph = use_torchair_graph )
271
+ return attn_metadata
272
+
273
+ def build_torchair_graph_dummy (self , num_reqs : int ,
274
+ num_actual_tokens : int ):
275
+ device = self .runner .device
276
+ _ , max_blocks = self .runner .graph_block_tables .shape
277
+ block_table = torch .zeros ((num_reqs , max_blocks ),
278
+ dtype = torch .int32 ,
279
+ device = device )
280
+ block_table = self ._get_graph_runner_block_tables (
281
+ num_reqs , block_table )
282
+ seq_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
283
+ slot_mapping = torch .full ((num_reqs , ),
284
+ PAD_SLOT_ID ,
285
+ dtype = torch .int32 ,
286
+ device = device )
287
+ query_start_loc = torch .full ((num_reqs , ),
288
+ - 1 ,
289
+ dtype = torch .int32 ,
290
+ device = device )
291
+
292
+ query_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
293
+ attn_mask = self .runner .attn_mask
294
+
295
+ attn_metadata = AscendMetadata (
296
+ num_actual_tokens = num_actual_tokens ,
297
+ block_tables = block_table ,
298
+ query_start_loc = query_start_loc ,
299
+ query_lens = query_lens ,
300
+ seq_lens = seq_lens ,
301
+ seq_lens_list = seq_lens .tolist (),
302
+ max_query_len = query_lens .max ().item (),
303
+ slot_mapping = slot_mapping ,
304
+ attn_mask = attn_mask ,
305
+ attn_state = AscendAttentionState .DecodeOnly )
204
306
return attn_metadata
205
307
206
308
def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
@@ -248,6 +350,7 @@ def __init__(
248
350
attn_type : str = AttentionType .DECODER ,
249
351
kv_sharing_target_layer_name : Optional [str ] = None ,
250
352
use_irope : bool = False ,
353
+ prefix : Optional [str ] = None ,
251
354
) -> None :
252
355
self .num_heads = num_heads
253
356
self .head_size = head_size
@@ -267,11 +370,30 @@ def __init__(
267
370
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
268
371
self .key_cache = None
269
372
self .value_cache = None
373
+ ascend_config = get_ascend_config ()
374
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
270
375
271
376
vllm_config = get_current_vllm_config ()
272
377
self .full_graph = vllm_config .compilation_config .full_cuda_graph
273
378
self .block_size = vllm_config .cache_config .block_size
274
379
380
+ @staticmethod
381
+ def update_kv_cache (key : torch .Tensor , value : torch .Tensor ,
382
+ key_cache : torch .Tensor , value_cache : torch .Tensor ,
383
+ slot_indices : torch .Tensor ) -> None :
384
+ # calc indices by block_size
385
+ block_size = key_cache .shape [1 ]
386
+ slot_indices = slot_indices .view (- 1 , 1 , 1 ).to (torch .int64 )
387
+ block_idx = torch .div (slot_indices , block_size , rounding_mode = 'floor' )
388
+ block_offset = slot_indices % block_size
389
+ indices = torch .cat ([block_idx , block_offset ], dim = 2 )
390
+ indices = indices .npu ()
391
+
392
+ # [blocknum, blocksize, numKvHeads, headDims]
393
+ # -> [blocknum, blocksize, numKvHeads * headDims]
394
+ torch_npu .npu_scatter_nd_update_ (key_cache , indices , key )
395
+ torch_npu .npu_scatter_nd_update_ (value_cache , indices , value )
396
+
275
397
def forward (
276
398
self ,
277
399
layer : AttentionLayer ,
@@ -320,12 +442,19 @@ def forward(
320
442
if self .key_cache is None :
321
443
self .key_cache , self .value_cache = kv_cache [0 ], kv_cache [1 ]
322
444
slots = attn_metadata .slot_mapping
323
- torch_npu ._npu_reshape_and_cache (
324
- key = key [:num_actual_tokens ],
325
- value = value [:num_actual_tokens ],
326
- key_cache = self .key_cache ,
327
- value_cache = self .value_cache ,
328
- slot_indices = slots )
445
+ if not attn_metadata .with_prefill_across_dp and self .torchair_graph_enabled :
446
+ self .update_kv_cache (key = key ,
447
+ value = value ,
448
+ key_cache = self .key_cache ,
449
+ value_cache = self .value_cache ,
450
+ slot_indices = slots .to (torch .int64 ))
451
+ else :
452
+ torch_npu ._npu_reshape_and_cache (
453
+ key = key [:num_actual_tokens ],
454
+ value = value [:num_actual_tokens ],
455
+ key_cache = self .key_cache ,
456
+ value_cache = self .value_cache ,
457
+ slot_indices = slots )
329
458
330
459
if hasattr (layer , 'quant_method' ):
331
460
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
@@ -363,10 +492,28 @@ def forward(
363
492
scale_value = self .scale ,
364
493
out = output )
365
494
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
366
- graph_params = get_graph_params ()
367
-
368
- forward_context = get_forward_context ()
369
- if not forward_context .capturing :
495
+ if self .torchair_graph_enabled :
496
+ # query change to BSND
497
+ query = query .view (- 1 , 1 , self .num_heads * self .head_size )
498
+ # [blocknum, numKvHeads, blocksize, headDims] -> [blocknum, blocksize, numKvHeads * headDims]
499
+ key_cache = self .key_cache .view ( # type: ignore
500
+ * self .key_cache .shape [:- 2 ], - 1 ) # type: ignore
501
+ value_cache = self .value_cache .view ( # type: ignore
502
+ * self .value_cache .shape [:- 2 ], - 1 ) # type: ignore
503
+
504
+ output = torch_npu .npu_incre_flash_attention (
505
+ query = query ,
506
+ key = key_cache ,
507
+ value = value_cache ,
508
+ num_heads = self .num_heads ,
509
+ num_key_value_heads = self .num_kv_heads ,
510
+ input_layout = 'BSH' ,
511
+ scale_value = self .scale ,
512
+ actual_seq_lengths = attn_metadata .seq_lens_list ,
513
+ block_table = attn_metadata .block_tables ,
514
+ block_size = kv_cache [0 ].shape [1 ],
515
+ )
516
+ elif not get_forward_context ().capturing :
370
517
torch_npu ._npu_paged_attention (
371
518
query = query ,
372
519
key_cache = self .key_cache ,
@@ -384,6 +531,7 @@ def forward(
384
531
event = torch .npu .ExternalEvent ()
385
532
event .wait (stream )
386
533
event .reset (stream )
534
+ graph_params = get_graph_params ()
387
535
graph_params .events [num_tokens ].append (event )
388
536
389
537
graph_params .attn_params [num_tokens ].append ((
0 commit comments