|
| 1 | +import copy |
1 | 2 | from unittest.mock import Mock, patch
|
2 | 3 |
|
3 | 4 | import torch
|
@@ -31,79 +32,139 @@ def test_get_pergroup_param(self):
|
31 | 32 |
|
32 | 33 |
|
33 | 34 | class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
| 35 | + experts = 8 |
| 36 | + input_size = 16 |
| 37 | + output_size = 56 |
| 38 | + group_size = 2 |
34 | 39 |
|
| 40 | + @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') |
35 | 41 | @patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
36 | 42 | @patch("vllm_ascend.ascend_config.get_ascend_config")
|
37 | 43 | @patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
38 | 44 | @patch('torch.distributed.get_rank', return_value=0)
|
39 | 45 | def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
|
40 |
| - mock_get_ep_group): |
| 46 | + mock_get_ep_group, get_current_vllm_config): |
41 | 47 | mock_ascend_config = Mock()
|
42 | 48 | mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
43 | 49 | mock_get_ascend_config.return_value = mock_ascend_config
|
| 50 | + mock_vllm_config = Mock() |
| 51 | + mock_vllm_config.quant_config = Mock(quant_description={ |
| 52 | + "group_size": self.group_size, |
| 53 | + "version": "0.0.0" |
| 54 | + }) |
| 55 | + mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True) |
| 56 | + get_current_vllm_config.return_value = mock_vllm_config |
44 | 57 | self.quant_method = AscendW4A8DynamicFusedMoEMethod()
|
45 | 58 |
|
46 | 59 | def test_get_weight(self):
|
47 |
| - param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16) |
| 60 | + # old quant version w4a8 weight |
| 61 | + param_dict = self.quant_method.get_weight(self.experts, |
| 62 | + self.input_size, |
| 63 | + self.output_size, |
| 64 | + torch.bfloat16) |
| 65 | + self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) |
| 66 | + self.assertEqual(param_dict["w13_weight"].shape, |
| 67 | + (self.experts, 2 * self.input_size, self.output_size)) |
| 68 | + # new quant version weight |
| 69 | + self.quant_method.new_quant_version = True |
| 70 | + param_dict = self.quant_method.get_weight(self.experts, |
| 71 | + self.input_size, |
| 72 | + self.output_size, |
| 73 | + torch.bfloat16) |
48 | 74 | self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
49 |
| - self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14)) |
| 75 | + self.assertEqual(param_dict["w13_weight"].shape, |
| 76 | + (self.experts, self.input_size, self.output_size)) |
50 | 77 |
|
51 |
| - @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') |
52 |
| - def test_get_dynamic_quant_param(self, mock_get_current_vllm_config): |
53 |
| - mock_vllm_config = Mock() |
54 |
| - mock_vllm_config.quant_config = Mock( |
55 |
| - quant_description={"group_size": 2}) |
56 |
| - mock_get_current_vllm_config.return_value = mock_vllm_config |
| 78 | + def test_get_dynamic_quant_param(self): |
| 79 | + # old quant version weight |
57 | 80 | param_dict = self.quant_method.get_dynamic_quant_param(
|
58 |
| - 8, 4, 14, torch.bfloat16) |
| 81 | + self.experts, self.input_size, self.output_size, torch.bfloat16) |
59 | 82 | self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
60 |
| - self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1)) |
| 83 | + self.assertEqual(param_dict["w13_weight_scale"].shape, |
| 84 | + (self.experts, 2 * self.input_size, 1)) |
61 | 85 | self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
62 | 86 | torch.bfloat16)
|
63 | 87 | self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
64 |
| - (8, 8, 7)) |
| 88 | + (self.experts, 2 * self.input_size, |
| 89 | + self.output_size // self.group_size)) |
65 | 90 | self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
66 |
| - self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1)) |
| 91 | + self.assertEqual(param_dict["w2_weight_scale"].shape, |
| 92 | + (self.experts, self.output_size, 1)) |
67 | 93 | self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
68 | 94 | torch.bfloat16)
|
69 | 95 | self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
70 |
| - (8, 14, 2)) |
| 96 | + (self.experts, self.output_size, |
| 97 | + self.input_size // self.group_size)) |
| 98 | + # new quant version weight |
| 99 | + self.quant_method.new_quant_version = True |
| 100 | + param_dict = self.quant_method.get_dynamic_quant_param( |
| 101 | + self.experts, self.input_size, self.output_size, torch.bfloat16) |
| 102 | + self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32) |
| 103 | + self.assertEqual( |
| 104 | + param_dict["w2_scale_bias"].shape, |
| 105 | + (self.experts, self.output_size, 16 // self.quant_method.tp_size)) |
71 | 106 |
|
72 | 107 | @patch('torch_npu.npu_quantize')
|
73 | 108 | @patch('torch.Tensor.npu')
|
74 | 109 | def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
| 110 | + # old quant version weight |
75 | 111 | layer = torch.nn.Module()
|
76 |
| - layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14), |
77 |
| - dtype=torch.int8), |
| 112 | + layer.w13_weight = torch.nn.Parameter(torch.zeros( |
| 113 | + (self.experts, 2 * self.input_size, self.output_size), |
| 114 | + dtype=torch.int8), |
78 | 115 | requires_grad=False)
|
79 |
| - layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4), |
80 |
| - dtype=torch.int8), |
| 116 | + layer.w2_weight = torch.nn.Parameter(torch.zeros( |
| 117 | + (self.experts, self.output_size, self.input_size), |
| 118 | + dtype=torch.int8), |
81 | 119 | requires_grad=False)
|
82 | 120 | layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
83 |
| - (8, 8, 1), dtype=torch.bfloat16), |
| 121 | + (self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), |
84 | 122 | requires_grad=False)
|
85 |
| - layer.w13_weight_offset = torch.nn.Parameter(torch.zeros( |
86 |
| - (8, 8, 1), dtype=torch.bfloat16), |
87 |
| - requires_grad=False) |
88 | 123 | layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
89 |
| - (8, 8, 7), dtype=torch.bfloat16), |
| 124 | + (self.experts, 2 * self.input_size, |
| 125 | + self.output_size // self.group_size), |
| 126 | + dtype=torch.bfloat16), |
90 | 127 | requires_grad=False)
|
91 | 128 | layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
92 |
| - (8, 14, 1), dtype=torch.bfloat16), |
| 129 | + (self.experts, self.output_size, 1), dtype=torch.bfloat16), |
93 | 130 | requires_grad=False)
|
94 |
| - layer.w2_weight_offset = torch.nn.Parameter(torch.zeros( |
95 |
| - (8, 14, 1), dtype=torch.bfloat16), |
96 |
| - requires_grad=False) |
97 | 131 | layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
98 |
| - (8, 14, 2), dtype=torch.bfloat16), |
| 132 | + (self.experts, self.output_size, |
| 133 | + self.input_size // self.group_size), |
| 134 | + dtype=torch.bfloat16), |
99 | 135 | requires_grad=False)
|
| 136 | + new_layer = copy.deepcopy(layer) |
100 | 137 |
|
101 | 138 | mock_npu.return_value = torch.Tensor()
|
102 | 139 | mock_npu_quantize.return_value = torch.Tensor()
|
103 | 140 | self.quant_method.process_weights_after_loading(layer)
|
104 | 141 | self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
105 |
| - self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8)) |
| 142 | + self.assertEqual(layer.w13_scale_bias.data.shape, |
| 143 | + (self.experts, 2 * self.input_size)) |
106 | 144 | self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
|
107 | 145 | self.assertTrue(hasattr(layer, "w2_scale_bias"))
|
108 |
| - self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14)) |
| 146 | + self.assertEqual(layer.w2_scale_bias.data.shape, |
| 147 | + (self.experts, self.output_size)) |
109 | 148 | self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
| 149 | + # new quant version weight |
| 150 | + self.quant_method.new_quant_version = True |
| 151 | + new_layer.w13_weight.data = torch.zeros( |
| 152 | + (self.experts, self.input_size, self.output_size), |
| 153 | + dtype=torch.int8) |
| 154 | + new_layer.w2_weight.data = torch.zeros( |
| 155 | + (self.experts, self.output_size // 2, self.input_size), |
| 156 | + dtype=torch.int8) |
| 157 | + w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1), |
| 158 | + dtype=torch.float32) |
| 159 | + new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, |
| 160 | + requires_grad=False) |
| 161 | + w2_scale_bias = torch.zeros( |
| 162 | + (self.experts, self.output_size, 16 // self.quant_method.tp_size), |
| 163 | + dtype=torch.float32) |
| 164 | + new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, |
| 165 | + requires_grad=False) |
| 166 | + self.quant_method.process_weights_after_loading(new_layer) |
| 167 | + self.assertEqual(new_layer.w13_scale_bias.data.shape, |
| 168 | + (self.experts, 2 * self.input_size)) |
| 169 | + self.assertEqual(new_layer.w2_scale_bias.data.shape, |
| 170 | + (self.experts, self.output_size)) |
0 commit comments