@@ -1132,15 +1132,39 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
1132
1132
if self .cp_size * self .sp_size > 1 :
1133
1133
decode_q_wo_k_up_pe = decode_q_wo_k_up [..., self .qk_nope_head_dim :]
1134
1134
decode_q_wo_k_up_pe = self .rope_single (decode_q_wo_k_up_pe , cos , sin )
1135
+ decode_q_wo_k_up [..., self .qk_nope_head_dim :] = decode_q_wo_k_up_pe
1136
+ decode_kv_no_split = kv_no_split [:num_decode_tokens ]
1137
+ kv_c , k_pe = decode_kv_no_split .split (
1138
+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1139
+ k_pe = k_pe .unsqueeze (1 )
1140
+ decode_k_pe = k_pe [:num_decode_tokens ]
1141
+ decode_k_pe = self .rope_single (decode_k_pe , cos , sin )
1142
+ k_pe [:num_decode_tokens ] = decode_k_pe
1135
1143
else :
1136
- # TODO sp_cp的decode的适配这里还不完整,待思考位置编码的处理
1137
1144
decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
1138
1145
decode_slots = attn_metadata .slot_mapping [:num_decode_tokens ]
1139
1146
decode_kv_no_split = kv_no_split [:num_decode_tokens ]
1140
- decode_k_pe , decode_k_nope = self .exec_kv_decode (
1147
+ if self .cp_size * self .sp_size > 1 :
1148
+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
1149
+ assert len (
1150
+ kv_cache
1151
+ ) > 1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
1152
+ kv_c_normed = kv_c_normed .view (
1153
+ [num_actual_tokens , self .num_kv_heads , - 1 ])
1154
+ torch_npu ._npu_reshape_and_cache (
1155
+ key = kv_c_normed ,
1156
+ value = k_pe ,
1157
+ key_cache = kv_cache [0 ],
1158
+ value_cache = kv_cache [1 ],
1159
+ slot_indices = attn_metadata .slot_mapping )
1160
+ else :
1161
+ decode_k_pe , decode_k_nope = self .exec_kv_decode (
1141
1162
decode_kv_no_split , cos , sin , kv_cache , decode_slots )
1142
- decode_preprocess_res = DecodeMLAPreprocessResult (
1143
- decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe , decode_q_wo_k_up )
1163
+ if self .cp_size * self .sp_size > 1 :
1164
+ decode_preprocess_res = DecodeMLAPreprocessResult (decode_q_wo_k_up = decode_q_wo_k_up )
1165
+ else :
1166
+ decode_preprocess_res = DecodeMLAPreprocessResult (
1167
+ decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe )
1144
1168
# Preprocess for prefill tokens
1145
1169
if has_prefill :
1146
1170
prefill_kv_no_split = kv_no_split [
@@ -1154,22 +1178,46 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
1154
1178
sin = attn_metadata .prefill .sin
1155
1179
prefill_slots = attn_metadata .slot_mapping [
1156
1180
num_decode_tokens :num_actual_tokens ]
1157
- # TODO 待确认这里的位置编码是否会影响到kv cathe的存储
1158
1181
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1159
1182
if self .cp_size > 1 :
1160
- prefill_kv_no_split = get_cp_group ().all_gather (prefill_kv_no_split , 0 )
1161
- prefill_kv_no_split = torch .index_select (prefill_kv_no_split , 0 , attn_metadata .prefill .cp_kv_recover_idx )
1162
- prefill_k_pe , prefill_k_c_normed = self .exec_kv_prefill (
1163
- prefill_kv_no_split , cos , sin , kv_cache , prefill_slots )
1164
- prefill_k_pe = prefill_k_pe .view (prefill_q_c .shape [0 ],
1165
- self .num_kv_heads , - 1 )
1183
+ kv_c , k_pe = prefill_kv_no_split .split (
1184
+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1185
+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
1186
+ assert len (
1187
+ kv_cache
1188
+ ) > 1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
1189
+ kv_c_normed = kv_c_normed .view (
1190
+ [num_actual_tokens , self .num_kv_heads , - 1 ])
1191
+ k_pe = k_pe .unsqueeze (1 )
1192
+ prefill_k_pe = k_pe [num_decode_tokens :]
1193
+ prefill_k_pe = self .rope_single (prefill_k_pe , cos , sin )
1194
+ prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
1195
+
1196
+ prefill_kv_c_k_pe = torch .cat ([prefill_k_c_normed , prefill_k_pe ], dim = - 1 )
1197
+ prefill_kv_c_k_pe = get_cp_group ().all_gather (prefill_kv_c_k_pe , 0 )
1198
+ prefill_kv_c_k_pe = torch .index_select (prefill_kv_c_k_pe , 0 , attn_metadata .prefill .cp_kv_recover_idx )
1199
+ prefill_k_c_normed , prefill_k_pe = prefill_kv_c_k_pe .split ([self .kv_lora_rank , self .qk_rope_head_dim ],
1200
+ dim = - 1 )
1201
+ kv_c_normed , k_pe = prefill_k_c_normed , prefill_k_pe
1202
+ prefill_k_c_normed = prefill_k_c_normed .squeeze ()
1203
+ torch_npu ._npu_reshape_and_cache (
1204
+ key = kv_c_normed ,
1205
+ value = k_pe ,
1206
+ key_cache = kv_cache [0 ],
1207
+ value_cache = kv_cache [1 ],
1208
+ slot_indices = attn_metadata .slot_mapping )
1209
+ else :
1210
+ prefill_k_pe , prefill_k_c_normed = self .exec_kv_prefill (
1211
+ prefill_kv_no_split , cos , sin , kv_cache , prefill_slots )
1166
1212
prefill_k_nope , prefill_value = self .kv_b_proj (
1167
- prefill_k_c_normed )[0 ].view (
1168
- - 1 , self .num_heads ,
1213
+ prefill_k_c_normed )[0 ].view (
1214
+ - 1 , self .num_heads ,
1169
1215
self .qk_nope_head_dim + self .v_head_dim ).split (
1170
1216
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
1171
- prefill_k_pe = prefill_k_pe .expand (
1172
- (* prefill_k_nope .shape [:- 1 ], - 1 ))
1217
+ if not self .cp_size > 1 :
1218
+ prefill_k_pe = prefill_k_pe .view (prefill_q_c .shape [0 ],
1219
+ self .num_kv_heads , - 1 )
1220
+ prefill_k_pe = prefill_k_pe .expand ((* prefill_k_nope .shape [:- 1 ], - 1 ))
1173
1221
prefill_preprocess_res = PrefillMLAPreprocessResult (
1174
1222
prefill_q_nope , prefill_q_pe , prefill_k_nope , prefill_k_pe ,
1175
1223
prefill_value )
0 commit comments