Skip to content

Commit c07e6e1

Browse files
committed
MLA layer eliminates redundant index operators
Signed-off-by: huiying <chenhuiying4@huawei.com>
1 parent e2a0c19 commit c07e6e1

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ def __init__(
408408
self.enable_graph_mode = additional_config.get(
409409
"enable_graph_mode", False)
410410

411+
self.cos = None
412+
self.sin = None
413+
self.debug_layer_idx = extra_impl_args.get('debug_layer_idx', 0)
414+
411415
def _v_up_proj_and_o_proj(self, x):
412416
# Convert from (B, N, L) to (N, B, L)
413417
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@@ -700,18 +704,20 @@ def forward(
700704
decode_ql_nope, decode_q_pe = \
701705
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
702706
if self.running_in_graph:
703-
seq_len = self.rotary_emb.max_position_embeddings
704-
cos = self.rotary_emb.cos_cached[:seq_len].to(
705-
dtype=decode_q_pe.dtype)
706-
sin = self.rotary_emb.sin_cached[:seq_len].to(
707-
dtype=decode_q_pe.dtype)
708-
cos = cos[attn_metadata.decode.input_positions]
709-
sin = sin[attn_metadata.decode.input_positions]
710-
cos = cos[:, None, None, :]
711-
sin = sin[:, None, None, :]
712-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
707+
# During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
708+
if self.debug_layer_idx == 0 or self.cos is None or self.sin is None:
709+
seq_len = self.rotary_emb.max_position_embeddings
710+
self.cos = self.rotary_emb.cos_cached[:seq_len].to(
711+
dtype=decode_q_pe.dtype)
712+
self.sin = self.rotary_emb.sin_cached[:seq_len].to(
713+
dtype=decode_q_pe.dtype)
714+
self.cos = self.cos[attn_metadata.decode.input_positions]
715+
self.sin = self.sin[attn_metadata.decode.input_positions]
716+
self.cos = self.cos[:, None, None, :]
717+
self.sin = self.sin[:, None, None, :]
718+
decode_q_pe = self.rope_single(decode_q_pe, self.cos, self.sin)
713719
decode_k_pe, decode_k_nope = self.exec_kv(
714-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
720+
hidden_states_or_kv_c_normed, self.cos, self.sin, kv_cache,
715721
attn_metadata.slot_mapping)
716722
else:
717723
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ def __init__(
364364
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
365365
self.scaling = self.scaling * mscale * mscale
366366

367+
self.prefix = prefix
368+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
369+
367370
# In the MLA backend, kv_cache includes both k_c and
368371
# pe (i.e. decoupled position embeddings). In particular,
369372
# the concat_and_cache_mla op requires
@@ -392,10 +395,9 @@ def __init__(
392395
kv_a_layernorm=self.kv_a_layernorm,
393396
kv_b_proj=self.kv_b_proj,
394397
o_proj=self.o_proj,
398+
debug_layer_idx=self.debug_layer_idx,
395399
)
396400

397-
self.prefix = prefix
398-
self.debug_layer_idx = int(self.prefix.split(".")[-2])
399401
self.enable_graph_mode = False
400402
additional_config = get_current_vllm_config().additional_config
401403
if additional_config:

0 commit comments

Comments
 (0)