File tree Expand file tree Collapse file tree 2 files changed +9
-6
lines changed Expand file tree Collapse file tree 2 files changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -376,7 +376,10 @@ def build(
376
376
seq_lens = seq_lens [:self ._num_decode_tokens ]
377
377
input_positions = input_positions [:self ._num_decode_tokens ]
378
378
block_table = block_table [:self ._num_decode_tokens , ...]
379
- if use_torchair_graph and self .runner .attn_state == AscendAttentionState .DecodeOnly :
379
+ if use_torchair_graph and self .runner .attn_state in [
380
+ AscendAttentionState .DecodeOnly ,
381
+ AscendAttentionState .SpecDecoding
382
+ ]:
380
383
num_seqs = len (seq_lens )
381
384
if graph_pad_size != 0 :
382
385
pad_value = 1
Original file line number Diff line number Diff line change @@ -943,11 +943,6 @@ def _process_reqs(
943
943
self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
944
944
input_ids = self .input_ids [:num_input_tokens ]
945
945
946
- if (envs_ascend .VLLM_ENABLE_MC2
947
- or self .torchair_graph_enabled ) and not with_prefill :
948
- input_ids = self .input_ids [:padded_batch_size ]
949
- positions = self .positions [:padded_batch_size ]
950
-
951
946
# prepare the MRoPE for mllm if using multimodal
952
947
num_input_tokens = total_num_scheduled_tokens
953
948
# _prepare_inputs may reorder the batch, so we must gather multi
@@ -985,6 +980,11 @@ def _process_reqs(
985
980
else :
986
981
positions = self .positions [:num_input_tokens ]
987
982
983
+ if (envs_ascend .VLLM_ENABLE_MC2
984
+ or self .torchair_graph_enabled ) and not with_prefill :
985
+ input_ids = self .input_ids [:padded_batch_size ]
986
+ positions = self .positions [:padded_batch_size ]
987
+
988
988
# Run forward pass
989
989
with set_forward_context (attn_metadata ,
990
990
self .vllm_config ,
You can’t perform that action at this time.
0 commit comments