File tree Expand file tree Collapse file tree 3 files changed +20
-2
lines changed
model_executor/layers/quantization Expand file tree Collapse file tree 3 files changed +20
-2
lines changed Original file line number Diff line number Diff line change
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from vllm .model_executor .layers .quantization .neuron_quant import (
3
+ NeuronQuantConfig )
4
+
5
+
6
+ def test_get_supported_act_dtypes ():
7
+ neuron_quant_config = NeuronQuantConfig ()
8
+ supported_act_dtypes = neuron_quant_config .get_supported_act_dtypes ()
9
+ target_list = ["any_dtype1" , "any_dtype2" ]
10
+ for dtype in target_list :
11
+ assert dtype in supported_act_dtypes
Original file line number Diff line number Diff line change 13
13
SUPPORTED_QUANT_DTYPE_LIST = ['s8' , 'f8e4m3fn' ]
14
14
15
15
16
+ class AlwaysSupportedDtypes (list ):
17
+
18
+ def __contains__ (self , item ):
19
+ return True
20
+
21
+
16
22
class NeuronQuantConfig (QuantizationConfig ):
17
23
"""Int8 Quantization Config class for Neuron Backend."""
18
24
@@ -35,7 +41,8 @@ def get_name(self) -> QuantizationMethods:
35
41
return "neuron_quant"
36
42
37
43
def get_supported_act_dtypes (self ) -> list [str ]:
38
- return SUPPORTED_QUANT_DTYPE_LIST
44
+ # Neuron implements custom handling logic for quantization support
45
+ return AlwaysSupportedDtypes ()
39
46
40
47
@classmethod
41
48
def get_min_capability (cls ) -> int :
Original file line number Diff line number Diff line change @@ -28,7 +28,7 @@ class NeuronPlatform(Platform):
28
28
device_name : str = "neuron"
29
29
device_type : str = "neuron"
30
30
ray_device_key : str = "neuron_cores"
31
- supported_quantization : list [str ] = ["neuron_quant" ]
31
+ supported_quantization : list [str ] = ["neuron_quant" , "fbgemm_fp8" ]
32
32
device_control_env_var : str = "NEURON_RT_VISIBLE_CORES"
33
33
34
34
@classmethod
You can’t perform that action at this time.
0 commit comments