Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tests/e2e/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from tests.e2e.conftest import VllmRunner

os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
DEEPSEEK_W4A8_MODELS = [
"vllm-ascend/DeepSeek-V3-W4A8-Pruing",
"vllm-ascend/DeepSeek-R1-w4a8-pruning"
]


def test_models_distributed_QwQ():
Expand Down Expand Up @@ -211,14 +215,15 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC():
vllm_model.generate_greedy(example_prompts, max_tokens)


@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
def test_models_distributed_DeepSeek_W4A8DYNAMIC():
def test_models_distributed_DeepSeek_W4A8DYNAMIC(model):
prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/DeepSeek-R1-w4a8-pruning"),
snapshot_download(model),
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
Expand Down
121 changes: 91 additions & 30 deletions tests/ut/quantization/test_w4a8_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from unittest.mock import Mock, patch

import torch
Expand Down Expand Up @@ -31,79 +32,139 @@ def test_get_pergroup_param(self):


class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
experts = 8
input_size = 16
output_size = 56
group_size = 2

@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch("vllm_ascend.ascend_config.get_ascend_config")
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
mock_get_ep_group):
mock_get_ep_group, get_current_vllm_config):
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_get_ascend_config.return_value = mock_ascend_config
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,
"version": "0.0.0"
})
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
get_current_vllm_config.return_value = mock_vllm_config
self.quant_method = AscendW4A8DynamicFusedMoEMethod()

def test_get_weight(self):
param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16)
# old quant version weight
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, 2 * self.input_size, self.output_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14))
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, self.input_size, self.output_size))

@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
def test_get_dynamic_quant_param(self, mock_get_current_vllm_config):
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 2})
mock_get_current_vllm_config.return_value = mock_vllm_config
def test_get_dynamic_quant_param(self):
# old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param(
8, 4, 14, torch.bfloat16)
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1))
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(8, 8, 7))
(self.experts, 2 * self.input_size,
self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1))
self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(8, 14, 2))
(self.experts, self.output_size,
self.input_size // self.group_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
self.assertEqual(
param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size))

@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
# old quant version weight
layer = torch.nn.Module()
layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14),
dtype=torch.int8),
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4),
dtype=torch.int8),
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size, self.input_size),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(8, 8, 1), dtype=torch.bfloat16),
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_offset = torch.nn.Parameter(torch.zeros(
(8, 8, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
(8, 8, 7), dtype=torch.bfloat16),
(self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(8, 14, 1), dtype=torch.bfloat16),
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_offset = torch.nn.Parameter(torch.zeros(
(8, 14, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
(8, 14, 2), dtype=torch.bfloat16),
(self.experts, self.output_size,
self.input_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
new_layer = copy.deepcopy(layer)

mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8))
self.assertEqual(layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
self.assertTrue(hasattr(layer, "w2_scale_bias"))
self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14))
self.assertEqual(layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.quant_method.new_quant_version = True
new_layer.w13_weight.data = torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8)
new_layer.w2_weight.data = torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8)
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
Comment on lines +157 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The setup for w13_scale_bias seems to be based on an incorrect shape definition in get_dynamic_quant_param. The dimension 2 * self.input_size is inconsistent with the corresponding w13_weight's dimension for the new quantization version. This should be self.input_size to match the weight.

Suggested change
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w13_scale_bias = torch.zeros((self.experts, self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)

w2_scale_bias = torch.zeros(
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
dtype=torch.float32)
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
Comment on lines +167 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion checks for a shape that is inconsistent with the corresponding weight's shape. Following the correction in the w13_scale_bias setup, this assertion should be updated to check for the correct shape, which should use self.input_size instead of 2 * self.input_size.

Suggested change
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, self.input_size))

self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
7 changes: 5 additions & 2 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
@register_quantization_config(ASCEND_QUATIZATION_METHOD)
class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend

This class is a general class that parse quantization configs
that are supported on ascend hardware.
"""
Expand Down Expand Up @@ -295,14 +295,17 @@ def create_weights(

extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
]
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in param_key or "weight_offset_second" in param_key:
if any(fields in param_key for fields in per_group_param):
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)

Expand Down
Loading
Loading