Skip to content

Commit 3fcce18

Browse files
committed
bugfix
Signed-off-by: Angazenn <supperccell@163.com>
1 parent 5f08e07 commit 3fcce18

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,7 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
213213
output,
214214
) = param
215215
seq_lens = forward_context.attn_metadata[key].seq_lens
216-
torch.npu.graph_task_update_begin(update_stream, handle)
217-
torch_npu._npu_paged_attention(
216+
workspace = torch_npu._npu_paged_attention_get_workspace(
218217
query=query,
219218
key_cache=key_cache,
220219
value_cache=value_cache,
@@ -223,8 +222,19 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
223222
scale_value=scale,
224223
block_table=block_table,
225224
context_lens=seq_lens,
226-
out=output,
227-
workspace=graph_params.workspaces.get(runtime_shape))
225+
out=output)
226+
227+
torch.npu.graph_task_update_begin(update_stream, handle)
228+
torch_npu._npu_paged_attention(query=query,
229+
key_cache=key_cache,
230+
value_cache=value_cache,
231+
num_kv_heads=num_kv_heads,
232+
num_heads=num_heads,
233+
scale_value=scale,
234+
block_table=block_table,
235+
context_lens=seq_lens,
236+
out=output,
237+
workspace=workspace)
228238
torch.npu.graph_task_update_end(update_stream)
229239

230240
event.record(update_stream)

0 commit comments

Comments
 (0)