1
- from unittest . mock import patch
1
+ import unittest
2
2
3
3
import pytest
4
4
import torch
5
+ from pytest_mock import MockerFixture
5
6
from vllm .model_executor .layers .layernorm import RMSNorm
6
7
7
-
8
- @pytest .fixture
9
- def dummy_tensor ():
10
- return torch .randn (4 , 8 , dtype = torch .float16 )
8
+ from tests .ut .base import PytestBase
9
+ from vllm_ascend .quantization .w8a8 import AscendW8A8LinearMethod
11
10
12
11
13
12
def mock_maybe_chunk_residual (x , residual ):
14
13
if x .size (0 ) != residual .size (0 ):
15
14
return residual [:4 ]
16
-
17
15
return residual
18
16
19
17
@@ -25,69 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps):
25
23
return 2 * x , None , 2 * residual
26
24
27
25
28
- @pytest .mark .parametrize ("is_310p_return" , [True , False ])
29
- @pytest .mark .parametrize ("residual" ,
30
- [None , torch .randn (4 , 8 , dtype = torch .float32 )])
31
- @patch ("torch_npu.npu_rms_norm" , side_effect = mock_rms_norm )
32
- @patch ("torch_npu.npu_add_rms_norm" , side_effect = mock_add_rms_norm )
33
- @patch ("torch.ops.vllm.maybe_wait_prefetch_done" , side_effect = lambda x : None )
34
- @patch ("torch.ops.vllm.maybe_chunk_residual" ,
35
- side_effect = mock_maybe_chunk_residual )
36
- def test_RMSNorm_forward (mock_maybe_chunk_residual ,
37
- mock_maybe_wait_prefetch_done , mock_add_rmsnorm ,
38
- mock_rmsnorm , is_310p_return , residual , dummy_tensor ):
39
-
40
- with patch ("vllm_ascend.utils.is_310p" , return_value = is_310p_return ):
26
+ def mock_add_rms_norm_quant (x , residual , weight , quant_scale , quant_offset ,
27
+ epsilon ):
28
+ x_out = 2 * x
29
+ residual_out = 2 * residual
30
+ x_out_quant = x_out .to (torch .int8 )
31
+ residual_out_quant = residual_out .to (torch .int8 )
32
+ return x_out_quant , None , residual_out_quant
33
+
34
+
35
+ class TestAscendRMSNorm (PytestBase ):
36
+
37
+ @pytest .fixture (autouse = True )
38
+ def context (self , mocker : MockerFixture ):
39
+ mocker .patch ("torch.ops.vllm.maybe_chunk_residual" ,
40
+ side_effect = mock_maybe_chunk_residual )
41
+ mocker .patch ("torch_npu.npu_rms_norm" , side_effect = mock_rms_norm )
42
+ mocker .patch ("torch_npu.npu_add_rms_norm" ,
43
+ side_effect = mock_add_rms_norm )
44
+ mocker .patch ("torch_npu.npu_add_rms_norm_quant" ,
45
+ side_effect = mock_add_rms_norm_quant )
46
+ mocker .patch ("torch.ops.vllm.maybe_wait_prefetch_done" ,
47
+ side_effect = lambda x : None )
48
+
49
+ # Test case for the most common and basic scenario
50
+ @pytest .mark .parametrize (
51
+ "residual" , [None , torch .randn (4 , 8 , dtype = torch .float16 )])
52
+ def test_forward_oot_basic (self , residual ):
41
53
layer = RMSNorm (hidden_size = 8 , eps = 1e-05 )
54
+ x = torch .randn (4 , 8 , dtype = torch .float16 )
42
55
if residual is not None :
43
- out_x , out_residual = layer .forward_oot (dummy_tensor , residual )
44
-
45
- if is_310p_return :
46
- expected_arg_x = dummy_tensor + residual .to (dummy_tensor .dtype )
47
- expected_out_x = expected_arg_x + 1
48
- expected_out_residual = expected_arg_x .to (residual .dtype )
49
-
50
- mock_maybe_chunk_residual .assert_called_once ()
51
- mock_rmsnorm .assert_called_once ()
52
- mock_maybe_wait_prefetch_done .assert_called_once ()
53
- assert torch .allclose (out_x , expected_out_x )
54
- assert torch .allclose (out_residual , expected_out_residual )
55
- else :
56
- expected_out_x = 2 * dummy_tensor
57
- expected_out_residual = 2 * residual
58
- mock_maybe_chunk_residual .assert_called_once ()
59
- mock_add_rmsnorm .assert_called_once ()
60
- mock_maybe_wait_prefetch_done .assert_called_once ()
61
- assert torch .allclose (out_x , expected_out_x )
62
- assert torch .allclose (out_residual , expected_out_residual )
56
+ x_out , residual_out = layer .forward_oot (x , residual )
57
+
58
+ x_out_expected = 2 * x
59
+ residual_out_expected = 2 * residual
60
+
61
+ assert torch .allclose (x_out , x_out_expected )
62
+ assert torch .allclose (residual_out , residual_out_expected )
63
63
else :
64
- out_x = layer .forward (dummy_tensor , residual )
65
- expected_out_x = dummy_tensor + 1
66
-
67
- mock_rmsnorm .assert_called_once ()
68
- assert torch .allclose (out_x , expected_out_x )
69
-
70
-
71
- @patch ("vllm_ascend.utils.is_310p" , return_value = False )
72
- @patch ("torch_npu.npu_add_rms_norm" , side_effect = mock_add_rms_norm )
73
- @patch ("torch.ops.vllm.maybe_wait_prefetch_done" , side_effect = lambda x : None )
74
- @patch ("torch.ops.vllm.maybe_chunk_residual" ,
75
- side_effect = mock_maybe_chunk_residual )
76
- def test_RMSNorm_forward_with_flashcomm_v1 (mock_maybe_chunk_residual ,
77
- mock_maybe_wait_prefetch_done ,
78
- mock_add_rms_norm , mock_is310p ):
79
- x = torch .randn (4 , 512 , dtype = torch .bfloat16 )
80
- residual = torch .randn (16 , 512 , dtype = torch .bfloat16 )
81
- layer = RMSNorm (hidden_size = 512 , eps = 1e-05 )
82
-
83
- out_x , out_residual = layer .forward_oot (x , residual )
84
-
85
- expected_out_x = 2 * x
86
- expected_out_residual = 2 * residual [:4 ]
87
-
88
- mock_maybe_chunk_residual .assert_called_once ()
89
- mock_add_rms_norm .assert_called_once ()
90
- mock_maybe_wait_prefetch_done .assert_called_once ()
91
- assert out_residual .size (0 ) == 4
92
- assert torch .allclose (out_x , expected_out_x )
93
- assert torch .allclose (out_residual , expected_out_residual )
64
+ x_out = layer .forward (x , residual )
65
+ x_out_expected = x + 1
66
+
67
+ assert torch .allclose (x_out , x_out_expected )
68
+
69
+ # Test case for flashcomm_v1 scenario
70
+ def test_forward_oot_with_flashcomm_v1 (self ):
71
+ layer = RMSNorm (hidden_size = 512 , eps = 1e-05 )
72
+ x = torch .randn (4 , 512 , dtype = torch .bfloat16 )
73
+ residual = torch .randn (16 , 512 , dtype = torch .bfloat16 )
74
+
75
+ x_out , residual_out = layer .forward_oot (x , residual )
76
+
77
+ x_out_expected = 2 * x
78
+ residual_out_expected = 2 * residual [:4 ]
79
+
80
+ assert residual_out .size (0 ) == 4
81
+ assert torch .allclose (x_out , x_out_expected )
82
+ assert torch .allclose (residual_out , residual_out_expected )
83
+
84
+ # Test case for addrmsnorm + w8a8 quant fusion
85
+ def test_forward_oot_with_quant_fusion (self , mocker : MockerFixture ):
86
+ mock_is_310p = mocker .patch ("vllm_ascend.utils.is_310p" )
87
+ mock_is_310p .return_value = False
88
+ mock_get_forward_context = mocker .patch (
89
+ "vllm_ascend.ops.layernorm.get_forward_context" )
90
+
91
+ # Simulating a scenario with quant_fusion enabled
92
+ mock_forward_context = mocker .MagicMock ()
93
+
94
+ mock_model_instance = mocker .MagicMock ()
95
+ mock_forward_context .model_instance = mock_model_instance
96
+ mock_model_instance .model .layers = [
97
+ mocker .MagicMock () for _ in range (2 )
98
+ ]
99
+
100
+ mock_layer_0 = mock_model_instance .model .layers [0 ]
101
+ mock_layer_0 .self_attn .qkv_proj = mocker .MagicMock ()
102
+ mock_layer_0 .mlp .gate_up_proj = mocker .MagicMock ()
103
+
104
+ mock_layer_1 = mock_model_instance .model .layers [1 ]
105
+ mock_layer_1 .self_attn .qkv_proj = mocker .MagicMock ()
106
+ mock_layer_1 .mlp .gate_up_proj = mocker .MagicMock ()
107
+
108
+ mock_quant_method_0_qkv = mocker .MagicMock ()
109
+ mock_quant_method_0_qkv .quant_method = AscendW8A8LinearMethod ()
110
+ mock_quant_method_0_gate_up = mocker .MagicMock ()
111
+ mock_quant_method_0_gate_up .quant_method = AscendW8A8LinearMethod ()
112
+ mock_layer_0 .self_attn .qkv_proj .quant_method = mock_quant_method_0_qkv
113
+ mock_layer_0 .mlp .gate_up_proj .quant_method = mock_quant_method_0_gate_up
114
+
115
+ mock_quant_method_1_qkv = mocker .MagicMock ()
116
+ mock_quant_method_1_qkv .quant_method = AscendW8A8LinearMethod ()
117
+ mock_quant_method_1_gate_up = mocker .MagicMock ()
118
+ mock_quant_method_1_gate_up .quant_method = AscendW8A8LinearMethod ()
119
+ mock_layer_1 .self_attn .qkv_proj .quant_method = mock_quant_method_1_qkv
120
+ mock_layer_1 .mlp .gate_up_proj .quant_method = mock_quant_method_1_gate_up
121
+
122
+ mock_get_forward_context .return_value = mock_forward_context
123
+
124
+ mock_forward_context .addrmsnorm_quant_fusion_enabled = True
125
+ mock_forward_context .prefetch_mlp_enabled = False
126
+ mock_forward_context .layer_idx = 0
127
+ mock_forward_context .num_hidden_layers = 2
128
+ mock_forward_context .fusion_linear = "gate_up_dense"
129
+
130
+ # Ensure fusion and layer_idx increment are handled correctly
131
+ x = torch .randn (4 , 8 , dtype = torch .float16 )
132
+ residual = torch .randn (4 , 8 , dtype = torch .float16 )
133
+ layer = RMSNorm (hidden_size = 8 , eps = 1e-05 )
134
+
135
+ x_out , residual_out = layer .forward_oot (x , residual )
136
+
137
+ assert mock_get_forward_context .call_count == 1
138
+ assert mock_forward_context .fusion_linear == "qkv_dense"
139
+ assert mock_forward_context .layer_idx == 1
140
+
141
+ x_out , residual_out = layer .forward_oot (x , residual )
142
+
143
+ assert mock_get_forward_context .call_count == 2
144
+ assert mock_forward_context .fusion_linear == "gate_up_dense"
145
+ assert mock_forward_context .layer_idx == 1
146
+
147
+ x_out , residual_out = layer .forward_oot (x , residual )
148
+
149
+ assert mock_get_forward_context .call_count == 3
150
+ assert mock_forward_context .fusion_linear == "qkv_dense"
151
+ assert mock_forward_context .layer_idx == 2
152
+
153
+ x_out , residual_out = layer .forward_oot (x , residual )
154
+
155
+ assert mock_get_forward_context .call_count == 4
156
+ assert mock_forward_context .fusion_linear == "qkv_dense"
157
+ assert mock_forward_context .layer_idx == 2
158
+
159
+
160
+ if __name__ == '__main__' :
161
+ unittest .main ()
0 commit comments