Skip to content

Conversation

1Fire4
Copy link
Contributor

@1Fire4 1Fire4 commented Sep 9, 2025

[BugFix] qwen_moe can't run by ep_size is less than 16

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue preventing qwen_moe models from running with an expert parallel size less than 16. The main fix involves removing the frozen_parameter configuration in TorchAir, which likely accommodates the dynamic nature of MoE models. While this change seems correct for MoE, it's applied globally and could potentially impact the performance of non-MoE models. Additionally, a critical issue has been introduced in the test suite, where correctness assertions have been removed, effectively turning a regression test into a smoke test. This significantly reduces test coverage and should be rectified.

Comment on lines 166 to 212
def _qwen_torchair_test_fixture(
model,
tp,
enable_expert_parallel,
):
# The current access control does not support 16 cards,
# so the MC2 operator in Qwen's graph mode cannot run.
# Once 16-card support is available,
# this e2e can be switched to graph mode.
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}

with VllmRunner(
model,
dtype="half",
tensor_parallel_size=tp,
distributed_executor_backend="mp",
enforce_eager=True,
additional_config=additional_config,
enable_expert_parallel=enable_expert_parallel,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
def stubbed_get_state(ep_size, with_prefill, is_deepseek_v3_r1):
return _get_fused_moe_state(16, with_prefill, is_deepseek_v3_r1)

with patch(
"vllm_ascend.ascend_forward_context._get_fused_moe_state",
stubbed_get_state):
# The current access control does not support 16 cards,
# so the MC2 operator in Qwen's graph mode cannot run.
# Once 16-card support is available,
# this e2e can be switched to graph mode.
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}

# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
# with 2 hidden layers, thus the golden results seems inaccurate.
# This will only change if accuracy changes with the official weights
# of PanguProMoE.
golden_results = [
'Hello, my name is Remempondeprecatedmiot忱',
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
'The capital of France is Rememvoud administrativ Remem投',
'The future of AI isotope Segnali Zoeken精细化 supus',
]
with VllmRunner(
model,
dtype="half",
tensor_parallel_size=tp,
distributed_executor_backend="mp",
enforce_eager=False,
additional_config=additional_config,
enable_expert_parallel=enable_expert_parallel,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)

assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
print(f"Generated text: {vllm_output[i][1]!r}")
for i in range(len(vllm_output)):
print(f"Generated text: {vllm_output[i][1]!r}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This test fixture has been modified to use a mock to enable running qwen_moe models in graph mode, which is a good workaround for testing. However, the golden output assertions have been removed, turning this from a correctness regression test into a simple smoke test that only checks for crashes. This is a significant degradation of test quality. Please restore the correctness checks. If the outputs have changed, the golden values should be updated. If the model output is non-deterministic, the test should be adapted to handle this, for example by checking against a set of possible outputs or using a less strict comparison.


config = torchair.CompilerConfig()
if get_ascend_config().torchair_graph_config.mode:
config.mode = get_ascend_config().torchair_graph_config.mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Removing config.experimental_config.frozen_parameter = True seems to be the core fix for qwen_moe models. However, this is a global configuration change in _get_torchair_lazy_compiled_model, which might affect all models compiled with TorchAir. This could potentially lead to performance regressions for non-MoE models if frozen_parameter enables important optimizations. To mitigate this risk, could this setting be disabled conditionally, only for MoE models? For example, you could add logic to detect MoE models and apply this configuration change only when necessary.

Copy link

github-actions bot commented Sep 9, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@1Fire4 1Fire4 force-pushed the main branch 2 times, most recently from 93a1148 to 08234df Compare September 11, 2025 09:02
@1Fire4 1Fire4 changed the title [Fix] qwen_moe can't run by ep_size is less than 16 [Fix] replace npu_incre_flash_attention with npu_fused_infer_attention_score Sep 11, 2025
@1Fire4 1Fire4 changed the title [Fix] replace npu_incre_flash_attention with npu_fused_infer_attention_score Replace npu_incre_flash_attention with npu_fused_infer_attention_score Sep 11, 2025
Signed-off-by: 1Fire4 <wangdingyi2@huawei.com>
Copy link

codecov bot commented Sep 11, 2025

Codecov Report

❌ Patch coverage is 0% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.29%. Comparing base (1bbb20e) to head (6dee3b4).
⚠️ Report is 28 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/torchair/torchair_attention.py 0.00% 5 Missing ⚠️

❌ Your patch check has failed because the patch coverage (0.00%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2831      +/-   ##
==========================================
+ Coverage   74.76%   75.29%   +0.52%     
==========================================
  Files         150      155       +5     
  Lines       20891    21271     +380     
==========================================
+ Hits        15620    16015     +395     
+ Misses       5271     5256      -15     
Flag Coverage Δ
unittests 75.29% <0.00%> (+0.52%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: p00465316 <panchao13@huawei.com>
@1Fire4 1Fire4 closed this Sep 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant