Skip to content

[WIP][Prefill Performance] Parallel Strategy Optimizations (VRAM-for-Speed Tradeoff) #1687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: v0.9.1-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
23a7889
Prefill optimization draft
SlightwindSec Jul 9, 2025
a175c70
fix mc2 mask
SlightwindSec Jul 9, 2025
687dd2d
qkv dp
Jul 9, 2025
2fb03eb
fix mc2 mask
SlightwindSec Jul 10, 2025
925f65e
fix qkv dp with mtp
Jul 10, 2025
e8f9733
Merge remote-tracking branch 'upstream/v0.9.1-dev' into upstream_v0.9…
Jul 14, 2025
2adcc11
resolve conficts with chunked mc2
Jul 14, 2025
9956a30
resolve conficts with chunked mc2
Jul 14, 2025
ae401bb
feat: add flag to toggle prefill performance improvements
SlightwindSec Jul 15, 2025
0c7db37
remove unused import
SlightwindSec Jul 15, 2025
6029cde
[bugfix] prefill optimization support torchair graph mode
kunpengW-code Jul 16, 2025
92994fb
async h2d
Jul 16, 2025
9cd37ca
fix unpadding
Jul 17, 2025
8024f73
fix unpadding
Jul 17, 2025
d27702f
[prefill] access ATBMLAPrefill operator
kunpengW-code Jul 17, 2025
7d43361
fix unpadding
Jul 17, 2025
64183a2
Merge remote-tracking branch 'upstream/v0.9.1-dev' into upstream_v0.9…
Jul 18, 2025
802fc2a
support eplb with prefill optimization
Jul 18, 2025
f581038
merge upstream_v0.9.1-dev
SlightwindSec Jul 23, 2025
427ee25
Merge branch 'v0.9.1-dev' of https://github.com/vllm-project/vllm-asc…
kunpengW-code Jul 24, 2025
3615190
[bugfix] prefill optimization support prefix_cache
kunpengW-code Jul 24, 2025
38e73c8
[bugfix] support floating
kunpengW-code Jul 24, 2025
42d69be
fix mtp in torchair graph
SlightwindSec Jul 25, 2025
4684127
Merge remote-tracking branch 'upstream/v0.9.1-dev' into upstream_v0.9…
SlightwindSec Jul 26, 2025
ed1aeee
fix lint
SlightwindSec Jul 26, 2025
88c56fb
fix lint
SlightwindSec Jul 26, 2025
b93e6a0
fix lint
SlightwindSec Jul 26, 2025
26c005a
support decode_torchair
SlightwindSec Jul 29, 2025
34e3f9a
Merge remote-tracking branch 'upstream/v0.9.1-dev' into upstream_v0.9…
Jul 30, 2025
0f7d66b
Merge remote-tracking branch 'upstream/v0.9.1-dev' into upstream_v0.9…
SlightwindSec Jul 31, 2025
0dfc48c
remove unused var
SlightwindSec Jul 31, 2025
92b1f17
remove fused_experts_with_all2all_v2()
SlightwindSec Jul 31, 2025
1442e41
add [WIP]npu_moe_init_routing_quantv2
SlightwindSec Jul 31, 2025
b1b651f
Merge remote-tracking branch 'upstream/v0.9.1-dev' into upstream_v0.9…
SlightwindSec Aug 1, 2025
4c19ed3
Merge remote-tracking branch 'upstream/v0.9.1-dev' into upstream_v0.9…
SlightwindSec Aug 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch

try:
from torch_npu.atb import npu_mla_prefill # noqa: F401
ATB_MLA_PREFILL_ENABLED = True
except ImportError:
ATB_MLA_PREFILL_ENABLED = False


class AscendMLABackend(AttentionBackend):

Expand Down Expand Up @@ -623,6 +629,7 @@ def __init__(
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_prefill_optimizations = ascend_config.enable_prefill_optimizations

# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
Expand Down Expand Up @@ -882,17 +889,41 @@ def _forward_prefill(
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)

elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
key = torch.cat((k_nope, k_pe), dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
if not self.enable_prefill_optimizations or not ATB_MLA_PREFILL_ENABLED:
key = torch.cat((k_nope, k_pe), dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
else:
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
mask = torch.triu(
torch.ones(512,
512,
device=query.device,
dtype=query.dtype),
1) # 512: mask only support 512
torch_npu.atb.npu_mla_prefill(
q=q_nope,
q_rope=q_pe,
k=k_nope,
k_rope=k_pe,
v=value,
q_seqlen=attn_metadata.prefill.context_lens,
kv_seqlen=attn_metadata.prefill.context_lens,
q_headnum=self.num_heads,
qk_scale=self.scale,
kv_headnum=self.num_heads,
mask=mask,
mask_type="mask_type_free",
output=attn_output)
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
Expand Down
12 changes: 12 additions & 0 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,18 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
expert_tokens_before_capacity_flag=False,
quant_mode=1,
)
elif hasattr(torch_npu, "npu_moe_init_routing_quantv2"): # TODO: Remove it
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quantv2(
hidden_states,
expert_idx=topk_ids.to(torch.int32),
active_num=0,
expert_capacity=0,
expert_num=global_num_experts,
drop_pad_mode=0,
expert_tokens_count_or_cumsum_flag=2,
expert_tokens_before_capacity_flag=False,
quant_mode=1,
)
else:
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
hidden_states, top_k, topk_ids, global_num_experts)
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def _make_attention_mask(self, seq_lens, query_lens, position,
seq_lens, query_lens, position, self.dtype, self.device)
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
max_seq_len = 128
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
Expand Down
Loading