Skip to content

Conversation

rjg-lyh
Copy link
Collaborator

@rjg-lyh rjg-lyh commented Sep 5, 2025

What this PR does / why we need it?

This PR fused addrmsnorm op and w8a8 quant op to get better perf.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

CI passed with new added/existing test.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fusion optimization for addrmsnorm and w8a8 quantization on Ascend hardware. The changes add a new environment flag to control the feature and modify the forward context and layer normalization logic to implement the fusion. While the overall direction is good for performance, I've found a few critical issues in the implementation. The core fusion logic in AscendRMSNorm appears to be reversed, and it fails to increment the layer index, which would lead to incorrect behavior for models with more than one layer. Additionally, there are missing arguments in function calls that would cause runtime errors. I've provided suggestions to fix these issues.

if envs_ascend.VLLM_ASCEND_ENABLE_ADDRMSNORM_QUANT_FUSION:
forward_context.prefetch_model = prefetch_model
forward_context.layer_idx = 0
forward_context.fusion_linear = "gate_up"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initial value for fusion_linear should be "qkv" to correctly fuse with the pre-attention RMSNorm first. The current value "gate_up" causes the fusion logic to be reversed.

Suggested change
forward_context.fusion_linear = "gate_up"
forward_context.fusion_linear = "qkv"

Comment on lines 112 to 122
forward_context = get_forward_context()
prefetch_model = forward_context.prefetch_model
layer_idx = forward_context.layer_idx
fusion_linear = forward_context.fusion_linear
if fusion_linear == "gate_up":
fusion_linear = prefetch_model.model.layers[layer_idx].mlp.gate_up_proj
forward_context.fusion_linear = "qkv"
elif fusion_linear == "qkv":
fusion_linear = prefetch_model.model.layers[layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_up"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appear to be two critical issues in this logic:

  1. Incorrect Fusion Logic: The fusion logic seems to be reversed. In a standard transformer block, the first RMSNorm (pre-attention) should be fused with the qkv_proj of the attention layer, and the second RMSNorm (pre-MLP) should be fused with the gate_up_proj of the MLP. The current implementation does the opposite.

  2. Missing layer_idx Increment: The layer_idx is fetched from the forward_context but is never incremented. This will cause the fusion to always use the layers from index 0 (prefetch_model.model.layers[0]) for all subsequent transformer layers, which is incorrect.

To fix this, the fusion logic should be corrected, and layer_idx should be incremented after the second RMSNorm of each layer. The initial value for fusion_linear in ascend_forward_context.py should also be adjusted.

Suggested change
forward_context = get_forward_context()
prefetch_model = forward_context.prefetch_model
layer_idx = forward_context.layer_idx
fusion_linear = forward_context.fusion_linear
if fusion_linear == "gate_up":
fusion_linear = prefetch_model.model.layers[layer_idx].mlp.gate_up_proj
forward_context.fusion_linear = "qkv"
elif fusion_linear == "qkv":
fusion_linear = prefetch_model.model.layers[layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_up"
forward_context = get_forward_context()
prefetch_model = forward_context.prefetch_model
layer_idx = forward_context.layer_idx
fusion_target = forward_context.fusion_linear
if fusion_target == "qkv":
fusion_linear = prefetch_model.model.layers[layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_up"
elif fusion_target == "gate_up":
fusion_linear = prefetch_model.model.layers[layer_idx].mlp.gate_up_proj
forward_context.fusion_linear = "qkv"
forward_context.layer_idx += 1

Comment on lines 128 to 152
x, residual = _addrmsnorm_w8a8_quant_forward_oot(x, residual, fusion_linear)
else:
x, residual = _addrmsnorm_forward_oot(x, residual)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
x, residual = _addrmsnorm_forward_oot(x, residual)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calls to _addrmsnorm_w8a8_quant_forward_oot and _addrmsnorm_forward_oot are missing the self argument. These helper functions are defined to take self as their first parameter, but it's not being passed in the calls. This will lead to a TypeError at runtime.

Suggested change
x, residual = _addrmsnorm_w8a8_quant_forward_oot(x, residual, fusion_linear)
else:
x, residual = _addrmsnorm_forward_oot(x, residual)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
x, residual = _addrmsnorm_forward_oot(x, residual)
x, residual = _addrmsnorm_w8a8_quant_forward_oot(self, x, residual, fusion_linear)
else:
x, residual = _addrmsnorm_forward_oot(self, x, residual)
else:
x, residual = _addrmsnorm_forward_oot(self, x, residual)

Copy link

github-actions bot commented Sep 5, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@rjg-lyh rjg-lyh changed the title [main] addrmsnorm + quant fusion optim [main] addrmsnorm + quant fusion optim in Qwen Models Sep 5, 2025
Copy link

github-actions bot commented Sep 8, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@Angazenn
Copy link
Contributor

Angazenn commented Sep 9, 2025

Can this PR handle the situation where RMSNorm has an extra bias that is introduced by anti-outlier algorithms?

@rjg-lyh rjg-lyh force-pushed the pr-rms-refactor branch 7 times, most recently from daa0827 to 9b2255c Compare September 10, 2025 06:26
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@rjg-lyh rjg-lyh force-pushed the pr-rms-refactor branch 7 times, most recently from dc697af to 01756d1 Compare September 15, 2025 07:30
@rjg-lyh rjg-lyh changed the title [main] addrmsnorm + quant fusion optim in Qwen Models [main] addrmsnorm + quant fusion optim in Dense Models Sep 15, 2025
@rjg-lyh rjg-lyh force-pushed the pr-rms-refactor branch 3 times, most recently from ce5eafb to ca10b7f Compare September 15, 2025 09:22
@rjg-lyh rjg-lyh added ready read for review accuracy-test enable all accuracy test for PR labels Sep 15, 2025
@MengqingCao MengqingCao added the ready-for-test start test by label for PR label Sep 16, 2025
@rjg-lyh rjg-lyh removed the accuracy-test enable all accuracy test for PR label Sep 16, 2025
Signed-off-by: rjg-lyh <1318825571@qq.com>
@MengqingCao MengqingCao merged commit 6b7117d into vllm-project:main Sep 16, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants