File tree Expand file tree Collapse file tree 2 files changed +6
-9
lines changed Expand file tree Collapse file tree 2 files changed +6
-9
lines changed Original file line number Diff line number Diff line change @@ -377,10 +377,7 @@ def build(
377
377
seq_lens = seq_lens [:self ._num_decode_tokens ]
378
378
input_positions = input_positions [:self ._num_decode_tokens ]
379
379
block_table = block_table [:self ._num_decode_tokens , ...]
380
- if use_torchair_graph and self .runner .attn_state in [
381
- AscendAttentionState .DecodeOnly ,
382
- AscendAttentionState .SpecDecoding
383
- ]:
380
+ if use_torchair_graph and self .runner .attn_state == AscendAttentionState .DecodeOnly :
384
381
num_seqs = len (seq_lens )
385
382
if graph_pad_size != 0 :
386
383
pad_value = 1
Original file line number Diff line number Diff line change @@ -943,6 +943,11 @@ def _process_reqs(
943
943
self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
944
944
input_ids = self .input_ids [:num_input_tokens ]
945
945
946
+ if (envs_ascend .VLLM_ENABLE_MC2
947
+ or self .torchair_graph_enabled ) and not with_prefill :
948
+ input_ids = self .input_ids [:padded_batch_size ]
949
+ positions = self .positions [:padded_batch_size ]
950
+
946
951
# prepare the MRoPE for mllm if using multimodal
947
952
num_input_tokens = total_num_scheduled_tokens
948
953
# _prepare_inputs may reorder the batch, so we must gather multi
@@ -980,11 +985,6 @@ def _process_reqs(
980
985
else :
981
986
positions = self .positions [:num_input_tokens ]
982
987
983
- if (envs_ascend .VLLM_ENABLE_MC2
984
- or self .torchair_graph_enabled ) and not with_prefill :
985
- input_ids = self .input_ids [:padded_batch_size ]
986
- positions = self .positions [:padded_batch_size ]
987
-
988
988
# Run forward pass
989
989
with set_forward_context (attn_metadata ,
990
990
self .vllm_config ,
You can’t perform that action at this time.
0 commit comments