From 5da64d7ffcc7b7d708bf6f7781c04b2337dc1368 Mon Sep 17 00:00:00 2001 From: wangxiaoxin-sherie Date: Tue, 4 Nov 2025 21:00:36 +0800 Subject: [PATCH] optimize fullgraph. Signed-off-by: wangxiaoxin-sherie --- vllm_ascend/compilation/acl_graph.py | 89 ++++++++++++++-------------- 1 file changed, 46 insertions(+), 43 deletions(-) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index d3e779e2bb..c96b348b4b 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -193,24 +193,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - ( - query, - key_cache, - value_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - output, - ) = param - seq_lens = forward_context.attn_metadata[key].seq_lens + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + seq_lens = forward_context.attn_metadata[key].seq_lens # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY # mode with GQA. This is triggered by getting workspace for _npu_paged_attention @@ -253,31 +254,33 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, - spec_attn_mask, sparse_mode, scale, block_table, block_size, - seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param - seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list - if speculative_config and speculative_config.method == "deepseek_mtp": - actual_seq_lengths = forward_context.attn_metadata[ - key].decode.actual_seq_lengths_q - spec_multiple = speculative_config.num_speculative_tokens + 1 - seq_lens_list = seq_lens_list + [0] * ( - runtime_shape // spec_multiple - len(seq_lens_list)) - actual_seq_lengths = [ - spec_multiple * (i + 1) - for i in range(runtime_shape // spec_multiple) - ] - else: - seq_lens_list = seq_lens_list + [0] * (runtime_shape - - len(seq_lens_list)) - with torch.npu.stream(update_stream): - torch.npu.graph_task_update_begin(update_stream, handle) + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + spec_attn_mask, sparse_mode, scale, block_table, block_size, + seq_lens_list, actual_seq_lengths, attn_output, + softmax_lse) = param + seq_lens_list = forward_context.attn_metadata[ + key].decode.seq_lens_list + if speculative_config and speculative_config.method == "deepseek_mtp": + actual_seq_lengths = forward_context.attn_metadata[ + key].decode.actual_seq_lengths_q + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_lens_list = seq_lens_list + [0] * ( + runtime_shape // spec_multiple - len(seq_lens_list)) + actual_seq_lengths = [ + spec_multiple * (i + 1) + for i in range(runtime_shape // spec_multiple) + ] + else: + seq_lens_list = seq_lens_list + [0] * (runtime_shape - + len(seq_lens_list)) + torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( q_nope,