Skip to content

[Perf] Reduce memory usage by splitting tokens in fused_experts #1729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ApsarasX
Copy link
Collaborator

@ApsarasX ApsarasX commented Jul 10, 2025

What this PR does / why we need it?

Reduce activation memory usage during the prefill phase in MOE for long-context scenarios.

Does this PR introduce any user-facing change?

Yes. If the user wants to use this feature, they need to manually set the fused_moe_max_chunk_size field in the additional-config dictionary.

How was this patch tested?

Yes.

ApsarasX added 2 commits July 10, 2025 12:48
Signed-off-by: ApsarasX <apsarax@outlook.com>
Signed-off-by: ApsarasX <apsarax@outlook.com>
Copy link

codecov bot commented Jul 10, 2025

Codecov Report

Attention: Patch coverage is 9.67742% with 28 lines in your changes missing coverage. Please review.

Project coverage is 54.45%. Comparing base (c30ddb8) to head (2ef5601).
Report is 155 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/ops/fused_moe.py 0.00% 14 Missing ⚠️
vllm_ascend/quantization/w8a8_dynamic.py 0.00% 14 Missing ⚠️

❌ Your patch check has failed because the patch coverage (9.67%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1729       +/-   ##
===========================================
+ Coverage   27.39%   54.45%   +27.06%     
===========================================
  Files          56       80       +24     
  Lines        6191     9995     +3804     
===========================================
+ Hits         1696     5443     +3747     
- Misses       4495     4552       +57     
Flag Coverage Δ
unittests 54.45% <9.67%> (+27.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ApsarasX ApsarasX added the ready read for review label Jul 10, 2025
@@ -33,6 +33,7 @@ The following table lists the additional configuration options available in vLLM
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `fused_moe_max_chunk_size` | int | `max_num_batched_tokens * data_parallel_size` | The maximum token chunk size for the fused MoE operation. Input exceeding this size is split into multiple chunks for processing. |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need test different case of data_parallel_size to make sure this change works as expect

@ApsarasX
Copy link
Collaborator Author

ApsarasX commented Jul 14, 2025

@wangxiyuan

DeepSeek-R1-W8A8, max_model_len=max_num_batched_tokens=32768

strategy fused_moe_max_chunk_size Available memory KVCache Size
attn: dp2tp8, moe: etp16 default 1.96GB 18688 tokens(crash)
attn: dp2tp8, moe: etp16 32768 7.21GB 109162 tokens
attn: dp2tp8, moe: etp16 16384 9.84GB 150438 tokens
attn: dp2tp8, moe: etp16 8192 11.15GB 169472 tokens
attn: dp4tp4, moe: etp16 default OOM OOM
attn: dp4tp4, moe: etp16 32768 3.82GB 57428 tokens
attn: dp4tp4, moe: etp16 16384 6.37GB 97280 tokens
attn: dp4tp4, moe: etp16 8192 7.76GB 117626 tokens

DeepSeek-R1-W8A8, max_model_len=32768, max_num_batched_tokens=8192

strategy fused_moe_max_chunk_size Available memory KVCache Size
attn: dp2tp8, moe: etp16 default 11.49GB 175616 tokens
attn: dp2tp8, moe: etp16 32768 11.48GB 175488 tokens
attn: dp2tp8, moe: etp16 16384 11.48GB 175488 tokens
attn: dp2tp8, moe: etp16 8192 12.80GB 195584 tokens
attn: dp4tp4, moe: etp16 default 6.19GB 94592 tokens
attn: dp4tp4, moe: etp16 32768 6.19GB 94592 tokens
attn: dp4tp4, moe: etp16 16384 8.83GB 134912 tokens
attn: dp4tp4, moe: etp16 8192 10.13GB 154752 tokens

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions github-actions bot added merge-conflicts and removed ready read for review labels Jul 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants