Skip to content

Commit 5cba09d

Browse files
[long_seq_optim] support cp&sp
1 parent bb0ab43 commit 5cba09d

File tree

1 file changed

+63
-15
lines changed

1 file changed

+63
-15
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,15 +1132,39 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
11321132
if self.cp_size * self.sp_size > 1:
11331133
decode_q_wo_k_up_pe = decode_q_wo_k_up[..., self.qk_nope_head_dim:]
11341134
decode_q_wo_k_up_pe = self.rope_single(decode_q_wo_k_up_pe, cos, sin)
1135+
decode_q_wo_k_up[..., self.qk_nope_head_dim:] = decode_q_wo_k_up_pe
1136+
decode_kv_no_split = kv_no_split[:num_decode_tokens]
1137+
kv_c, k_pe = decode_kv_no_split.split(
1138+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1139+
k_pe = k_pe.unsqueeze(1)
1140+
decode_k_pe = k_pe[:num_decode_tokens]
1141+
decode_k_pe = self.rope_single(decode_k_pe, cos, sin)
1142+
k_pe[:num_decode_tokens] = decode_k_pe
11351143
else:
1136-
# TODO sp_cp的decode的适配这里还不完整,待思考位置编码的处理
11371144
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11381145
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
11391146
decode_kv_no_split = kv_no_split[:num_decode_tokens]
1140-
decode_k_pe, decode_k_nope = self.exec_kv_decode(
1147+
if self.cp_size * self.sp_size > 1:
1148+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1149+
assert len(
1150+
kv_cache
1151+
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
1152+
kv_c_normed = kv_c_normed.view(
1153+
[num_actual_tokens, self.num_kv_heads, -1])
1154+
torch_npu._npu_reshape_and_cache(
1155+
key=kv_c_normed,
1156+
value=k_pe,
1157+
key_cache=kv_cache[0],
1158+
value_cache=kv_cache[1],
1159+
slot_indices=attn_metadata.slot_mapping)
1160+
else:
1161+
decode_k_pe, decode_k_nope = self.exec_kv_decode(
11411162
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
1142-
decode_preprocess_res = DecodeMLAPreprocessResult(
1143-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, decode_q_wo_k_up)
1163+
if self.cp_size * self.sp_size > 1:
1164+
decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_wo_k_up=decode_q_wo_k_up)
1165+
else:
1166+
decode_preprocess_res = DecodeMLAPreprocessResult(
1167+
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
11441168
# Preprocess for prefill tokens
11451169
if has_prefill:
11461170
prefill_kv_no_split = kv_no_split[
@@ -1154,22 +1178,46 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
11541178
sin = attn_metadata.prefill.sin
11551179
prefill_slots = attn_metadata.slot_mapping[
11561180
num_decode_tokens:num_actual_tokens]
1157-
# TODO 待确认这里的位置编码是否会影响到kv cathe的存储
11581181
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11591182
if self.cp_size > 1:
1160-
prefill_kv_no_split = get_cp_group().all_gather(prefill_kv_no_split, 0)
1161-
prefill_kv_no_split = torch.index_select(prefill_kv_no_split, 0, attn_metadata.prefill.cp_kv_recover_idx)
1162-
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
1163-
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
1164-
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
1165-
self.num_kv_heads, -1)
1183+
kv_c, k_pe = prefill_kv_no_split.split(
1184+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1185+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1186+
assert len(
1187+
kv_cache
1188+
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
1189+
kv_c_normed = kv_c_normed.view(
1190+
[num_actual_tokens, self.num_kv_heads, -1])
1191+
k_pe = k_pe.unsqueeze(1)
1192+
prefill_k_pe = k_pe[num_decode_tokens:]
1193+
prefill_k_pe = self.rope_single(prefill_k_pe, cos, sin)
1194+
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
1195+
1196+
prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe], dim=-1)
1197+
prefill_kv_c_k_pe = get_cp_group().all_gather(prefill_kv_c_k_pe, 0)
1198+
prefill_kv_c_k_pe = torch.index_select(prefill_kv_c_k_pe, 0, attn_metadata.prefill.cp_kv_recover_idx)
1199+
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split([self.kv_lora_rank, self.qk_rope_head_dim],
1200+
dim=-1)
1201+
kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe
1202+
prefill_k_c_normed = prefill_k_c_normed.squeeze()
1203+
torch_npu._npu_reshape_and_cache(
1204+
key=kv_c_normed,
1205+
value=k_pe,
1206+
key_cache=kv_cache[0],
1207+
value_cache=kv_cache[1],
1208+
slot_indices=attn_metadata.slot_mapping)
1209+
else:
1210+
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
1211+
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
11661212
prefill_k_nope, prefill_value = self.kv_b_proj(
1167-
prefill_k_c_normed)[0].view(
1168-
-1, self.num_heads,
1213+
prefill_k_c_normed)[0].view(
1214+
-1, self.num_heads,
11691215
self.qk_nope_head_dim + self.v_head_dim).split(
11701216
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1171-
prefill_k_pe = prefill_k_pe.expand(
1172-
(*prefill_k_nope.shape[:-1], -1))
1217+
if not self.cp_size > 1:
1218+
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
1219+
self.num_kv_heads, -1)
1220+
prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1))
11731221
prefill_preprocess_res = PrefillMLAPreprocessResult(
11741222
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
11751223
prefill_value)

0 commit comments

Comments
 (0)