Skip to content

Commit 8cd3c81

Browse files
committed
MLA layer eliminates redundant index operators
Signed-off-by: huiying <chenhuiying4@huawei.com>
1 parent e2a0c19 commit 8cd3c81

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

vllm_ascend/attention/attention.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,10 @@ def __init__(
10001000
self.w_kc = None
10011001
self.w_vc = None
10021002

1003+
self.cos = None
1004+
self.sin = None
1005+
self.debug_layer_idx = extra_impl_args.get('debug_layer_idx', 0)
1006+
10031007
self.enable_graph_mode = False
10041008
additional_config = get_current_vllm_config().additional_config
10051009
if additional_config:
@@ -1128,17 +1132,18 @@ def forward(
11281132
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
11291133
dim=-1)
11301134
if k_pe is None and attn_metadata.decode_metadata:
1131-
seq_len = self.rotary_emb.max_position_embeddings
1132-
1133-
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
1134-
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
1135-
cos = cos[attn_metadata.input_positions]
1136-
sin = sin[attn_metadata.input_positions]
1137-
cos = cos[:, None, None, :]
1138-
sin = sin[:, None, None, :]
1139-
1140-
q_pe = self.rope_single(q_pe, cos, sin)
1141-
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
1135+
if self.debug_layer_idx == 0 or self.cos is None or self.sin is None:
1136+
seq_len = self.rotary_emb.max_position_embeddings
1137+
1138+
self.cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
1139+
self.sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
1140+
self.cos = self.cos[attn_metadata.input_positions]
1141+
self.sin = self.sin[attn_metadata.input_positions]
1142+
self.cos = self.cos[:, None, None, :]
1143+
self.sin = self.sin[:, None, None, :]
1144+
1145+
q_pe = self.rope_single(q_pe, self.cos, self.sin)
1146+
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, self.cos, self.sin,
11421147
kv_cache, attn_metadata.slot_mapping)
11431148
else:
11441149
if k_pe is None:

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ def __init__(
364364
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
365365
self.scaling = self.scaling * mscale * mscale
366366

367+
self.prefix = prefix
368+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
369+
367370
# In the MLA backend, kv_cache includes both k_c and
368371
# pe (i.e. decoupled position embeddings). In particular,
369372
# the concat_and_cache_mla op requires
@@ -392,10 +395,9 @@ def __init__(
392395
kv_a_layernorm=self.kv_a_layernorm,
393396
kv_b_proj=self.kv_b_proj,
394397
o_proj=self.o_proj,
398+
debug_layer_idx=self.debug_layer_idx,
395399
)
396400

397-
self.prefix = prefix
398-
self.debug_layer_idx = int(self.prefix.split(".")[-2])
399401
self.enable_graph_mode = False
400402
additional_config = get_current_vllm_config().additional_config
401403
if additional_config:

0 commit comments

Comments
 (0)