Skip to content

Commit 1aace42

Browse files
committed
[bugfix] fix kv_nz accuracy problem and delete redundant reshape_and_cache operation
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 12bcbd0 commit 1aace42

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

vllm_ascend/torchair/torchair_mla.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,17 +1221,9 @@ def forward(
12211221
assert len(
12221222
kv_cache
12231223
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
1224-
if self.torchair_graph_enabled:
1225-
if kv_cache[0].numel() > 0 and has_prefill:
1226-
slots = attn_metadata.slot_mapping
1227-
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
1228-
torch_npu._npu_reshape_and_cache(
1229-
key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
1230-
value=prefill_k_pe,
1231-
key_cache=kv_cache[0],
1232-
value_cache=kv_cache[1],
1233-
slot_indices=slots[num_decode_tokens:])
1234-
else:
1224+
# NOTE: Since CP/SP and shared_expert dp features temporarily depen on torchair modeling
1225+
# and attention backend, cases without torchair_graph_enabled should be considered here.
1226+
if not self.torchair_graph_enabled:
12351227
kv_c_normed = kv_c_normed.view(
12361228
[num_actual_toks, self.num_kv_heads, -1])
12371229
torch_npu._npu_reshape_and_cache(

0 commit comments

Comments
 (0)