Skip to content

Commit 6029cde

Browse files
committed
[bugfix] prefill optimization support torchair graph mode
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
1 parent 0c7db37 commit 6029cde

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ def __init__(
627627
ascend_config = get_ascend_config()
628628
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
629629
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
630+
self.enable_prefill_optimizations = ascend_config.enable_prefill_optimizations
630631

631632
# Adapt torch air graph mode with spec decoding.
632633
speculative_config = get_current_vllm_config().speculative_config
@@ -1140,7 +1141,7 @@ def forward(
11401141
# Inputs and outputs may be padded for CUDA graphs
11411142
output_padded = output
11421143
output = output[:num_actual_toks, ...]
1143-
if not self.torchair_graph_enabled:
1144+
if not self.torchair_graph_enabled or self.enable_prefill_optimizations:
11441145
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
11451146
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
11461147
if not self.running_in_graph:
@@ -1187,7 +1188,7 @@ def forward(
11871188
.view(-1, self.num_heads, self.qk_head_dim)
11881189
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
11891190
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
1190-
if self.torchair_graph_enabled:
1191+
if self.torchair_graph_enabled and not self.enable_prefill_optimizations:
11911192
num_tokens = prefill_hs_or_q_c.shape[0]
11921193
cos = attn_metadata.prefill.cos
11931194
sin = attn_metadata.prefill.sin
@@ -1203,6 +1204,7 @@ def forward(
12031204
-1)
12041205
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
12051206
else:
1207+
num_tokens = prefill_hs_or_q_c.shape[0]
12061208
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
12071209
attn_metadata.prefill.input_positions,
12081210
prefill_q_pe.contiguous(), prefill_k_pe)

vllm_ascend/models/deepseek_v2.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
from vllm.distributed import (get_dp_group, get_pp_group,
4040
get_tensor_model_parallel_rank,
4141
get_tensor_model_parallel_world_size,
42-
tensor_model_parallel_reduce_scatter,
43-
split_tensor_along_last_dim, get_tp_group)
42+
get_tp_group, split_tensor_along_last_dim,
43+
tensor_model_parallel_reduce_scatter)
4444
from vllm.forward_context import get_forward_context
4545
from vllm.model_executor.layers.activation import SiluAndMul
4646
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -68,6 +68,7 @@
6868
from vllm.sequence import IntermediateTensors
6969

7070
from vllm_ascend.ascend_config import get_ascend_config
71+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
7172
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7273
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7374
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
@@ -539,9 +540,24 @@ def forward(
539540
else:
540541
hidden_states_or_q_c = hidden_states
541542
is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model
542-
if self.torchair_graph_enabled and not is_mtp_model:
543+
with_decode = attn_metadata is not None and attn_metadata.attn_state in [
544+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
545+
]
546+
with_optimization_prefill = self.enable_prefill_optimizations and not with_decode
547+
if self.torchair_graph_enabled and not is_mtp_model and not with_optimization_prefill:
548+
if self.enable_prefill_optimizations and self.debug_layer_idx > 3 and self.debug_layer_idx < 61:
549+
hidden_states_or_q_c = get_tp_group().all_gather(
550+
hidden_states_or_q_c, 0)
551+
hidden_states = get_tp_group().all_gather(hidden_states, 0)
543552
if envs.VLLM_USE_V1:
544-
output_shape = hidden_states.shape
553+
if not self.enable_prefill_optimizations or self.debug_layer_idx < 3:
554+
output_shape = hidden_states.shape
555+
else:
556+
num_tokens = hidden_states.shape[0]
557+
rows = num_tokens // self.tp_size
558+
if num_tokens % self.tp_size:
559+
rows += 1
560+
output_shape = (rows, hidden_states.shape[1])
545561
output = torch.empty(output_shape,
546562
dtype=hidden_states_or_q_c.dtype,
547563
device=hidden_states_or_q_c.device)

0 commit comments

Comments
 (0)