From f575cbb3bbb388f619cbd7b5ca7da45312a079de Mon Sep 17 00:00:00 2001 From: Angazenn Date: Thu, 6 Nov 2025 17:22:53 +0800 Subject: [PATCH] bugfix Signed-off-by: Angazenn --- vllm_ascend/compilation/acl_graph.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index d9e08c84ca..be9aaad4c6 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -213,8 +213,16 @@ def update_attn_params(update_stream, forward_context, runtime_shape): output, ) = param seq_lens = forward_context.attn_metadata[key].seq_lens - torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention( + + # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY + # mode with GQA. This is triggered by getting workspace for _npu_paged_attention + # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens + # might encounter a bigger workspace, while currently we use max_model_len to + # calculate max workspace in capturing. So additional get_workspace is added + # here to avoid such bugs. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + workspace = torch_npu._npu_paged_attention_get_workspace( query=query, key_cache=key_cache, value_cache=value_cache, @@ -223,8 +231,18 @@ def update_attn_params(update_stream, forward_context, runtime_shape): scale_value=scale, block_table=block_table, context_lens=seq_lens, - out=output, - workspace=graph_params.workspaces.get(runtime_shape)) + out=output) + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) torch.npu.graph_task_update_end(update_stream) event.record(update_stream)