Skip to content

Commit 2c6f2dc

Browse files
22dimensionsoffline0806
authored andcommitted
[1/N][Refactor][Quantization] remove redundant quantizer class (vllm-project#2680)
### What this PR does / why we need it? AscendQuantizer/LLMQuantizer class is used to select quant method based on quant config and some other arguments, but it is more simple and clean replacing these classes with map. So i remove them. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ut and e2e test - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@6997a25 Signed-off-by: 22dimensions <waitingwind@foxmail.com> Signed-off-by: offline0806 <z00858301@china.huawei.com>
1 parent 5c933b5 commit 2c6f2dc

File tree

10 files changed

+322
-555
lines changed

10 files changed

+322
-555
lines changed

tests/ut/quantization/test_quant_config.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -156,33 +156,22 @@ class TestAscendKVCacheMethod(TestBase):
156156
def setUp(self):
157157
# Setup common test fixtures
158158
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
159-
self.mock_quant_config.quant_description = {"some_config": "value"}
160-
self.prefix = "attention_layer"
159+
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
160+
self.prefix = "layer.attn"
161161

162-
# Mock the quantizer and quant_method
163-
self.mock_quantizer = MagicMock()
162+
# Mock quant_method
164163
self.mock_quant_method = MagicMock()
165-
166-
# Patch the AscendQuantizer
167-
self.quantizer_patcher = patch(
168-
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
169-
return_value=self.mock_quantizer)
170-
self.mock_get_quantizer = self.quantizer_patcher.start()
171-
172-
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
164+
self.patcher = patch(
165+
'vllm_ascend.quantization.quant_config.get_quant_method')
166+
self.mock_get_quant_method = self.patcher.start()
167+
self.mock_get_quant_method.return_value = self.mock_quant_method
173168

174169
# Create instance
175170
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
176171
self.prefix)
177172

178173
def tearDown(self):
179-
self.quantizer_patcher.stop()
180-
181-
def test_init(self):
182-
"""Test initialization with proper quantizer setup."""
183-
self.mock_get_quantizer.assert_called_once_with(
184-
self.mock_quant_config.quant_description, self.prefix)
185-
self.mock_quantizer.build_attention_method.assert_called_once()
174+
self.patcher.stop()
186175

187176
def test_create_weights(self):
188177
"""Test create_weights delegates to quant_method."""

tests/ut/quantization/test_quantizer.py

Lines changed: 0 additions & 145 deletions
This file was deleted.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import types
2+
3+
from tests.ut.base import TestBase
4+
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
5+
get_quant_method)
6+
7+
8+
class TestGetQuantMethod(TestBase):
9+
10+
def setUp(self):
11+
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
12+
)
13+
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
14+
for layer_type in layer_map.keys():
15+
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
16+
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
17+
18+
def tearDown(self):
19+
# Restore original map
20+
ASCEND_QUANTIZATION_METHOD_MAP.clear()
21+
ASCEND_QUANTIZATION_METHOD_MAP.update(
22+
self.original_quantization_method_map)
23+
24+
def test_linear_quant_methods(self):
25+
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
26+
if "linear" in layer_map.keys():
27+
prefix = "linear_layer"
28+
cls = layer_map["linear"]
29+
method = get_quant_method({"linear_layer.weight": quant_type},
30+
prefix, "linear")
31+
self.assertIsInstance(method, cls)
32+
33+
def test_moe_quant_methods(self):
34+
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
35+
if "moe" in layer_map.keys():
36+
prefix = "layer"
37+
cls = layer_map["moe"]
38+
method = get_quant_method({"layer.weight": quant_type}, prefix,
39+
"moe")
40+
self.assertIsInstance(method, cls)
41+
42+
def test_with_fa_quant_type(self):
43+
quant_description = {"fa_quant_type": "C8"}
44+
method = get_quant_method(quant_description, ".attn", "attention")
45+
self.assertIsInstance(
46+
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
47+
48+
def test_with_kv_quant_type(self):
49+
quant_description = {"kv_quant_type": "C8"}
50+
method = get_quant_method(quant_description, ".attn", "attention")
51+
self.assertIsInstance(
52+
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
53+
54+
def test_invalid_layer_type(self):
55+
quant_description = {"linear_layer.weight": "W8A8"}
56+
with self.assertRaises(NotImplementedError):
57+
get_quant_method(quant_description, "linear_layer", "unsupported")
58+
59+
def test_invalid_quant_type(self):
60+
quant_description = {"linear_layer.weight": "UNKNOWN"}
61+
with self.assertRaises(NotImplementedError):
62+
get_quant_method(quant_description, "linear_layer", "linear")

tests/ut/torchair/ops/test_torchair_fused_moe.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
2626
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
27-
from vllm_ascend.quantization.quantizer import W8A8Quantizer
2827
from vllm_ascend.torchair.ops.torchair_fused_moe import (
2928
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
3029
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
@@ -236,12 +235,9 @@ def test_init_with_quant(self, mock_dist_env, default_moe_config):
236235
mock_quant_method = MockFusedMoEMethod()
237236
mock_quant_config.get_quant_method.return_value = mock_quant_method
238237
mock_quant_config.is_layer_skipped_ascend.return_value = False
239-
with patch(
240-
'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer',
241-
return_value=W8A8Quantizer):
238+
with patch("vllm_ascend.quantization.quant_config.get_quant_method"):
242239
moe = TorchairAscendFusedMoE(**default_moe_config,
243240
quant_config=mock_quant_config)
244-
245241
assert moe.quant_method is not None
246242
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
247243

tests/ut/torchair/test_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77

88
from tests.ut.base import TestBase
9-
from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE
109
from vllm_ascend.torchair import utils
1110

1211

@@ -135,15 +134,3 @@ def test_converting_weight_acl_format_format_true(self, mock_npu_cast,
135134

136135
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
137136
mock_npu_cast.assert_not_called()
138-
139-
def test_torchair_quant_method_register(self):
140-
141-
TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
142-
"W8A8_DYNAMIC"]
143-
TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
144-
"W4A8_DYNAMIC"]
145-
utils.torchair_quant_method_register()
146-
self.assertNotEqual(TorchairW8A8DYNAMICQuantizer,
147-
SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"])
148-
self.assertNotEqual(TorchairW4A8DYNAMICQuantizer,
149-
SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"])

vllm_ascend/quantization/quant_config.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
3939
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
4040

41-
from .quantizer import AscendQuantizer
41+
from .utils import get_quant_method
4242

4343

4444
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
@@ -150,18 +150,15 @@ def get_scaled_act_names(self) -> List[str]:
150150
class AscendLinearMethod(LinearMethodBase):
151151
"""Linear method for Ascend quantization.
152152
153-
This class calls AscendQuantizer to search a specific quantization
154-
implementations supported on ascend hardware for linear methods.
155-
156153
Args:
157154
quant_config: The Ascend quantization config.
158155
"""
159156

160157
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
161158
packed_modules_mapping: Dict[str, Any]) -> None:
162-
self.quantizer = AscendQuantizer.get_quantizer(
163-
quant_config.quant_description, prefix, packed_modules_mapping)
164-
self.quant_method = self.quantizer.build_linear_method()
159+
self.quant_method = get_quant_method(quant_config.quant_description,
160+
prefix, "linear",
161+
packed_modules_mapping)
165162

166163
def create_weights(
167164
self,
@@ -231,17 +228,13 @@ def apply(
231228
class AscendKVCacheMethod(BaseKVCacheMethod):
232229
"""KVCache method for Ascend quantization.
233230
234-
This class calls AscendQuantizer to search a specific quantization
235-
implementations supported on ascend hardware for kvcache methods.
236-
237231
Args:
238232
quant_config: The Ascend quantization config.
239233
"""
240234

241235
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
242-
self.quantizer = AscendQuantizer.get_quantizer(
243-
quant_config.quant_description, prefix)
244-
self.quant_method = self.quantizer.build_attention_method()
236+
self.quant_method = get_quant_method(quant_config.quant_description,
237+
prefix, "attention")
245238

246239
def create_weights(self, layer: torch.nn.Module) -> None:
247240
# Different from linear method, there are no weight processing/slicing
@@ -263,18 +256,15 @@ def apply(self, layer: torch.nn.Module, query: torch.Tensor,
263256
class AscendFusedMoEMethod(FusedMoEMethodBase):
264257
"""FusedMoE method for Ascend quantization.
265258
266-
This class calls AscendQuantizer to search a specific quantization
267-
implementations supported on ascend hardware for kvcache methods.
268-
269259
Args:
270260
quant_config: The Ascend quantization config.
271261
"""
272262

273263
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
274264
packed_modules_mapping: Dict[str, Any]):
275-
self.quantizer = AscendQuantizer.get_quantizer(
276-
quant_config.quant_description, prefix, packed_modules_mapping)
277-
self.quant_method = self.quantizer.build_moe_method()
265+
self.quant_method = get_quant_method(quant_config.quant_description,
266+
prefix, "moe",
267+
packed_modules_mapping)
278268

279269
def create_weights(
280270
self,
@@ -344,14 +334,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
344334

345335
class AscendEmbeddingMethod(AscendLinearMethod):
346336
"""Embedding method for Ascend quantization.
347-
This class calls AscendQuantizer to search a specific quantization
348-
implementations supported on ascend hardware for Embedding methods.
337+
349338
Args:
350339
quant_config: The Ascend quantization config.
351340
"""
352341

353342
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
354343
packed_modules_mapping: Dict[str, Any]) -> None:
355-
self.quantizer = AscendQuantizer.get_quantizer(
356-
quant_config.quant_description, prefix, packed_modules_mapping)
357-
self.quant_method = self.quantizer.build_linear_method()
344+
self.quant_method = get_quant_method(quant_config.quant_description,
345+
prefix, "linear",
346+
packed_modules_mapping)

0 commit comments

Comments
 (0)