Skip to content

Commit ce14d71

Browse files
committed
refactoer quantization
Signed-off-by: hust17yixuan <303660421@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com>
1 parent 21b5727 commit ce14d71

File tree

10 files changed

+1752
-3
lines changed

10 files changed

+1752
-3
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import copy
2+
from unittest.mock import Mock, patch
3+
4+
import torch
5+
6+
from tests.ut.base import TestBase
7+
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
8+
TorchairAscendW4A8DynamicFusedMoEMethod,
9+
TorchairAscendW4A8DynamicLinearMethod)
10+
11+
12+
class TestAscendW4A8DynamicLinearMethod(TestBase):
13+
14+
def setUp(self):
15+
self.method = TorchairAscendW4A8DynamicLinearMethod()
16+
self.method.group_size = 8
17+
18+
def test_get_weight(self):
19+
weight = self.method.get_weight(8, 32, torch.bfloat16)
20+
self.assertEqual(weight["weight"].dtype, torch.int8)
21+
self.assertEqual(weight["weight"].shape, (32, 8))
22+
23+
def test_get_pergroup_param(self):
24+
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
25+
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
26+
self.assertEqual(params["weight_scale"].shape, (32, 1))
27+
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
28+
self.assertEqual(params["weight_offset"].shape, (32, 1))
29+
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
30+
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
31+
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
32+
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
33+
34+
35+
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
36+
experts = 8
37+
input_size = 16
38+
output_size = 56
39+
group_size = 2
40+
41+
@patch(
42+
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
43+
)
44+
@patch(
45+
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
46+
@patch("vllm_ascend.ascend_config.get_ascend_config")
47+
@patch(
48+
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
49+
)
50+
@patch('torch.distributed.get_rank', return_value=0)
51+
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
52+
mock_get_ep_group, get_current_vllm_config):
53+
mock_ascend_config = Mock()
54+
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
55+
mock_get_ascend_config.return_value = mock_ascend_config
56+
mock_vllm_config = Mock()
57+
mock_vllm_config.quant_config = Mock(quant_description={
58+
"group_size": self.group_size,
59+
"version": "0.0.0"
60+
})
61+
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
62+
get_current_vllm_config.return_value = mock_vllm_config
63+
self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod()
64+
65+
def test_get_weight(self):
66+
# old quant version w4a8 weight
67+
param_dict = self.quant_method.get_weight(self.experts,
68+
self.input_size,
69+
self.output_size,
70+
torch.bfloat16)
71+
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
72+
self.assertEqual(param_dict["w13_weight"].shape,
73+
(self.experts, 2 * self.input_size, self.output_size))
74+
# new quant version weight
75+
self.quant_method.new_quant_version = True
76+
param_dict = self.quant_method.get_weight(self.experts,
77+
self.input_size,
78+
self.output_size,
79+
torch.bfloat16)
80+
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
81+
self.assertEqual(param_dict["w13_weight"].shape,
82+
(self.experts, self.input_size, self.output_size))
83+
84+
def test_get_dynamic_quant_param(self):
85+
# old quant version weight
86+
param_dict = self.quant_method.get_dynamic_quant_param(
87+
self.experts, self.input_size, self.output_size, torch.bfloat16)
88+
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
89+
self.assertEqual(param_dict["w13_weight_scale"].shape,
90+
(self.experts, 2 * self.input_size, 1))
91+
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
92+
torch.bfloat16)
93+
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
94+
(self.experts, 2 * self.input_size,
95+
self.output_size // self.group_size))
96+
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
97+
self.assertEqual(param_dict["w2_weight_scale"].shape,
98+
(self.experts, self.output_size, 1))
99+
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
100+
torch.bfloat16)
101+
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
102+
(self.experts, self.output_size,
103+
self.input_size // self.group_size))
104+
# new quant version weight
105+
self.quant_method.new_quant_version = True
106+
param_dict = self.quant_method.get_dynamic_quant_param(
107+
self.experts, self.input_size, self.output_size, torch.bfloat16)
108+
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
109+
self.assertEqual(
110+
param_dict["w2_scale_bias"].shape,
111+
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
112+
113+
@patch('torch_npu.npu_quantize')
114+
@patch('torch.Tensor.npu')
115+
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
116+
# old quant version weight
117+
layer = torch.nn.Module()
118+
layer.w13_weight = torch.nn.Parameter(torch.zeros(
119+
(self.experts, 2 * self.input_size, self.output_size),
120+
dtype=torch.int8),
121+
requires_grad=False)
122+
layer.w2_weight = torch.nn.Parameter(torch.zeros(
123+
(self.experts, self.output_size, self.input_size),
124+
dtype=torch.int8),
125+
requires_grad=False)
126+
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
127+
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
128+
requires_grad=False)
129+
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
130+
(self.experts, 2 * self.input_size,
131+
self.output_size // self.group_size),
132+
dtype=torch.bfloat16),
133+
requires_grad=False)
134+
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
135+
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
136+
requires_grad=False)
137+
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
138+
(self.experts, self.output_size,
139+
self.input_size // self.group_size),
140+
dtype=torch.bfloat16),
141+
requires_grad=False)
142+
new_layer = copy.deepcopy(layer)
143+
144+
mock_npu.return_value = torch.Tensor()
145+
mock_npu_quantize.return_value = torch.Tensor()
146+
self.quant_method.process_weights_after_loading(layer)
147+
self.assertTrue(hasattr(layer, "w13_scale_bias"))
148+
self.assertEqual(layer.w13_scale_bias.data.shape,
149+
(self.experts, 2 * self.input_size))
150+
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
151+
self.assertTrue(hasattr(layer, "w2_scale_bias"))
152+
self.assertEqual(layer.w2_scale_bias.data.shape,
153+
(self.experts, self.output_size))
154+
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
155+
# new quant version weight
156+
self.quant_method.new_quant_version = True
157+
new_layer.w13_weight.data = torch.zeros(
158+
(self.experts, self.input_size, self.output_size),
159+
dtype=torch.int8)
160+
new_layer.w2_weight.data = torch.zeros(
161+
(self.experts, self.output_size // 2, self.input_size),
162+
dtype=torch.int8)
163+
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
164+
dtype=torch.float32)
165+
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
166+
requires_grad=False)
167+
w2_scale_bias = torch.zeros(
168+
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
169+
dtype=torch.float32)
170+
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
171+
requires_grad=False)
172+
self.quant_method.process_weights_after_loading(new_layer)
173+
self.assertEqual(new_layer.w13_scale_bias.data.shape,
174+
(self.experts, 2 * self.input_size))
175+
self.assertEqual(new_layer.w2_scale_bias.data.shape,
176+
(self.experts, self.output_size))
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import torch
4+
5+
from tests.ut.base import TestBase
6+
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
7+
torchair_fused_experts_with_all2all
8+
9+
10+
class TestAscendW8A8FusedMoEMethod(TestBase):
11+
12+
def setUp(self):
13+
self.hidden_size = 128
14+
self.num_tokens = 128
15+
self.placeholder = torch.randn(self.num_tokens,
16+
self.hidden_size,
17+
dtype=torch.bfloat16)
18+
19+
@patch("torch.distributed.all_to_all_single")
20+
@patch("torch_npu.npu_moe_re_routing")
21+
@patch("torch_npu.npu_grouped_matmul")
22+
@patch("torch_npu.npu_swiglu")
23+
@patch("torch_npu.npu_dynamic_quant")
24+
@patch("torch_npu.npu_moe_finalize_routing")
25+
@patch("torch_npu.npu_moe_init_routing")
26+
def test_torchair_fused_experts_with_all2all(
27+
self, mock_moe_init_routing, mock_moe_finalize_routing,
28+
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
29+
mock_moe_re_routing, mock_all_to_all_single):
30+
31+
expert_map = MagicMock()
32+
ep_group = MagicMock()
33+
placeholder_int8 = torch.randint(0,
34+
100,
35+
(self.num_tokens, self.hidden_size),
36+
dtype=torch.int8)
37+
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
38+
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
39+
input)
40+
mock_moe_init_routing.return_value = (
41+
placeholder_int8,
42+
placeholder_ones,
43+
placeholder_ones,
44+
)
45+
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
46+
torch.randint(0,
47+
100,
48+
(self.num_tokens, ),
49+
dtype=torch.int32),
50+
self.placeholder)
51+
mock_grouped_matmul.return_value = self.placeholder
52+
mock_swiglu.return_value = self.placeholder
53+
mock_dynamic_quant.return_value = (
54+
placeholder_int8,
55+
torch.randn(self.num_tokens),
56+
)
57+
mock_moe_finalize_routing.return_value = self.placeholder
58+
59+
result = torchair_fused_experts_with_all2all(
60+
hidden_states=self.placeholder,
61+
w1=self.placeholder,
62+
w1_scale=self.placeholder,
63+
w2=self.placeholder,
64+
w2_scale=self.placeholder,
65+
topk_weights=self.placeholder,
66+
topk_ids=self.placeholder,
67+
top_k=8,
68+
expert_map=expert_map,
69+
ep_group=ep_group,
70+
log2phy=None,
71+
global_redundant_expert_num=256,
72+
)
73+
self.assertIsNotNone(result)
74+
self.assertEqual(result.dtype, torch.bfloat16)
75+
self.assertEqual(result.shape, (128, 128))

tests/ut/torchair/test_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from tests.ut.base import TestBase
9+
from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE
910
from vllm_ascend.torchair import utils
1011

1112

@@ -120,3 +121,15 @@ def test_converting_weight_acl_format_format_true(self, mock_npu_cast,
120121

121122
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
122123
mock_npu_cast.assert_not_called()
124+
125+
def test_torchair_quant_method_register(self):
126+
127+
TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
128+
"W8A8_DYNAMIC"]
129+
TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
130+
"W4A8_DYNAMIC"]
131+
utils.torchair_quant_method_register()
132+
self.assertNotEqual(TorchairW8A8DYNAMICQuantizer,
133+
SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"])
134+
self.assertNotEqual(TorchairW4A8DYNAMICQuantizer,
135+
SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"])

vllm_ascend/torchair/models/torchair_deepseek_v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@
7171

7272
from vllm_ascend.ascend_config import get_ascend_config
7373
from vllm_ascend.quantization.quant_config import AscendLinearMethod
74-
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7574
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
75+
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
76+
TorchairAscendW8A8DynamicLinearMethod
7677
from vllm_ascend.utils import dispose_tensor, npu_prefetch
7778

7879

@@ -261,8 +262,9 @@ def __init__(
261262
quant_method = self.gate_up_proj.quant_method
262263
if isinstance(quant_method, UnquantizedLinearMethod):
263264
self.act_fn = TorchairDeepseekV2SiluAndMul()
264-
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
265-
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
265+
elif (isinstance(quant_method, AscendLinearMethod)
266+
and isinstance(quant_method.quant_method,
267+
TorchairAscendW8A8DynamicLinearMethod)):
266268
# TODO(sdmyzlp): Currently preserved as before:
267269
# 1. The only quantization supported for silu is W8A8Dynamic
268270
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16

vllm_ascend/torchair/quantization/__init__.py

Whitespace-only changes.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from vllm_ascend.quantization.quantizer import VLLMAscendQuantizer
2+
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
3+
TorchairAscendW4A8DynamicFusedMoEMethod,
4+
TorchairAscendW4A8DynamicLinearMethod)
5+
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
6+
TorchairAscendW8A8DynamicFusedMoEMethod,
7+
TorchairAscendW8A8DynamicLinearMethod)
8+
9+
10+
class TorchairW8A8DYNAMICQuantizer(VLLMAscendQuantizer):
11+
12+
@staticmethod
13+
def build_linear_method():
14+
return TorchairAscendW8A8DynamicLinearMethod()
15+
16+
@staticmethod
17+
def build_moe_method():
18+
return TorchairAscendW8A8DynamicFusedMoEMethod()
19+
20+
21+
class TorchairW4A8DYNAMICQuantizer(VLLMAscendQuantizer):
22+
23+
@staticmethod
24+
def build_linear_method():
25+
return TorchairAscendW4A8DynamicLinearMethod()
26+
27+
@staticmethod
28+
def build_moe_method():
29+
return TorchairAscendW4A8DynamicFusedMoEMethod()

0 commit comments

Comments
 (0)