Skip to content

Commit 9cdb554

Browse files
committed
MLA layer eliminates redundant index operators
Signed-off-by: huiying <chenhuiying4@huawei.com>
1 parent 94a52cf commit 9cdb554

File tree

2 files changed

+82
-25
lines changed

2 files changed

+82
-25
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,8 @@ def forward(
829829
k_pe: torch.Tensor, # value in unified attn
830830
kv_cache: torch.Tensor,
831831
attn_metadata: M,
832+
rotary_cos: Optional[torch.Tensor] = None,
833+
rotary_sin: Optional[torch.Tensor] = None,
832834
output: Optional[torch.Tensor] = None,
833835
) -> torch.Tensor:
834836
assert output is not None, "Output tensor must be provided."
@@ -875,24 +877,28 @@ def forward(
875877
decode_k_nope = None
876878
assert attn_metadata.decode is not None
877879
if self.running_in_graph:
878-
seq_len = self.rotary_emb.max_position_embeddings
879-
cos = self.rotary_emb.cos_cached[:seq_len].to(
880-
dtype=decode_hs_or_q_c.dtype)
881-
sin = self.rotary_emb.sin_cached[:seq_len].to(
882-
dtype=decode_hs_or_q_c.dtype)
883-
cos = cos[attn_metadata.decode.input_positions]
884-
sin = sin[attn_metadata.decode.input_positions]
885-
cos = cos[:, None, None, :]
886-
sin = sin[:, None, None, :]
887-
# Without explicitly controlling the order, IndexByTensor operations
888-
# would be placed after `matmul W_KV_T` hindering the overlapping of
889-
# KvRmsNormRopeCache and SingleRope.
890-
npu_wait_tensor(decode_hs_or_q_c,
891-
cos,
892-
enabled=self.enable_multistream_mla)
893-
npu_wait_tensor(decode_hs_or_q_c,
894-
sin,
895-
enabled=self.enable_multistream_mla)
880+
if rotary_cos is not None and rotary_sin is not None:
881+
cos = rotary_cos.to(dtype=decode_hs_or_q_c.dtype)
882+
sin = rotary_sin.to(dtype=decode_hs_or_q_c.dtype)
883+
else:
884+
seq_len = self.rotary_emb.max_position_embeddings
885+
cos = self.rotary_emb.cos_cached[:seq_len].to(
886+
dtype=decode_hs_or_q_c.dtype)
887+
sin = self.rotary_emb.sin_cached[:seq_len].to(
888+
dtype=decode_hs_or_q_c.dtype)
889+
cos = cos[attn_metadata.decode.input_positions]
890+
sin = sin[attn_metadata.decode.input_positions]
891+
cos = cos[:, None, None, :]
892+
sin = sin[:, None, None, :]
893+
# Without explicitly controlling the order, IndexByTensor operations
894+
# would be placed after `matmul W_KV_T` hindering the overlapping of
895+
# KvRmsNormRopeCache and SingleRope.
896+
npu_wait_tensor(decode_hs_or_q_c,
897+
cos,
898+
enabled=self.enable_multistream_mla)
899+
npu_wait_tensor(decode_hs_or_q_c,
900+
sin,
901+
enabled=self.enable_multistream_mla)
896902
decode_ql_nope, decode_q_pe = \
897903
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
898904
if self.running_in_graph:

vllm_ascend/models/deepseek_v2.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767

6868
import vllm_ascend.envs as envs_ascend
6969
from vllm_ascend.ascend_config import get_ascend_config
70+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
7071
from vllm_ascend.distributed.parallel_state import get_ep_group
7172
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7273
from vllm_ascend.quantization.quant_config import AscendLinearMethod
@@ -500,12 +501,13 @@ def __init__(
500501
self.enable_multistream_mla = \
501502
ascend_config.torchair_graph_config.enable_multistream_mla
502503

503-
def forward(
504-
self,
505-
positions: torch.Tensor,
506-
hidden_states: torch.Tensor,
507-
kv_cache: Optional[torch.Tensor] = None,
508-
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
504+
def forward(self,
505+
positions: torch.Tensor,
506+
hidden_states: torch.Tensor,
507+
kv_cache: Optional[torch.Tensor] = None,
508+
attn_metadata: Optional[AttentionMetadata] = None,
509+
rotary_cos: Optional[torch.Tensor] = None,
510+
rotary_sin: Optional[torch.Tensor] = None) -> torch.Tensor:
509511
if self.q_lora_rank is not None:
510512
ckq = self.q_a_proj(hidden_states)[0]
511513
use_multistream_mla = (self.enable_multistream_mla
@@ -526,6 +528,8 @@ def forward(
526528
dtype=hidden_states_or_q_c.dtype,
527529
device=hidden_states_or_q_c.device)
528530
forward_kwargs['output'] = output
531+
forward_kwargs['rotary_cos'] = rotary_cos
532+
forward_kwargs['rotary_sin'] = rotary_sin
529533

530534
output = self.mla_attn.impl.forward(self.mla_attn,
531535
hidden_states_or_q_c,
@@ -617,6 +621,8 @@ def forward(
617621
residual: Optional[torch.Tensor],
618622
kv_cache: Optional[torch.Tensor] = None,
619623
attn_metadata: Optional[AttentionMetadata] = None,
624+
rotary_cos: Optional[torch.Tensor] = None,
625+
rotary_sin: Optional[torch.Tensor] = None,
620626
) -> torch.Tensor:
621627
# Self Attention
622628
if residual is None:
@@ -636,6 +642,8 @@ def forward(
636642
hidden_states=hidden_states,
637643
kv_cache=kv_cache,
638644
attn_metadata=attn_metadata,
645+
rotary_cos=rotary_cos,
646+
rotary_sin=rotary_sin,
639647
)
640648

641649
if hidden_states.dtype == torch.float16:
@@ -713,9 +721,47 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
713721
make_empty_intermediate_tensors_factory(
714722
["hidden_states", "residual"], config.hidden_size))
715723

724+
ascend_config = get_ascend_config()
725+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
726+
727+
rope_theta = getattr(config, "rope_theta", 10000)
728+
rope_scaling = getattr(config, "rope_scaling", None)
729+
max_position_embeddings = getattr(config, "max_position_embeddings",
730+
8192)
731+
if rope_scaling:
732+
rope_scaling["rope_type"] = 'deepseek_yarn'
733+
self.rotary_emb = get_rope(config.qk_rope_head_dim,
734+
rotary_dim=config.qk_rope_head_dim,
735+
max_position=max_position_embeddings,
736+
base=rope_theta,
737+
rope_scaling=rope_scaling,
738+
is_neox_style=False)
739+
716740
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
717741
return self.embed_tokens(input_ids)
718742

743+
def prepare_decoder_rotary_cos_sin(
744+
self,
745+
attn_metadata: Optional[AttentionMetadata] = None
746+
) -> Tuple[torch.Tensor, torch.Tensor]:
747+
if (envs.VLLM_USE_V1 and attn_metadata is not None
748+
and attn_metadata.num_decodes is not None
749+
and attn_metadata.atten_state is not None):
750+
has_decode = attn_metadata.num_decodes > 0
751+
running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
752+
AscendAttentionState.DecodeOnly,
753+
AscendAttentionState.SpecDecoding
754+
]
755+
if has_decode and running_in_graph:
756+
cos = self.rotary_emb.cos_cached
757+
sin = self.rotary_emb.sin_cached
758+
cos = cos[attn_metadata.decode.input_positions]
759+
sin = sin[attn_metadata.decode.input_positions]
760+
cos = cos[:, None, None, :]
761+
sin = sin[:, None, None, :]
762+
return cos, sin
763+
return None, None
764+
719765
def forward(
720766
self,
721767
input_ids: torch.Tensor,
@@ -736,13 +782,18 @@ def forward(
736782
hidden_states = intermediate_tensors["hidden_states"]
737783
residual = intermediate_tensors["residual"]
738784

785+
# In graph mode and v1 engine,
786+
# precomputing cos and sin can eliminate repeated calculations in each decode layer.
787+
rotary_cos, rotary_sin = self.prepare_decoder_rotary_cos_sin(
788+
attn_metadata)
789+
739790
for i in range(self.start_layer, self.end_layer):
740791
layer = self.layers[i]
741792
hidden_states, residual = layer(
742793
positions, hidden_states, residual,
743794
kv_caches[i -
744795
self.start_layer] if kv_caches is not None else None,
745-
attn_metadata)
796+
attn_metadata, rotary_cos, rotary_sin)
746797

747798
if not get_pp_group().is_last_rank:
748799
return IntermediateTensors({

0 commit comments

Comments
 (0)