Skip to content

Commit 2855577

Browse files
committed
feat(quantization): support FP4 quantized models on AMD CDNA2/CDNA3 GPUs
Signed-off-by: feng <fengli1702@gmail.com>
1 parent e5ebeeb commit 2855577

File tree

8 files changed

+376
-3
lines changed

8 files changed

+376
-3
lines changed

requirements/rocm.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0
1717
setuptools-scm>=8
1818
runai-model-streamer==0.11.0
1919
runai-model-streamer-s3==0.11.0
20-
conch-triton-kernels==1.2.1
20+
conch-triton-kernels==1.2.1

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,8 @@ def _read_requirements(filename: str) -> list[str]:
666666
"video": [], # Kept for backwards compatibility
667667
# FlashInfer should be updated together with the Dockerfile
668668
"flashinfer": ["flashinfer-python==0.2.10"],
669+
# Optional deps for AMD FP4 quantization support
670+
"petit-kernel": ["petit-kernel"],
669671
},
670672
cmdclass=cmdclass,
671673
package_data=package_data,

vllm/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,8 @@ def _verify_quantization(self) -> None:
11261126
optimized_quantization_methods = [
11271127
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
11281128
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
1129-
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
1129+
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc",
1130+
"petit_nvfp4",
11301131
]
11311132
if self.quantization is not None:
11321133
self.quantization = cast(me_quant.QuantizationMethods,
@@ -1159,6 +1160,7 @@ def _verify_quantization(self) -> None:
11591160
"moe_wna16",
11601161
"modelopt",
11611162
"modelopt_fp4",
1163+
"petit_nvfp4",
11621164
]
11631165
quantization_methods = [
11641166
q for q in supported_quantization if q not in overrides

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"HQQMarlinMethod",
5353
"QuarkLinearMethod",
5454
"ModelOptNvFp4LinearMethod",
55+
"PetitNvFp4LinearMethod",
5556
]
5657

5758

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"rtn",
3939
"inc",
4040
"mxfp4",
41+
"petit_nvfp4",
4142
]
4243
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
4344

@@ -118,6 +119,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
118119
from .rtn import RTNConfig
119120
from .torchao import TorchAOConfig
120121
from .tpu_int8 import Int8TpuConfig
122+
from .petit import PetitNvFp4Config
121123

122124
method_to_config: dict[str, type[QuantizationConfig]] = {
123125
"aqlm": AQLMConfig,
@@ -151,6 +153,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
151153
"rtn": RTNConfig,
152154
"inc": INCConfig,
153155
"mxfp4": Mxfp4Config,
156+
"petit_nvfp4": PetitNvFp4Config,
154157
}
155158
# Update the `method_to_config` with customized quantization methods.
156159
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2+
3+
4+
from vllm.logger import init_logger
5+
from typing import Any, Callable, Dict, List, Optional
6+
7+
import regex as re
8+
import torch
9+
from torch.nn import Module
10+
from torch.nn.parameter import Parameter
11+
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
12+
from vllm.model_executor.layers.quantization import QuantizationMethods
13+
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod, LinearMethodBase
14+
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
15+
from vllm.model_executor.layers.quantization.base_config import (
16+
QuantizationConfig,
17+
QuantizeMethodBase,
18+
)
19+
from vllm.model_executor.layers.quantization.utils.petit_utils import (
20+
apply_petit_nvfp4_linear,
21+
prepare_nvfp4_layer_for_petit,
22+
verify_petit_nvfp4_supported,
23+
)
24+
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
25+
26+
# Initialize logger for the module
27+
logger = init_logger(__name__)
28+
29+
# Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool
30+
class PetitNvFp4Config(QuantizationConfig):
31+
"""Config class for Petit FP4."""
32+
33+
def __init__(
34+
self,
35+
is_checkpoint_nvfp4_serialized: bool = False,
36+
kv_cache_quant_algo: str = None,
37+
group_size: int = None,
38+
exclude_modules: List[str] = None,
39+
) -> None:
40+
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
41+
if is_checkpoint_nvfp4_serialized:
42+
logger.warning(
43+
"Detected nvfp4 checkpoint. Please note that the "
44+
"format is experimental and subject to change."
45+
)
46+
self.group_size = group_size
47+
self.kv_cache_quant_algo = kv_cache_quant_algo
48+
self.exclude_modules = exclude_modules
49+
50+
@classmethod
51+
def get_name(cls) -> QuantizationMethods:
52+
return "petit_nvfp4"
53+
54+
@classmethod
55+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
56+
return [torch.bfloat16, torch.half]
57+
58+
@classmethod
59+
def get_min_capability(cls) -> int:
60+
# Petit supports the gfx90a and gfx942 GPUs
61+
return 90
62+
63+
@classmethod
64+
def get_config_filenames(cls) -> List[str]:
65+
return ["hf_quant_config.json"]
66+
67+
@classmethod
68+
def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config":
69+
quant_config = cls.get_from_keys(config, ["quantization"])
70+
quant_method = quant_config["quant_algo"]
71+
group_size = quant_config.get("group_size", None)
72+
verify_petit_nvfp4_supported(quant_method, group_size)
73+
74+
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
75+
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
76+
if not kv_cache_quant_algo:
77+
kv_cache_quant_algo = "auto"
78+
exclude_modules = quant_config.get("exclude_modules", None)
79+
if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)):
80+
logger.warning(
81+
f"group_size: {group_size},"
82+
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
83+
f"exclude_modules: {exclude_modules}"
84+
)
85+
raise ValueError(
86+
"NVFP4 quantization requires group size and "
87+
"kv_cache_quant_algo specified in "
88+
"hf_quant_config.json"
89+
)
90+
return cls(
91+
is_checkpoint_nvfp4_serialized,
92+
kv_cache_quant_algo,
93+
group_size,
94+
exclude_modules,
95+
)
96+
97+
@classmethod
98+
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
99+
qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
100+
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
101+
if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
102+
return cls.get_name() # "petit_nvfp4"
103+
return None
104+
105+
@classmethod
106+
def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool:
107+
qc = quant_config.get("quantization", quant_config)
108+
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
109+
return algo == "NVFP4"
110+
111+
def is_layer_excluded(self, prefix: str, exclude_modules: list):
112+
for pattern in exclude_modules:
113+
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
114+
if re.fullmatch(regex_str, prefix):
115+
return True
116+
return False
117+
118+
def get_quant_method(
119+
self, layer: torch.nn.Module, prefix: str
120+
) -> Optional["QuantizeMethodBase"]:
121+
from vllm.attention.layer import Attention # Avoid circular import
122+
if isinstance(layer, LinearBase):
123+
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
124+
prefix, self.exclude_modules
125+
):
126+
return UnquantizedLinearMethod()
127+
return PetitNvFp4LinearMethod(self)
128+
elif isinstance(layer, Attention):
129+
return PetitFp8KVCacheMethod(self)
130+
return None
131+
132+
def get_scaled_act_names(self) -> List[str]:
133+
return []
134+
135+
136+
class PetitFp8KVCacheMethod(BaseKVCacheMethod):
137+
"""
138+
Supports loading kv-cache scaling factors from FP8 checkpoints.
139+
"""
140+
def __init__(self, quant_config: PetitNvFp4Config):
141+
super().__init__(quant_config)
142+
143+
144+
class PetitNvFp4LinearMethod(LinearMethodBase):
145+
"""Linear method for NVFP4.
146+
Supports loading NVFP4 checkpoints with the following structure:
147+
148+
|Tensor Name | datatype | shape |
149+
|----------------------------------------------------|
150+
|input_scale | torch.float32 | scalar |
151+
|weight | NVFP4(SE2M1) | [1, X, y/2] |
152+
|weight_scale | FP8-E4M3 | [X, Y] |
153+
|weight_scale_2 | torch.float32 | scalar |
154+
155+
The weights are quantized per block of 16 elements.
156+
Args: quant_config: The ModelOpt quantization config.
157+
"""
158+
159+
def __init__(self, quant_config: PetitNvFp4Config):
160+
self.quant_config = quant_config
161+
162+
def create_weights(
163+
self,
164+
layer: torch.nn.Module,
165+
input_size_per_partition: int,
166+
output_partition_sizes: List[int],
167+
input_size: int,
168+
output_size: int,
169+
params_dtype: torch.dtype,
170+
**extra_weight_attrs,
171+
):
172+
del input_size, output_size
173+
if not self.quant_config.is_checkpoint_nvfp4_serialized:
174+
raise ValueError(
175+
"NVFP4 quantization was selected, "
176+
" dynamic quantization is not supported."
177+
)
178+
179+
output_size_per_partition = sum(output_partition_sizes)
180+
weight_loader = extra_weight_attrs.get("weight_loader")
181+
182+
layer.logical_widths = output_partition_sizes
183+
184+
layer.input_size_per_partition = input_size_per_partition
185+
layer.output_size_per_partition = output_size_per_partition
186+
if input_size_per_partition % 16 != 0:
187+
raise ValueError(
188+
"Unsupported model when in features size is " "not multiple of 16"
189+
)
190+
191+
weight_dtype = (
192+
torch.float8_e4m3fn
193+
if self.quant_config.is_checkpoint_nvfp4_serialized
194+
else params_dtype
195+
)
196+
197+
weight = ModelWeightParameter(
198+
data=torch.empty(
199+
# 2 fp4 data is packed in one uint8 in the input dimension
200+
output_size_per_partition,
201+
input_size_per_partition // 2,
202+
dtype=torch.uint8,
203+
),
204+
input_dim=1,
205+
output_dim=0,
206+
weight_loader=weight_loader,
207+
)
208+
layer.register_parameter("weight", weight)
209+
210+
input_scale = PerTensorScaleParameter(
211+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
212+
weight_loader=weight_loader,
213+
)
214+
215+
layer.register_parameter("input_scale", input_scale)
216+
217+
weight_scale_2 = PerTensorScaleParameter(
218+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
219+
weight_loader=weight_loader,
220+
)
221+
layer.register_parameter("weight_scale_2", weight_scale_2)
222+
223+
weight_scale = ModelWeightParameter(
224+
data=torch.empty(
225+
output_size_per_partition,
226+
input_size_per_partition // self.quant_config.group_size,
227+
dtype=weight_dtype,
228+
),
229+
input_dim=1,
230+
output_dim=0,
231+
weight_loader=weight_loader,
232+
)
233+
234+
layer.register_parameter("weight_scale", weight_scale)
235+
236+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
237+
input_scale_2 = layer.input_scale.max().to(torch.float32)
238+
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
239+
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
240+
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
241+
layer.alpha = Parameter(
242+
layer.input_scale * layer.weight_scale_2, requires_grad=False
243+
)
244+
245+
prepare_nvfp4_layer_for_petit(layer)
246+
del layer.input_scale
247+
248+
def apply(
249+
self,
250+
layer: torch.nn.Module,
251+
x: torch.Tensor,
252+
bias: Optional[torch.Tensor] = None,
253+
) -> torch.Tensor:
254+
return apply_petit_nvfp4_linear(
255+
input=x,
256+
weight=layer.weight,
257+
weight_scale=layer.weight_scale,
258+
weight_scale_2=layer.weight_scale_2,
259+
size_n=layer.output_size_per_partition,
260+
size_k=layer.input_size_per_partition,
261+
bias=bias,
262+
)

0 commit comments

Comments
 (0)