Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 47 additions & 45 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,26 +193,26 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens

with torch.npu.stream(update_stream):
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens

torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
Expand All @@ -235,31 +235,33 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (
runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, attn_output,
softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[
key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (
runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)

Comment on lines +264 to 265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to torch.npu.graph_task_update_begin() has been moved into a conditional block, but torch.npu.graph_task_update_end() is still called unconditionally. This will cause a runtime error when speculative_config and speculative_config.method == "deepseek_mtp" is true, as graph_task_update_end will be called without a corresponding graph_task_update_begin.

To fix this, torch.npu.graph_task_update_begin() should be moved out of the else block to ensure it is always called within the loop.

Suggested change
torch.npu.graph_task_update_begin(update_stream, handle)
torch.npu.graph_task_update_begin(update_stream, handle)

torch_npu.npu_fused_infer_attention_score.out(
q_nope,
Expand Down
Loading