Skip to content

Commit 6b7117d

Browse files
authored
[main] addrmsnorm + quant fusion optim in Dense Models (#2772)
### What this PR does / why we need it? This PR fused addrmsnorm op and w8a8 quant op to get better perf. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@0faf3cc Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 88ca8a0 commit 6b7117d

File tree

5 files changed

+219
-278
lines changed

5 files changed

+219
-278
lines changed

tests/ut/ops/test_layernorm.py

Lines changed: 137 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
from unittest.mock import patch
1+
import unittest
22

33
import pytest
44
import torch
5+
from pytest_mock import MockerFixture
56
from vllm.model_executor.layers.layernorm import RMSNorm
67

7-
8-
@pytest.fixture
9-
def dummy_tensor():
10-
return torch.randn(4, 8, dtype=torch.float16)
8+
from tests.ut.base import PytestBase
9+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
1110

1211

1312
def mock_maybe_chunk_residual(x, residual):
1413
if x.size(0) != residual.size(0):
1514
return residual[:4]
16-
1715
return residual
1816

1917

@@ -25,69 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps):
2523
return 2 * x, None, 2 * residual
2624

2725

28-
@pytest.mark.parametrize("is_310p_return", [True, False])
29-
@pytest.mark.parametrize("residual",
30-
[None, torch.randn(4, 8, dtype=torch.float32)])
31-
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
32-
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
33-
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
34-
@patch("torch.ops.vllm.maybe_chunk_residual",
35-
side_effect=mock_maybe_chunk_residual)
36-
def test_RMSNorm_forward(mock_maybe_chunk_residual,
37-
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
38-
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
39-
40-
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
26+
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
27+
epsilon):
28+
x_out = 2 * x
29+
residual_out = 2 * residual
30+
x_out_quant = x_out.to(torch.int8)
31+
residual_out_quant = residual_out.to(torch.int8)
32+
return x_out_quant, None, residual_out_quant
33+
34+
35+
class TestAscendRMSNorm(PytestBase):
36+
37+
@pytest.fixture(autouse=True)
38+
def context(self, mocker: MockerFixture):
39+
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
40+
side_effect=mock_maybe_chunk_residual)
41+
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
42+
mocker.patch("torch_npu.npu_add_rms_norm",
43+
side_effect=mock_add_rms_norm)
44+
mocker.patch("torch_npu.npu_add_rms_norm_quant",
45+
side_effect=mock_add_rms_norm_quant)
46+
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
47+
side_effect=lambda x: None)
48+
49+
# Test case for the most common and basic scenario
50+
@pytest.mark.parametrize(
51+
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
52+
def test_forward_oot_basic(self, residual):
4153
layer = RMSNorm(hidden_size=8, eps=1e-05)
54+
x = torch.randn(4, 8, dtype=torch.float16)
4255
if residual is not None:
43-
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
44-
45-
if is_310p_return:
46-
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
47-
expected_out_x = expected_arg_x + 1
48-
expected_out_residual = expected_arg_x.to(residual.dtype)
49-
50-
mock_maybe_chunk_residual.assert_called_once()
51-
mock_rmsnorm.assert_called_once()
52-
mock_maybe_wait_prefetch_done.assert_called_once()
53-
assert torch.allclose(out_x, expected_out_x)
54-
assert torch.allclose(out_residual, expected_out_residual)
55-
else:
56-
expected_out_x = 2 * dummy_tensor
57-
expected_out_residual = 2 * residual
58-
mock_maybe_chunk_residual.assert_called_once()
59-
mock_add_rmsnorm.assert_called_once()
60-
mock_maybe_wait_prefetch_done.assert_called_once()
61-
assert torch.allclose(out_x, expected_out_x)
62-
assert torch.allclose(out_residual, expected_out_residual)
56+
x_out, residual_out = layer.forward_oot(x, residual)
57+
58+
x_out_expected = 2 * x
59+
residual_out_expected = 2 * residual
60+
61+
assert torch.allclose(x_out, x_out_expected)
62+
assert torch.allclose(residual_out, residual_out_expected)
6363
else:
64-
out_x = layer.forward(dummy_tensor, residual)
65-
expected_out_x = dummy_tensor + 1
66-
67-
mock_rmsnorm.assert_called_once()
68-
assert torch.allclose(out_x, expected_out_x)
69-
70-
71-
@patch("vllm_ascend.utils.is_310p", return_value=False)
72-
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
73-
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
74-
@patch("torch.ops.vllm.maybe_chunk_residual",
75-
side_effect=mock_maybe_chunk_residual)
76-
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
77-
mock_maybe_wait_prefetch_done,
78-
mock_add_rms_norm, mock_is310p):
79-
x = torch.randn(4, 512, dtype=torch.bfloat16)
80-
residual = torch.randn(16, 512, dtype=torch.bfloat16)
81-
layer = RMSNorm(hidden_size=512, eps=1e-05)
82-
83-
out_x, out_residual = layer.forward_oot(x, residual)
84-
85-
expected_out_x = 2 * x
86-
expected_out_residual = 2 * residual[:4]
87-
88-
mock_maybe_chunk_residual.assert_called_once()
89-
mock_add_rms_norm.assert_called_once()
90-
mock_maybe_wait_prefetch_done.assert_called_once()
91-
assert out_residual.size(0) == 4
92-
assert torch.allclose(out_x, expected_out_x)
93-
assert torch.allclose(out_residual, expected_out_residual)
64+
x_out = layer.forward(x, residual)
65+
x_out_expected = x + 1
66+
67+
assert torch.allclose(x_out, x_out_expected)
68+
69+
# Test case for flashcomm_v1 scenario
70+
def test_forward_oot_with_flashcomm_v1(self):
71+
layer = RMSNorm(hidden_size=512, eps=1e-05)
72+
x = torch.randn(4, 512, dtype=torch.bfloat16)
73+
residual = torch.randn(16, 512, dtype=torch.bfloat16)
74+
75+
x_out, residual_out = layer.forward_oot(x, residual)
76+
77+
x_out_expected = 2 * x
78+
residual_out_expected = 2 * residual[:4]
79+
80+
assert residual_out.size(0) == 4
81+
assert torch.allclose(x_out, x_out_expected)
82+
assert torch.allclose(residual_out, residual_out_expected)
83+
84+
# Test case for addrmsnorm + w8a8 quant fusion
85+
def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
86+
mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p")
87+
mock_is_310p.return_value = False
88+
mock_get_forward_context = mocker.patch(
89+
"vllm_ascend.ops.layernorm.get_forward_context")
90+
91+
# Simulating a scenario with quant_fusion enabled
92+
mock_forward_context = mocker.MagicMock()
93+
94+
mock_model_instance = mocker.MagicMock()
95+
mock_forward_context.model_instance = mock_model_instance
96+
mock_model_instance.model.layers = [
97+
mocker.MagicMock() for _ in range(2)
98+
]
99+
100+
mock_layer_0 = mock_model_instance.model.layers[0]
101+
mock_layer_0.self_attn.qkv_proj = mocker.MagicMock()
102+
mock_layer_0.mlp.gate_up_proj = mocker.MagicMock()
103+
104+
mock_layer_1 = mock_model_instance.model.layers[1]
105+
mock_layer_1.self_attn.qkv_proj = mocker.MagicMock()
106+
mock_layer_1.mlp.gate_up_proj = mocker.MagicMock()
107+
108+
mock_quant_method_0_qkv = mocker.MagicMock()
109+
mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod()
110+
mock_quant_method_0_gate_up = mocker.MagicMock()
111+
mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod()
112+
mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv
113+
mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up
114+
115+
mock_quant_method_1_qkv = mocker.MagicMock()
116+
mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod()
117+
mock_quant_method_1_gate_up = mocker.MagicMock()
118+
mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod()
119+
mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv
120+
mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up
121+
122+
mock_get_forward_context.return_value = mock_forward_context
123+
124+
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
125+
mock_forward_context.prefetch_mlp_enabled = False
126+
mock_forward_context.layer_idx = 0
127+
mock_forward_context.num_hidden_layers = 2
128+
mock_forward_context.fusion_linear = "gate_up_dense"
129+
130+
# Ensure fusion and layer_idx increment are handled correctly
131+
x = torch.randn(4, 8, dtype=torch.float16)
132+
residual = torch.randn(4, 8, dtype=torch.float16)
133+
layer = RMSNorm(hidden_size=8, eps=1e-05)
134+
135+
x_out, residual_out = layer.forward_oot(x, residual)
136+
137+
assert mock_get_forward_context.call_count == 1
138+
assert mock_forward_context.fusion_linear == "qkv_dense"
139+
assert mock_forward_context.layer_idx == 1
140+
141+
x_out, residual_out = layer.forward_oot(x, residual)
142+
143+
assert mock_get_forward_context.call_count == 2
144+
assert mock_forward_context.fusion_linear == "gate_up_dense"
145+
assert mock_forward_context.layer_idx == 1
146+
147+
x_out, residual_out = layer.forward_oot(x, residual)
148+
149+
assert mock_get_forward_context.call_count == 3
150+
assert mock_forward_context.fusion_linear == "qkv_dense"
151+
assert mock_forward_context.layer_idx == 2
152+
153+
x_out, residual_out = layer.forward_oot(x, residual)
154+
155+
assert mock_get_forward_context.call_count == 4
156+
assert mock_forward_context.fusion_linear == "qkv_dense"
157+
assert mock_forward_context.layer_idx == 2
158+
159+
160+
if __name__ == '__main__':
161+
unittest.main()

vllm_ascend/ascend_forward_context.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,22 @@ def set_ascend_forward_context(
129129
forward_context.prefetch_mlp_down_proj = False
130130
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
131131

132+
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
133+
# It will be improved later by implementing operator fusion through the FX graph.
134+
#
135+
# set for addrmsnorm+quant fusion.
136+
# this optim now just support dense models due to the specific operators used.
137+
# Once the necessary conditions are met, support for MOE models will also be added.
138+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
139+
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
140+
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
141+
forward_context.layer_idx is not None
142+
if addrmsnorm_quant_fusion_enabled:
143+
forward_context.model_instance = model_instance
144+
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
145+
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
146+
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
147+
132148
if num_tokens is None and attn_metadata is not None:
133149
num_tokens = attn_metadata.num_actual_tokens
134150

vllm_ascend/models/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ def register_model():
3535
"Qwen3MoeForCausalLM",
3636
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
3737

38-
ModelRegistry.register_model(
39-
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
40-
4138
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
4239
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
4340
ModelRegistry.register_model(

0 commit comments

Comments
 (0)