Skip to content

Conversation

lilinsiman
Copy link
Contributor

@lilinsiman lilinsiman commented Sep 9, 2025

What this PR does / why we need it?

  1. Solved the problem that in the Qwen3 Moe model case, opening DP would use an extra stream, causing ACLgraph sizes capture error
  2. After experimentation, it was found that in many cases, some operators would occupy more streams than expected. Therefore, the buffer area for streams in ACLgraph was not large enough. After discussion, extra 120 streams were added as buffer.

Does this PR introduce any user-facing change?

no

How was this patch tested?

ut

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an ACLgraph size capture error for MoE models with data parallelism by adjusting the stream buffer size and accounting for extra stream usage in MoE models. The changes look reasonable, but I've identified a potential issue in the newly introduced is_moe_model function. The current implementation for detecting MoE models is too broad and could lead to false positives. I've suggested a more robust implementation to improve accuracy.

Comment on lines +586 to +589
def is_moe_model(vllm_config: VllmConfig):
config = vllm_config.model_config.hf_config
return any('experts' in key.lower() for key in config.to_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation to detect a Mixture-of-Experts (MoE) model is based on checking if 'experts' is a substring in any of the configuration keys. This approach is fragile and might lead to false positives if a non-MoE model has a configuration key containing this substring (e.g., use_shared_experts). This could cause incorrect logic to be applied, potentially leading to suboptimal performance or runtime errors.

A more robust approach would be to check for specific attributes that are characteristic of a MoE architecture, such as num_experts and num_experts_per_tok, and ensure they have meaningful values. This will make the detection more accurate and prevent incorrect behavior.

Suggested change
def is_moe_model(vllm_config: VllmConfig):
config = vllm_config.model_config.hf_config
return any('experts' in key.lower() for key in config.to_dict())
def is_moe_model(vllm_config: VllmConfig):
config = vllm_config.model_config.hf_config
# A more robust check for MoE models by verifying specific attributes
# that are characteristic of such architectures.
return (hasattr(config, "num_experts") and
isinstance(config.num_experts, int) and config.num_experts > 0 and
hasattr(config, "num_experts_per_tok"))

Copy link

github-actions bot commented Sep 9, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@lilinsiman lilinsiman force-pushed the main branch 2 times, most recently from 79698d9 to d7755b8 Compare September 9, 2025 13:24
@lilinsiman lilinsiman force-pushed the main branch 3 times, most recently from fd26cec to 917f4de Compare September 10, 2025 06:13
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
# TODO: Find out whether we need to take into account the pp_size
parallel_factor = 1 + num_comm_groups + int(
parallel_config.enable_expert_parallel)
if is_moe_model(vllm_config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need to check if is moe model right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for MoE, we'll have to do prepare and finalize around quant_method.apply, which will have extra communication.

@MengqingCao
Copy link
Collaborator

ut failed due to some known issues, which are unrelated to this pr. Let's merge this after the full e2e test passed

@MengqingCao MengqingCao added ready read for review ready-for-test start test by label for PR labels Sep 10, 2025
@MengqingCao MengqingCao merged commit b7df04d into vllm-project:main Sep 10, 2025
33 of 34 checks passed
@Yikun
Copy link
Collaborator

Yikun commented Sep 10, 2025

Please rewrite the PR title as a meaningful one. I didn't get any human readable info from original title.

@lilinsiman lilinsiman changed the title debug_aclgraph_sizes_capture [Bugfix] aclgraph_sizes_capture Sep 11, 2025
@lilinsiman lilinsiman changed the title [Bugfix] aclgraph_sizes_capture [Bugfix] Fix aclgraph sizes capture error Sep 11, 2025
@lilinsiman
Copy link
Contributor Author

Please rewrite the PR title as a meaningful one. I didn't get any human readable info from original title.

Already done

@lilinsiman lilinsiman changed the title [Bugfix] Fix aclgraph sizes capture error [Bugfix] Fix aclgraph sizes capture error in Qwen3 Moe model case Sep 11, 2025
yiz-liu pushed a commit to linfeng-yuan/vllm-ascend that referenced this pull request Sep 12, 2025
### What this PR does / why we need it?
1. Solved the problem that in the Qwen3 Moe model case, opening DP would
use an extra stream, causing ACLgraph sizes capture error
2. After experimentation, it was found that in many cases, some
operators would occupy more streams than expected. Therefore, the buffer
area for streams in ACLgraph was not large enough. After discussion,
extra 120 streams were added as buffer.

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut

- vLLM version: main
- vLLM main:
vllm-project/vllm@0ae43db

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
parallel_factor += (parallel_config.data_parallel_size > 1)
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
# Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also update the note.

offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
### What this PR does / why we need it?
1. Solved the problem that in the Qwen3 Moe model case, opening DP would
use an extra stream, causing ACLgraph sizes capture error
2. After experimentation, it was found that in many cases, some
operators would occupy more streams than expected. Therefore, the buffer
area for streams in ACLgraph was not large enough. After discussion,
extra 120 streams were added as buffer.

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut

- vLLM version: main
- vLLM main:
vllm-project/vllm@0ae43db

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: offline0806 <z00858301@china.huawei.com>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
### What this PR does / why we need it?
1. Solved the problem that in the Qwen3 Moe model case, opening DP would
use an extra stream, causing ACLgraph sizes capture error
2. After experimentation, it was found that in many cases, some
operators would occupy more streams than expected. Therefore, the buffer
area for streams in ACLgraph was not large enough. After discussion,
extra 120 streams were added as buffer.

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut

- vLLM version: main
- vLLM main:
vllm-project/vllm@0ae43db

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
1. Solved the problem that in the Qwen3 Moe model case, opening DP would
use an extra stream, causing ACLgraph sizes capture error
2. After experimentation, it was found that in many cases, some
operators would occupy more streams than expected. Therefore, the buffer
area for streams in ACLgraph was not large enough. After discussion,
extra 120 streams were added as buffer.

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut

- vLLM version: main
- vLLM main:
vllm-project/vllm@0ae43db

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:core module:tests ready read for review ready-for-test start test by label for PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants