-
Notifications
You must be signed in to change notification settings - Fork 454
[main] addrmsnorm + quant fusion optim in Dense Models #2772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fusion optimization for addrmsnorm
and w8a8
quantization on Ascend hardware. The changes add a new environment flag to control the feature and modify the forward context and layer normalization logic to implement the fusion. While the overall direction is good for performance, I've found a few critical issues in the implementation. The core fusion logic in AscendRMSNorm
appears to be reversed, and it fails to increment the layer index, which would lead to incorrect behavior for models with more than one layer. Additionally, there are missing arguments in function calls that would cause runtime errors. I've provided suggestions to fix these issues.
if envs_ascend.VLLM_ASCEND_ENABLE_ADDRMSNORM_QUANT_FUSION: | ||
forward_context.prefetch_model = prefetch_model | ||
forward_context.layer_idx = 0 | ||
forward_context.fusion_linear = "gate_up" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm_ascend/ops/layernorm.py
Outdated
forward_context = get_forward_context() | ||
prefetch_model = forward_context.prefetch_model | ||
layer_idx = forward_context.layer_idx | ||
fusion_linear = forward_context.fusion_linear | ||
if fusion_linear == "gate_up": | ||
fusion_linear = prefetch_model.model.layers[layer_idx].mlp.gate_up_proj | ||
forward_context.fusion_linear = "qkv" | ||
elif fusion_linear == "qkv": | ||
fusion_linear = prefetch_model.model.layers[layer_idx].self_attn.qkv_proj | ||
forward_context.fusion_linear = "gate_up" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appear to be two critical issues in this logic:
-
Incorrect Fusion Logic: The fusion logic seems to be reversed. In a standard transformer block, the first
RMSNorm
(pre-attention) should be fused with theqkv_proj
of the attention layer, and the secondRMSNorm
(pre-MLP) should be fused with thegate_up_proj
of the MLP. The current implementation does the opposite. -
Missing
layer_idx
Increment: Thelayer_idx
is fetched from theforward_context
but is never incremented. This will cause the fusion to always use the layers from index 0 (prefetch_model.model.layers[0]
) for all subsequent transformer layers, which is incorrect.
To fix this, the fusion logic should be corrected, and layer_idx
should be incremented after the second RMSNorm of each layer. The initial value for fusion_linear
in ascend_forward_context.py
should also be adjusted.
forward_context = get_forward_context() | |
prefetch_model = forward_context.prefetch_model | |
layer_idx = forward_context.layer_idx | |
fusion_linear = forward_context.fusion_linear | |
if fusion_linear == "gate_up": | |
fusion_linear = prefetch_model.model.layers[layer_idx].mlp.gate_up_proj | |
forward_context.fusion_linear = "qkv" | |
elif fusion_linear == "qkv": | |
fusion_linear = prefetch_model.model.layers[layer_idx].self_attn.qkv_proj | |
forward_context.fusion_linear = "gate_up" | |
forward_context = get_forward_context() | |
prefetch_model = forward_context.prefetch_model | |
layer_idx = forward_context.layer_idx | |
fusion_target = forward_context.fusion_linear | |
if fusion_target == "qkv": | |
fusion_linear = prefetch_model.model.layers[layer_idx].self_attn.qkv_proj | |
forward_context.fusion_linear = "gate_up" | |
elif fusion_target == "gate_up": | |
fusion_linear = prefetch_model.model.layers[layer_idx].mlp.gate_up_proj | |
forward_context.fusion_linear = "qkv" | |
forward_context.layer_idx += 1 |
vllm_ascend/ops/layernorm.py
Outdated
x, residual = _addrmsnorm_w8a8_quant_forward_oot(x, residual, fusion_linear) | ||
else: | ||
x, residual = _addrmsnorm_forward_oot(x, residual) | ||
else: | ||
x, _, residual = torch_npu.npu_add_rms_norm( | ||
x, residual, self.weight, self.variance_epsilon) | ||
x, residual = _addrmsnorm_forward_oot(x, residual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calls to _addrmsnorm_w8a8_quant_forward_oot
and _addrmsnorm_forward_oot
are missing the self
argument. These helper functions are defined to take self
as their first parameter, but it's not being passed in the calls. This will lead to a TypeError
at runtime.
x, residual = _addrmsnorm_w8a8_quant_forward_oot(x, residual, fusion_linear) | |
else: | |
x, residual = _addrmsnorm_forward_oot(x, residual) | |
else: | |
x, _, residual = torch_npu.npu_add_rms_norm( | |
x, residual, self.weight, self.variance_epsilon) | |
x, residual = _addrmsnorm_forward_oot(x, residual) | |
x, residual = _addrmsnorm_w8a8_quant_forward_oot(self, x, residual, fusion_linear) | |
else: | |
x, residual = _addrmsnorm_forward_oot(self, x, residual) | |
else: | |
x, residual = _addrmsnorm_forward_oot(self, x, residual) |
43bdd7e
to
b76b16e
Compare
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
b76b16e
to
f025bfa
Compare
Can this PR handle the situation where RMSNorm has an extra bias that is introduced by anti-outlier algorithms? |
daa0827
to
9b2255c
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
9b2255c
to
18b32a2
Compare
dc697af
to
01756d1
Compare
ce5eafb
to
ca10b7f
Compare
Signed-off-by: rjg-lyh <1318825571@qq.com>
ca10b7f
to
224efd4
Compare
What this PR does / why we need it?
This PR fused addrmsnorm op and w8a8 quant op to get better perf.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
CI passed with new added/existing test.