Skip to content

Commit f4d4c0a

Browse files
[V0.9.1] torchair_graph bugfix when chunked_prefill is true (#1748)
### What this PR does / why we need it? when torchair_graph and chunked_prefill are both true, save the decode kv_cache. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? --------- Signed-off-by: fems14 <1804143737@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
1 parent a8294b6 commit f4d4c0a

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ def forward(
11971197
prefill_hs, cos, sin, kv_cache,
11981198
attn_metadata.slot_mapping[num_decode_tokens:])
11991199

1200-
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
1200+
kv_c_normed_prefill = prefill_k_nope[:num_actual_toks, ...]
12011201
prefill_k_c_normed = prefill_k_nope
12021202
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
12031203
-1)
@@ -1215,12 +1215,23 @@ def forward(
12151215
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
12161216
slots = attn_metadata.slot_mapping
12171217
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
1218-
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
1218+
torch_npu._npu_reshape_and_cache(key=kv_c_normed_prefill.view(
12191219
num_tokens, self.num_kv_heads, -1),
12201220
value=prefill_k_pe,
12211221
key_cache=kv_cache[0],
12221222
value_cache=kv_cache[1],
12231223
slot_indices=slots)
1224+
1225+
if kv_cache[0].numel(
1226+
) > 0 and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill and has_decode:
1227+
slots = attn_metadata.slot_mapping[:num_decode_tokens]
1228+
k_c_normed_decode = kv_c_normed[:num_decode_tokens]
1229+
torch_npu._npu_reshape_and_cache(key=k_c_normed_decode.view(
1230+
num_decode_tokens, self.num_kv_heads, -1),
1231+
value=decode_k_pe,
1232+
key_cache=kv_cache[0],
1233+
value_cache=kv_cache[1],
1234+
slot_indices=slots)
12241235
else:
12251236
kv_c_normed = kv_c_normed.view(
12261237
[num_actual_toks, self.num_kv_heads, -1])

0 commit comments

Comments
 (0)