Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,8 @@ def forward(
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: M,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
Expand Down Expand Up @@ -875,24 +877,28 @@ def forward(
decode_k_nope = None
assert attn_metadata.decode is not None
if self.running_in_graph:
seq_len = self.rotary_emb.max_position_embeddings
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
npu_wait_tensor(decode_hs_or_q_c,
cos,
enabled=self.enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=self.enable_multistream_mla)
if rotary_cos is not None and rotary_sin is not None:
cos = rotary_cos.to(dtype=decode_hs_or_q_c.dtype)
sin = rotary_sin.to(dtype=decode_hs_or_q_c.dtype)
else:
seq_len = self.rotary_emb.max_position_embeddings
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
npu_wait_tensor(decode_hs_or_q_c,
cos,
enabled=self.enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=self.enable_multistream_mla)
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
Expand Down
65 changes: 58 additions & 7 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
Expand Down Expand Up @@ -500,12 +501,13 @@ def __init__(
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
def forward(self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
use_multistream_mla = (self.enable_multistream_mla
Expand All @@ -526,6 +528,8 @@ def forward(
dtype=hidden_states_or_q_c.dtype,
device=hidden_states_or_q_c.device)
forward_kwargs['output'] = output
forward_kwargs['rotary_cos'] = rotary_cos
forward_kwargs['rotary_sin'] = rotary_sin

output = self.mla_attn.impl.forward(self.mla_attn,
hidden_states_or_q_c,
Expand Down Expand Up @@ -617,6 +621,8 @@ def forward(
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
Expand All @@ -636,6 +642,8 @@ def forward(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
)

if hidden_states.dtype == torch.float16:
Expand Down Expand Up @@ -713,9 +721,47 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled

rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(config.qk_rope_head_dim,
rotary_dim=config.qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def prepare_decoder_rotary_cos_sin(
self,
attn_metadata: Optional[AttentionMetadata] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if (envs.VLLM_USE_V1 and attn_metadata is not None
and attn_metadata.num_decodes is not None
and attn_metadata.atten_state is not None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When dp!=1, when _dummy_run is triggered, attn_metadata is None, attn_metadata.num_decodes and attn_metadata.atten_state do not exist

Copy link
Author

@huiyingCCCC huiyingCCCC Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When dp!=1, when _dummy_run is triggered, attn_metadata is None, attn_metadata.num_decodes and attn_metadata.atten_state do not exist

If attn_metadata is empty, this branch will not be entered and the original code will be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If attn_metadata is empty, attn_metadata has no num_decodes and atten_state attributes.I changed it to if (envs.VLLM_USE_V1 and attn_metadata is not None):
and it can run normally. How much is your profit?

has_decode = attn_metadata.num_decodes > 0
running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]
if has_decode and running_in_graph:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
return cos, sin
return None, None

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -736,13 +782,18 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

# In graph mode and v1 engine,
# precomputing cos and sin can eliminate repeated calculations in each decode layer.
rotary_cos, rotary_sin = self.prepare_decoder_rotary_cos_sin(
attn_metadata)

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata)
attn_metadata, rotary_cos, rotary_sin)

if not get_pp_group().is_last_rank:
return IntermediateTensors({
Expand Down