Skip to content

Commit 5c57b04

Browse files
committed
bugfix
Signed-off-by: Angazenn <supperccell@163.com>
1 parent 954dab6 commit 5c57b04

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -211,20 +211,38 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
211211
output,
212212
) = param
213213
seq_lens = forward_context.attn_metadata[key].seq_lens
214+
215+
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
216+
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
217+
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
218+
# might encounter a bigger workspace, while currently we use max_model_len to
219+
# calculate max workspace in capturing. So additional get_workspace is added
220+
# here to avoid such bugs.
221+
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
222+
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
223+
workspace = torch_npu._npu_paged_attention_get_workspace(
224+
query=query,
225+
key_cache=key_cache,
226+
value_cache=value_cache,
227+
num_kv_heads=num_kv_heads,
228+
num_heads=num_heads,
229+
scale_value=scale,
230+
block_table=block_table,
231+
context_lens=seq_lens,
232+
out=output)
214233

215234
with torch.npu.stream(update_stream):
216235
torch.npu.graph_task_update_begin(update_stream, handle)
217-
torch_npu._npu_paged_attention(
218-
query=query,
219-
key_cache=key_cache,
220-
value_cache=value_cache,
221-
num_kv_heads=num_kv_heads,
222-
num_heads=num_heads,
223-
scale_value=scale,
224-
block_table=block_table,
225-
context_lens=seq_lens,
226-
out=output,
227-
workspace=graph_params.workspaces.get(runtime_shape))
236+
torch_npu._npu_paged_attention(query=query,
237+
key_cache=key_cache,
238+
value_cache=value_cache,
239+
num_kv_heads=num_kv_heads,
240+
num_heads=num_heads,
241+
scale_value=scale,
242+
block_table=block_table,
243+
context_lens=seq_lens,
244+
out=output,
245+
workspace=workspace)
228246
torch.npu.graph_task_update_end(update_stream)
229247

230248
event.record(update_stream)

0 commit comments

Comments
 (0)