Skip to content

Commit c40d417

Browse files
[main][quantization] Adapt to the new format of ds w4a8 weight (#2392)
### What this PR does / why we need it? The deepseek w4a8 weights we supported before were in mindie-format format. It uses int8 to represent int4, so the weight size is similar to w8a8, and we need to do a few extra steps to make vllm-ascend load it normally. Now we can directly use the new weight format, which uses two int4 packs to save the weight, the weight size is reduced, and there is no need to do many extra operations to directly use it on vllm-ascend, but we are also compatible with the weights of the previous mindie format. The weight changes in the new version: 1. The weight is packed (2 int4 pack to int8) 2. The bias required in the apply method is directly generated by modelslim ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` #### 1.How to get weights using Modelslim ##### Installation steps we can use the branch br_release_MindStudio_8.1.RC2_TR5_20260624 git clone -b br_release_MindStudio_8.1.RC2_TR5_20260624 https://gitee.com/ascend/msit.git cd msit/msmodelslim bash install.sh ##### Generate w4a8 weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md Execute the [pre-check](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80) and [DeepSeek-R1 w4a8 mix quantization](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96) chapter Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path} ##### Adapt to vllm-ascend Modification in `config.json`:`"model_type":deepseekv2` is changed to `"model_type":deepseek_v3`; #### 2.How to run w4a8 ##### a.How to run eager mode export VLLM_ASCEND_MLA_PA=1 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --max-num-seqs 128 --enforce-eager ##### b.How to run graph mode export HCCL_BUFFSIZE=1024 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' eg: python -m vllm.entrypoints.openai.api_server --model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@103f1ec --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
1 parent eccfb71 commit c40d417

File tree

4 files changed

+184
-71
lines changed

4 files changed

+184
-71
lines changed

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
from tests.e2e.conftest import VllmRunner
3232

3333
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
34+
DEEPSEEK_W4A8_MODELS = [
35+
"vllm-ascend/DeepSeek-V3-W4A8-Pruing",
36+
"vllm-ascend/DeepSeek-R1-w4a8-pruning"
37+
]
3438

3539

3640
def test_models_distributed_QwQ():
@@ -211,14 +215,15 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC():
211215
vllm_model.generate_greedy(example_prompts, max_tokens)
212216

213217

218+
@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS)
214219
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
215-
def test_models_distributed_DeepSeek_W4A8DYNAMIC():
220+
def test_models_distributed_DeepSeek_W4A8DYNAMIC(model):
216221
prompts = [
217222
"Hello, my name is",
218223
]
219224
max_tokens = 5
220225
with VllmRunner(
221-
snapshot_download("vllm-ascend/DeepSeek-R1-w4a8-pruning"),
226+
snapshot_download(model),
222227
dtype="auto",
223228
tensor_parallel_size=2,
224229
quantization="ascend",
Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from unittest.mock import Mock, patch
23

34
import torch
@@ -31,79 +32,139 @@ def test_get_pergroup_param(self):
3132

3233

3334
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
35+
experts = 8
36+
input_size = 16
37+
output_size = 56
38+
group_size = 2
3439

40+
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
3541
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
3642
@patch("vllm_ascend.ascend_config.get_ascend_config")
3743
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
3844
@patch('torch.distributed.get_rank', return_value=0)
3945
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
40-
mock_get_ep_group):
46+
mock_get_ep_group, get_current_vllm_config):
4147
mock_ascend_config = Mock()
4248
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
4349
mock_get_ascend_config.return_value = mock_ascend_config
50+
mock_vllm_config = Mock()
51+
mock_vllm_config.quant_config = Mock(quant_description={
52+
"group_size": self.group_size,
53+
"version": "0.0.0"
54+
})
55+
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
56+
get_current_vllm_config.return_value = mock_vllm_config
4457
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
4558

4659
def test_get_weight(self):
47-
param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16)
60+
# old quant version w4a8 weight
61+
param_dict = self.quant_method.get_weight(self.experts,
62+
self.input_size,
63+
self.output_size,
64+
torch.bfloat16)
65+
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
66+
self.assertEqual(param_dict["w13_weight"].shape,
67+
(self.experts, 2 * self.input_size, self.output_size))
68+
# new quant version weight
69+
self.quant_method.new_quant_version = True
70+
param_dict = self.quant_method.get_weight(self.experts,
71+
self.input_size,
72+
self.output_size,
73+
torch.bfloat16)
4874
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
49-
self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14))
75+
self.assertEqual(param_dict["w13_weight"].shape,
76+
(self.experts, self.input_size, self.output_size))
5077

51-
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
52-
def test_get_dynamic_quant_param(self, mock_get_current_vllm_config):
53-
mock_vllm_config = Mock()
54-
mock_vllm_config.quant_config = Mock(
55-
quant_description={"group_size": 2})
56-
mock_get_current_vllm_config.return_value = mock_vllm_config
78+
def test_get_dynamic_quant_param(self):
79+
# old quant version weight
5780
param_dict = self.quant_method.get_dynamic_quant_param(
58-
8, 4, 14, torch.bfloat16)
81+
self.experts, self.input_size, self.output_size, torch.bfloat16)
5982
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
60-
self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1))
83+
self.assertEqual(param_dict["w13_weight_scale"].shape,
84+
(self.experts, 2 * self.input_size, 1))
6185
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
6286
torch.bfloat16)
6387
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
64-
(8, 8, 7))
88+
(self.experts, 2 * self.input_size,
89+
self.output_size // self.group_size))
6590
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
66-
self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1))
91+
self.assertEqual(param_dict["w2_weight_scale"].shape,
92+
(self.experts, self.output_size, 1))
6793
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
6894
torch.bfloat16)
6995
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
70-
(8, 14, 2))
96+
(self.experts, self.output_size,
97+
self.input_size // self.group_size))
98+
# new quant version weight
99+
self.quant_method.new_quant_version = True
100+
param_dict = self.quant_method.get_dynamic_quant_param(
101+
self.experts, self.input_size, self.output_size, torch.bfloat16)
102+
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
103+
self.assertEqual(
104+
param_dict["w2_scale_bias"].shape,
105+
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
71106

72107
@patch('torch_npu.npu_quantize')
73108
@patch('torch.Tensor.npu')
74109
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
110+
# old quant version weight
75111
layer = torch.nn.Module()
76-
layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14),
77-
dtype=torch.int8),
112+
layer.w13_weight = torch.nn.Parameter(torch.zeros(
113+
(self.experts, 2 * self.input_size, self.output_size),
114+
dtype=torch.int8),
78115
requires_grad=False)
79-
layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4),
80-
dtype=torch.int8),
116+
layer.w2_weight = torch.nn.Parameter(torch.zeros(
117+
(self.experts, self.output_size, self.input_size),
118+
dtype=torch.int8),
81119
requires_grad=False)
82120
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
83-
(8, 8, 1), dtype=torch.bfloat16),
121+
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
84122
requires_grad=False)
85-
layer.w13_weight_offset = torch.nn.Parameter(torch.zeros(
86-
(8, 8, 1), dtype=torch.bfloat16),
87-
requires_grad=False)
88123
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
89-
(8, 8, 7), dtype=torch.bfloat16),
124+
(self.experts, 2 * self.input_size,
125+
self.output_size // self.group_size),
126+
dtype=torch.bfloat16),
90127
requires_grad=False)
91128
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
92-
(8, 14, 1), dtype=torch.bfloat16),
129+
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
93130
requires_grad=False)
94-
layer.w2_weight_offset = torch.nn.Parameter(torch.zeros(
95-
(8, 14, 1), dtype=torch.bfloat16),
96-
requires_grad=False)
97131
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
98-
(8, 14, 2), dtype=torch.bfloat16),
132+
(self.experts, self.output_size,
133+
self.input_size // self.group_size),
134+
dtype=torch.bfloat16),
99135
requires_grad=False)
136+
new_layer = copy.deepcopy(layer)
100137

101138
mock_npu.return_value = torch.Tensor()
102139
mock_npu_quantize.return_value = torch.Tensor()
103140
self.quant_method.process_weights_after_loading(layer)
104141
self.assertTrue(hasattr(layer, "w13_scale_bias"))
105-
self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8))
142+
self.assertEqual(layer.w13_scale_bias.data.shape,
143+
(self.experts, 2 * self.input_size))
106144
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
107145
self.assertTrue(hasattr(layer, "w2_scale_bias"))
108-
self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14))
146+
self.assertEqual(layer.w2_scale_bias.data.shape,
147+
(self.experts, self.output_size))
109148
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
149+
# new quant version weight
150+
self.quant_method.new_quant_version = True
151+
new_layer.w13_weight.data = torch.zeros(
152+
(self.experts, self.input_size, self.output_size),
153+
dtype=torch.int8)
154+
new_layer.w2_weight.data = torch.zeros(
155+
(self.experts, self.output_size // 2, self.input_size),
156+
dtype=torch.int8)
157+
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
158+
dtype=torch.float32)
159+
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
160+
requires_grad=False)
161+
w2_scale_bias = torch.zeros(
162+
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
163+
dtype=torch.float32)
164+
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
165+
requires_grad=False)
166+
self.quant_method.process_weights_after_loading(new_layer)
167+
self.assertEqual(new_layer.w13_scale_bias.data.shape,
168+
(self.experts, 2 * self.input_size))
169+
self.assertEqual(new_layer.w2_scale_bias.data.shape,
170+
(self.experts, self.output_size))

vllm_ascend/quantization/quant_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
@register_quantization_config(ASCEND_QUATIZATION_METHOD)
4545
class AscendQuantConfig(QuantizationConfig):
4646
"""Config class for Ascend
47-
47+
4848
This class is a general class that parse quantization configs
4949
that are supported on ascend hardware.
5050
"""
@@ -295,14 +295,17 @@ def create_weights(
295295

296296
extra_weight_attrs.update(
297297
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
298+
per_group_param = [
299+
"weight_scale_second", "weight_offset_second", "scale_bias"
300+
]
298301
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
299302
num_experts, intermediate_size_per_partition, hidden_size,
300303
params_dtype)
301304
for param_key, param_value in dynamic_quant_param.items():
302305
param = torch.nn.Parameter(param_value, requires_grad=False)
303306
layer.register_parameter(param_key, param)
304307
set_weight_attrs(param, extra_weight_attrs)
305-
if "weight_scale_second" in param_key or "weight_offset_second" in param_key:
308+
if any(fields in param_key for fields in per_group_param):
306309
setattr(param, "quant_method",
307310
FusedMoeWeightScaleSupported.GROUP.value)
308311

0 commit comments

Comments
 (0)