10
10
MLAAttentionImpl )
11
11
from vllm .config import VllmConfig , get_current_vllm_config
12
12
from vllm .distributed import get_tensor_model_parallel_world_size , get_tp_group
13
+ from vllm .forward_context import ForwardContext , get_forward_context
13
14
from vllm .model_executor .layers .linear import (LinearBase ,
14
15
UnquantizedLinearMethod )
15
16
from vllm .utils import cdiv , round_down
21
22
maybe_save_kv_layer_to_connector ,
22
23
split_decodes_and_prefills ,
23
24
wait_for_kv_layer_from_connector )
25
+ from vllm_ascend .compilation .acl_graph import get_graph_params
24
26
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
25
27
from vllm_ascend .multistream .context import get_multistream_comm_context
26
28
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
@@ -169,7 +171,7 @@ def split_metadata_for_multistream(
169
171
class AscendMLAMetadataBuilder :
170
172
# Does this backend/builder support ACL Graphs for attention (default: no).
171
173
aclgraph_support : ClassVar [AttentionCGSupport ] = \
172
- AttentionCGSupport .NEVER
174
+ AttentionCGSupport .UNIFORM_SINGLE_TOKEN_DECODE
173
175
"""
174
176
NOTE: Please read the comment at the top of the file before trying to
175
177
understand this class
@@ -314,6 +316,8 @@ def build(
314
316
self .model_config .dtype ) # type: ignore
315
317
self .sin_cache = self .sin_cache .to ( # type: ignore
316
318
self .model_config .dtype ) # type: ignore
319
+ cos = common_attn_metadata .cos
320
+ sin = common_attn_metadata .sin
317
321
318
322
query_seq_lens_cpu = query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ]
319
323
query_lens = query_seq_lens_cpu [:num_reqs ]
@@ -395,9 +399,12 @@ def build(
395
399
block_table = block_table [:num_decodes , ...]
396
400
seq_lens_list = seq_lens .tolist ()
397
401
398
- cos = self .cos_cache [input_positions ].unsqueeze ( # type: ignore
402
+ assert self .cos_cache is not None
403
+ assert self .sin_cache is not None
404
+
405
+ cos [:num_decodes , ...] = self .cos_cache [input_positions ].unsqueeze (
399
406
1 ).unsqueeze (2 )
400
- sin = self .sin_cache [input_positions ].unsqueeze ( # type: ignore
407
+ sin [: num_decodes , ...] = self .sin_cache [input_positions ].unsqueeze (
401
408
1 ).unsqueeze (2 )
402
409
403
410
decode_metadata = AscendMLADecodeMetadata (
@@ -408,8 +415,8 @@ def build(
408
415
max_seq_lens = max_seq_lens ,
409
416
attn_mask = common_attn_metadata .spec_attn_mask ,
410
417
actual_seq_lengths_q = actual_seq_lengths_q ,
411
- sin = sin ,
412
- cos = cos )
418
+ sin = sin [: num_decodes , ...] ,
419
+ cos = cos [: num_decodes , ...] )
413
420
414
421
return self .metadata_cls ( # type: ignore
415
422
num_actual_tokens = num_actual_tokens ,
@@ -429,6 +436,26 @@ def build(
429
436
enable_dbo_across_dp = common_attn_metadata .enable_dbo_across_dp ,
430
437
)
431
438
439
+ def build_for_graph_capture (
440
+ self ,
441
+ common_attn_metadata : AscendCommonAttentionMetadata ,
442
+ attn_state : AscendAttentionState = AscendAttentionState .DecodeOnly ,
443
+ model : Optional [nn .Module ] = None ,
444
+ ):
445
+ if attn_state == AscendAttentionState .DecodeOnly :
446
+ attn_metadata = self .build (
447
+ common_prefix_len = 0 ,
448
+ common_attn_metadata = common_attn_metadata ,
449
+ model = model ,
450
+ )
451
+ else :
452
+ raise NotImplementedError (
453
+ "Currently we only support building dummy metadata for DecodeOnly state"
454
+ )
455
+
456
+ attn_metadata .attn_state = attn_state
457
+ return attn_metadata
458
+
432
459
433
460
class DecodeMLAPreprocessResult (NamedTuple ):
434
461
ql_nope : Optional [torch .Tensor ] = None
@@ -832,24 +859,69 @@ def _forward_decode(
832
859
sparse_mode = 0
833
860
spec_attn_mask = None
834
861
835
- attn_output , _ = torch_npu .npu_fused_infer_attention_score (
836
- q_nope ,
837
- k_nope ,
838
- k_nope ,
839
- query_rope = q_pe ,
840
- key_rope = k_pe ,
841
- num_heads = self .num_heads ,
842
- num_key_value_heads = self .num_kv_heads ,
843
- input_layout = input_layout ,
844
- atten_mask = spec_attn_mask ,
845
- sparse_mode = sparse_mode ,
846
- scale = self .scale ,
847
- antiquant_mode = 0 ,
848
- antiquant_scale = None ,
849
- block_table = decode_meta .block_table ,
850
- block_size = block_size ,
851
- actual_seq_lengths_kv = decode_meta .seq_lens_list ,
852
- actual_seq_lengths = actual_seq_lengths )
862
+ common_kwargs = {
863
+ 'query_rope' : q_pe ,
864
+ 'key_rope' : k_pe ,
865
+ 'num_heads' : self .num_heads ,
866
+ 'num_key_value_heads' : self .num_kv_heads ,
867
+ 'input_layout' : input_layout ,
868
+ 'atten_mask' : spec_attn_mask ,
869
+ 'sparse_mode' : sparse_mode ,
870
+ 'scale' : self .scale ,
871
+ 'antiquant_mode' : 0 ,
872
+ 'antiquant_scale' : None ,
873
+ 'block_table' : decode_meta .block_table ,
874
+ 'block_size' : block_size ,
875
+ "actual_seq_lengths" : actual_seq_lengths ,
876
+ "actual_seq_lengths_kv" : decode_meta .seq_lens_list ,
877
+ }
878
+ graph_params = get_graph_params ()
879
+ forward_context : ForwardContext = get_forward_context ()
880
+ if forward_context .capturing :
881
+ stream = torch_npu .npu .current_stream ()
882
+
883
+ event = torch .npu .ExternalEvent ()
884
+ event .wait (stream )
885
+ event .reset (stream )
886
+ graph_params .events [num_tokens ].append (event )
887
+
888
+ workspace = graph_params .workspaces .get (num_tokens )
889
+ if workspace is None :
890
+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
891
+ q_nope ,
892
+ k_nope ,
893
+ k_nope ,
894
+ ** common_kwargs )
895
+ graph_params .workspaces [num_tokens ] = workspace
896
+
897
+ attn_output = torch .empty_like (q_nope )
898
+ softmax_lse = torch .empty (num_tokens ,
899
+ dtype = q_nope .dtype ,
900
+ device = q_nope .device )
901
+
902
+ graph_params .attn_params [num_tokens ].append (
903
+ (q_nope , k_nope , q_pe , k_pe , self .num_heads , self .num_kv_heads ,
904
+ input_layout , spec_attn_mask , sparse_mode , self .scale ,
905
+ decode_meta .block_table , block_size ,
906
+ decode_meta .seq_lens_list , actual_seq_lengths , workspace ,
907
+ attn_output , softmax_lse ))
908
+
909
+ torch .npu .graph_task_group_begin (stream )
910
+ torch_npu .npu_fused_infer_attention_score .out (
911
+ q_nope ,
912
+ k_nope ,
913
+ k_nope ,
914
+ ** common_kwargs ,
915
+ workspace = workspace ,
916
+ out = [attn_output , softmax_lse ])
917
+ handle = torch .npu .graph_task_group_end (stream )
918
+ graph_params .handles [num_tokens ].append (handle )
919
+ else :
920
+ attn_output , _ = torch_npu .npu_fused_infer_attention_score (
921
+ q_nope ,
922
+ k_nope ,
923
+ k_nope ,
924
+ ** common_kwargs )
853
925
854
926
current_ms_metadata = get_multistream_comm_context ()
855
927
if current_ms_metadata is None :
0 commit comments