Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 91 additions & 30 deletions tests/ut/quantization/test_w4a8_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from unittest.mock import Mock, patch

import torch
Expand Down Expand Up @@ -31,79 +32,139 @@ def test_get_pergroup_param(self):


class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
experts = 8
input_size = 16
output_size = 56
group_size = 2

@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch("vllm_ascend.ascend_config.get_ascend_config")
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
mock_get_ep_group):
mock_get_ep_group, get_current_vllm_config):
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_get_ascend_config.return_value = mock_ascend_config
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,
"version": "0.0.0"
})
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
get_current_vllm_config.return_value = mock_vllm_config
self.quant_method = AscendW4A8DynamicFusedMoEMethod()

def test_get_weight(self):
param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16)
# old quant version weight
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, 2 * self.input_size, self.output_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14))
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, self.input_size, self.output_size))

@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
def test_get_dynamic_quant_param(self, mock_get_current_vllm_config):
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 2})
mock_get_current_vllm_config.return_value = mock_vllm_config
def test_get_dynamic_quant_param(self):
# old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param(
8, 4, 14, torch.bfloat16)
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1))
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(8, 8, 7))
(self.experts, 2 * self.input_size,
self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1))
self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(8, 14, 2))
(self.experts, self.output_size,
self.input_size // self.group_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
self.assertEqual(
param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size))

@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
# old quant version weight
layer = torch.nn.Module()
layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14),
dtype=torch.int8),
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4),
dtype=torch.int8),
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size, self.input_size),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(8, 8, 1), dtype=torch.bfloat16),
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_offset = torch.nn.Parameter(torch.zeros(
(8, 8, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
(8, 8, 7), dtype=torch.bfloat16),
(self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(8, 14, 1), dtype=torch.bfloat16),
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_offset = torch.nn.Parameter(torch.zeros(
(8, 14, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
(8, 14, 2), dtype=torch.bfloat16),
(self.experts, self.output_size,
self.input_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
new_layer = copy.deepcopy(layer)

mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8))
self.assertEqual(layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
self.assertTrue(hasattr(layer, "w2_scale_bias"))
self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14))
self.assertEqual(layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.quant_method.new_quant_version = True
new_layer.w13_weight.data = torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8)
new_layer.w2_weight.data = torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8)
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
Comment on lines +157 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The setup for w13_scale_bias seems to be based on an incorrect shape definition in get_dynamic_quant_param. The dimension 2 * self.input_size is inconsistent with the corresponding w13_weight's dimension for the new quantization version. This should be self.input_size to match the weight.

Suggested change
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w13_scale_bias = torch.zeros((self.experts, self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)

w2_scale_bias = torch.zeros(
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
dtype=torch.float32)
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
Comment on lines +167 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion checks for a shape that is inconsistent with the corresponding weight's shape. Following the correction in the w13_scale_bias setup, this assertion should be updated to check for the correct shape, which should use self.input_size instead of 2 * self.input_size.

Suggested change
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, self.input_size))

self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
7 changes: 5 additions & 2 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
@register_quantization_config(ASCEND_QUATIZATION_METHOD)
class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend

This class is a general class that parse quantization configs
that are supported on ascend hardware.
"""
Expand Down Expand Up @@ -295,14 +295,17 @@ def create_weights(

extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
]
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in param_key or "weight_offset_second" in param_key:
if any(fields in param_key for fields in per_group_param):
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)

Expand Down
Loading
Loading