Skip to content

Commit bfc94f3

Browse files
committed
MLA layer eliminates redundant index operators
Signed-off-by: huiying <chenhuiying4@huawei.com>
1 parent 3393d53 commit bfc94f3

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,8 @@ def forward(
822822
k_pe: torch.Tensor, # value in unified attn
823823
kv_cache: torch.Tensor,
824824
attn_metadata: M,
825+
rotary_cos: Optional[torch.Tensor] = None,
826+
rotary_sin: Optional[torch.Tensor] = None,
825827
output: Optional[torch.Tensor] = None,
826828
) -> torch.Tensor:
827829
assert output is not None, "Output tensor must be provided."
@@ -870,15 +872,8 @@ def forward(
870872
decode_ql_nope, decode_q_pe = \
871873
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
872874
if self.running_in_graph:
873-
seq_len = self.rotary_emb.max_position_embeddings
874-
cos = self.rotary_emb.cos_cached[:seq_len].to(
875-
dtype=decode_q_pe.dtype)
876-
sin = self.rotary_emb.sin_cached[:seq_len].to(
877-
dtype=decode_q_pe.dtype)
878-
cos = cos[attn_metadata.decode.input_positions]
879-
sin = sin[attn_metadata.decode.input_positions]
880-
cos = cos[:, None, None, :]
881-
sin = sin[:, None, None, :]
875+
cos = rotary_cos.to(dtype=decode_q_pe.dtype)
876+
sin = rotary_sin.to(dtype=decode_q_pe.dtype)
882877

883878
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
884879
decode_k_pe, decode_k_nope = self.exec_kv(

vllm_ascend/models/deepseek_v2.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7373
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7474
from vllm_ascend.utils import dispose_tensor
75+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
7576

7677
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7778

@@ -502,7 +503,9 @@ def forward(
502503
positions: torch.Tensor,
503504
hidden_states: torch.Tensor,
504505
kv_cache: Optional[torch.Tensor] = None,
505-
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
506+
attn_metadata: Optional[AttentionMetadata] = None,
507+
rotary_cos: Optional[torch.Tensor] = None,
508+
rotary_sin: Optional[torch.Tensor] = None) -> torch.Tensor:
506509
if self.q_lora_rank is not None:
507510
ckq = self.q_a_proj(hidden_states)[0]
508511
hidden_states_or_q_c = self.q_a_layernorm(ckq)
@@ -516,6 +519,8 @@ def forward(
516519
dtype=hidden_states_or_q_c.dtype,
517520
device=hidden_states_or_q_c.device)
518521
forward_kwargs['output'] = output
522+
forward_kwargs['rotary_cos'] = rotary_cos
523+
forward_kwargs['rotary_sin'] = rotary_sin
519524

520525
output = self.mla_attn.impl.forward(self.mla_attn,
521526
hidden_states_or_q_c,
@@ -607,6 +612,8 @@ def forward(
607612
residual: Optional[torch.Tensor],
608613
kv_cache: Optional[torch.Tensor] = None,
609614
attn_metadata: Optional[AttentionMetadata] = None,
615+
rotary_cos: Optional[torch.Tensor] = None,
616+
rotary_sin: Optional[torch.Tensor] = None,
610617
) -> torch.Tensor:
611618
# Self Attention
612619
if residual is None:
@@ -626,6 +633,8 @@ def forward(
626633
hidden_states=hidden_states,
627634
kv_cache=kv_cache,
628635
attn_metadata=attn_metadata,
636+
rotary_cos=rotary_cos,
637+
rotary_sin=rotary_sin,
629638
)
630639

631640
if hidden_states.dtype == torch.float16:
@@ -703,9 +712,43 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
703712
make_empty_intermediate_tensors_factory(
704713
["hidden_states", "residual"], config.hidden_size))
705714

715+
ascend_config = get_ascend_config()
716+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
717+
718+
rope_theta = getattr(config, "rope_theta", 10000)
719+
rope_scaling = getattr(config, "rope_scaling", None)
720+
max_position_embeddings = getattr(config, "max_position_embeddings",
721+
8192)
722+
if rope_scaling:
723+
rope_scaling["rope_type"] = 'deepseek_yarn'
724+
self.rotary_emb = get_rope(config.qk_rope_head_dim,
725+
rotary_dim=config.qk_rope_head_dim,
726+
max_position=max_position_embeddings,
727+
base=rope_theta,
728+
rope_scaling=rope_scaling,
729+
is_neox_style=False)
730+
706731
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
707732
return self.embed_tokens(input_ids)
708733

734+
def prepare_decoder_rotary_cos_sin(
735+
self, attn_metadata: Optional[AttentionMetadata] = None
736+
) -> Tuple[torch.Tensor, torch.Tensor]:
737+
if (attn_metadata is not None and attn_metadata.num_decodes is not None and
738+
attn_metadata.atten_state):
739+
has_decode = attn_metadata.num_decodes > 0
740+
running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
741+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding]
742+
if has_decode and running_in_graph:
743+
cos = self.rotary_emb.cos_cached
744+
sin = self.rotary_emb.sin_cached
745+
cos = cos[attn_metadata.decode.input_positions]
746+
sin = sin[attn_metadata.decode.input_positions]
747+
cos = cos[:, None, None, :]
748+
sin = sin[:, None, None, :]
749+
return cos, sin
750+
return None, None
751+
709752
def forward(
710753
self,
711754
input_ids: torch.Tensor,
@@ -726,13 +769,17 @@ def forward(
726769
hidden_states = intermediate_tensors["hidden_states"]
727770
residual = intermediate_tensors["residual"]
728771

772+
# In graph mode and v1 engine,
773+
# precomputing cos and sin can eliminate repeated calculations in each decode layer.
774+
rotary_cos, rotary_sin = self.prepare_decoder_rotary_cos_sin(attn_metadata)
775+
729776
for i in range(self.start_layer, self.end_layer):
730777
layer = self.layers[i]
731778
hidden_states, residual = layer(
732779
positions, hidden_states, residual,
733780
kv_caches[i -
734781
self.start_layer] if kv_caches is not None else None,
735-
attn_metadata)
782+
attn_metadata, rotary_cos, rotary_sin)
736783

737784
if not get_pp_group().is_last_rank:
738785
return IntermediateTensors({

0 commit comments

Comments
 (0)