@@ -213,8 +213,16 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
213213 output ,
214214 ) = param
215215 seq_lens = forward_context .attn_metadata [key ].seq_lens
216- torch .npu .graph_task_update_begin (update_stream , handle )
217- torch_npu ._npu_paged_attention (
216+
217+ # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
218+ # mode with GQA. This is triggered by getting workspace for _npu_paged_attention
219+ # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
220+ # might encounter a bigger workspace, while currently we use max_model_len to
221+ # calculate max workspace in capturing. So additional get_workspace is added
222+ # here to avoid such bugs.
223+ # TODO(Angazenn): we will remove this once _npu_paged_attention is fully
224+ # replaced by npu_fused_infer_attention_score which does not contain such bugs.
225+ workspace = torch_npu ._npu_paged_attention_get_workspace (
218226 query = query ,
219227 key_cache = key_cache ,
220228 value_cache = value_cache ,
@@ -223,8 +231,18 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
223231 scale_value = scale ,
224232 block_table = block_table ,
225233 context_lens = seq_lens ,
226- out = output ,
227- workspace = graph_params .workspaces .get (runtime_shape ))
234+ out = output )
235+ torch .npu .graph_task_update_begin (update_stream , handle )
236+ torch_npu ._npu_paged_attention (query = query ,
237+ key_cache = key_cache ,
238+ value_cache = value_cache ,
239+ num_kv_heads = num_kv_heads ,
240+ num_heads = num_heads ,
241+ scale_value = scale ,
242+ block_table = block_table ,
243+ context_lens = seq_lens ,
244+ out = output ,
245+ workspace = workspace )
228246 torch .npu .graph_task_update_end (update_stream )
229247
230248 event .record (update_stream )
0 commit comments