@@ -211,20 +211,38 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
211211 output ,
212212 ) = param
213213 seq_lens = forward_context .attn_metadata [key ].seq_lens
214+
215+ # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
216+ # mode with GQA. This is triggered by getting workspace for _npu_paged_attention
217+ # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
218+ # might encounter a bigger workspace, while currently we use max_model_len to
219+ # calculate max workspace in capturing. So additional get_workspace is added
220+ # here to avoid such bugs.
221+ # TODO(Angazenn): we will remove this once _npu_paged_attention is fully
222+ # replaced by npu_fused_infer_attention_score which does not contain such bugs.
223+ workspace = torch_npu ._npu_paged_attention_get_workspace (
224+ query = query ,
225+ key_cache = key_cache ,
226+ value_cache = value_cache ,
227+ num_kv_heads = num_kv_heads ,
228+ num_heads = num_heads ,
229+ scale_value = scale ,
230+ block_table = block_table ,
231+ context_lens = seq_lens ,
232+ out = output )
214233
215234 with torch .npu .stream (update_stream ):
216235 torch .npu .graph_task_update_begin (update_stream , handle )
217- torch_npu ._npu_paged_attention (
218- query = query ,
219- key_cache = key_cache ,
220- value_cache = value_cache ,
221- num_kv_heads = num_kv_heads ,
222- num_heads = num_heads ,
223- scale_value = scale ,
224- block_table = block_table ,
225- context_lens = seq_lens ,
226- out = output ,
227- workspace = graph_params .workspaces .get (runtime_shape ))
236+ torch_npu ._npu_paged_attention (query = query ,
237+ key_cache = key_cache ,
238+ value_cache = value_cache ,
239+ num_kv_heads = num_kv_heads ,
240+ num_heads = num_heads ,
241+ scale_value = scale ,
242+ block_table = block_table ,
243+ context_lens = seq_lens ,
244+ out = output ,
245+ workspace = workspace )
228246 torch .npu .graph_task_update_end (update_stream )
229247
230248 event .record (update_stream )
0 commit comments