Skip to content

Commit e0cbad4

Browse files
[Neuron] Support quantization on neuron (#18283)
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
1 parent b48d5cc commit e0cbad4

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from vllm.model_executor.layers.quantization.neuron_quant import (
3+
NeuronQuantConfig)
4+
5+
6+
def test_get_supported_act_dtypes():
7+
neuron_quant_config = NeuronQuantConfig()
8+
supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes()
9+
target_list = ["any_dtype1", "any_dtype2"]
10+
for dtype in target_list:
11+
assert dtype in supported_act_dtypes

vllm/model_executor/layers/quantization/neuron_quant.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
1414

1515

16+
class AlwaysSupportedDtypes(list):
17+
18+
def __contains__(self, item):
19+
return True
20+
21+
1622
class NeuronQuantConfig(QuantizationConfig):
1723
"""Int8 Quantization Config class for Neuron Backend."""
1824

@@ -35,7 +41,8 @@ def get_name(self) -> QuantizationMethods:
3541
return "neuron_quant"
3642

3743
def get_supported_act_dtypes(self) -> list[str]:
38-
return SUPPORTED_QUANT_DTYPE_LIST
44+
# Neuron implements custom handling logic for quantization support
45+
return AlwaysSupportedDtypes()
3946

4047
@classmethod
4148
def get_min_capability(cls) -> int:

vllm/platforms/neuron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class NeuronPlatform(Platform):
2828
device_name: str = "neuron"
2929
device_type: str = "neuron"
3030
ray_device_key: str = "neuron_cores"
31-
supported_quantization: list[str] = ["neuron_quant"]
31+
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
3232
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
3333

3434
@classmethod

0 commit comments

Comments
 (0)