Skip to content

Commit bfcccb2

Browse files
committed
MLA layer eliminates redundant index operators
Signed-off-by: huiying <chenhuiying4@huawei.com>
1 parent afc4c0c commit bfcccb2

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,10 @@ def __init__(
458458
self.enable_graph_mode = additional_config.get(
459459
"enable_graph_mode", False)
460460

461+
self.cos = None
462+
self.sin = None
463+
self.debug_layer_idx = kwargs.get('debug_layer_idx', 0)
464+
461465
def _v_up_proj_and_o_proj(self, x):
462466
# Convert from (B, N, L) to (N, B, L)
463467
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@@ -750,19 +754,20 @@ def forward(
750754
decode_ql_nope, decode_q_pe = \
751755
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
752756
if self.running_in_graph:
753-
seq_len = self.rotary_emb.max_position_embeddings
754-
cos = self.rotary_emb.cos_cached[:seq_len].to(
755-
dtype=decode_q_pe.dtype)
756-
sin = self.rotary_emb.sin_cached[:seq_len].to(
757-
dtype=decode_q_pe.dtype)
758-
cos = cos[attn_metadata.decode.input_positions]
759-
sin = sin[attn_metadata.decode.input_positions]
760-
cos = cos[:, None, None, :]
761-
sin = sin[:, None, None, :]
762-
763-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
757+
# During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
758+
if self.debug_layer_idx == 0 or self.cos is None or self.sin is None:
759+
seq_len = self.rotary_emb.max_position_embeddings
760+
self.cos = self.rotary_emb.cos_cached[:seq_len].to(
761+
dtype=decode_q_pe.dtype)
762+
self.sin = self.rotary_emb.sin_cached[:seq_len].to(
763+
dtype=decode_q_pe.dtype)
764+
self.cos = self.cos[attn_metadata.decode.input_positions]
765+
self.sin = self.sin[attn_metadata.decode.input_positions]
766+
self.cos = self.cos[:, None, None, :]
767+
self.sin = self.sin[:, None, None, :]
768+
decode_q_pe = self.rope_single(decode_q_pe, self.cos, self.sin)
764769
decode_k_pe, decode_k_nope = self.exec_kv(
765-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
770+
hidden_states_or_kv_c_normed, self.cos, self.sin, kv_cache,
766771
attn_metadata.slot_mapping)
767772
else:
768773
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,9 @@ def __init__(
391391
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
392392
self.scaling = self.scaling * mscale * mscale
393393

394+
self.prefix = prefix
395+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
396+
394397
# In the MLA backend, kv_cache includes both k_c and
395398
# pe (i.e. decoupled position embeddings). In particular,
396399
# the concat_and_cache_mla op requires
@@ -419,10 +422,9 @@ def __init__(
419422
kv_a_layernorm=self.kv_a_layernorm,
420423
kv_b_proj=self.kv_b_proj,
421424
o_proj=self.o_proj,
425+
debug_layer_idx=self.debug_layer_idx,
422426
)
423427

424-
self.prefix = prefix
425-
self.debug_layer_idx = int(self.prefix.split(".")[-2])
426428
self.enable_graph_mode = False
427429
additional_config = get_current_vllm_config().additional_config
428430
if additional_config:

0 commit comments

Comments
 (0)