Skip to content

[V0.9.1] torchair_graph bugfix when chunked_prefill is true #1748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ def forward(
prefill_hs, cos, sin, kv_cache,
attn_metadata.slot_mapping[num_decode_tokens:])

kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
kv_c_normed_prefill = prefill_k_nope[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_nope
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
-1)
Expand All @@ -1218,12 +1218,23 @@ def forward(
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
slots = attn_metadata.slot_mapping
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
torch_npu._npu_reshape_and_cache(key=kv_c_normed_prefill.view(
num_tokens, self.num_kv_heads, -1),
value=prefill_k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots)

if kv_cache[0].numel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you solve this problem for good, and remove all those redundant reshape_and_cache operation, merge all the code path into one. All the rope_k and nope_k are supposed to be reshape and cached into the kv_cache right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the graph mode (without chunked prefill) is enabled, the reshapeAndCache path is not taken; instead, the fusion operator is used. The code for this part, including the entire file (mla_v1), has already been refactored by someone. It has been evaluated and is no longer on the current branch, but it will be merged into the main branch later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine

) > 0 and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill and has_decode:
slots = attn_metadata.slot_mapping[:num_decode_tokens]
k_c_normed_decode = kv_c_normed[:num_decode_tokens]
torch_npu._npu_reshape_and_cache(key=k_c_normed_decode.view(
num_decode_tokens, self.num_kv_heads, -1),
value=decode_k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots)
else:
kv_c_normed = kv_c_normed.view(
[num_actual_toks, self.num_kv_heads, -1])
Expand Down