Skip to content

Commit 2f217ff

Browse files
committed
Quantization: support FP4 quantized models on AMD CDNA2/CDNA3 GPUs
1 parent 3c91d66 commit 2f217ff

File tree

2 files changed

+34
-43
lines changed

2 files changed

+34
-43
lines changed

vllm/model_executor/layers/quantization/petit.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ class PetitNvFp4Config(QuantizationConfig):
4242
def __init__(
4343
self,
4444
is_checkpoint_nvfp4_serialized: bool = False,
45-
kv_cache_quant_algo: str = None,
46-
group_size: int = None,
47-
exclude_modules: list[str] = None,
45+
kv_cache_quant_algo: Optional[str] = None,
46+
group_size: Optional[int] = None,
47+
exclude_modules: Optional[list[str]] = None,
4848
) -> None:
4949
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
5050
if is_checkpoint_nvfp4_serialized:
@@ -87,10 +87,12 @@ def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
8787
exclude_modules = quant_config.get("exclude_modules", None)
8888
if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)):
8989
logger.warning(
90-
f"group_size: {group_size},"
91-
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
92-
f"exclude_modules: {exclude_modules}"
90+
"group_size: %s, kv_cache_quant_algo: %s, exclude_modules: %s",
91+
group_size,
92+
kv_cache_quant_algo,
93+
exclude_modules,
9394
)
95+
9496
raise ValueError(
9597
"NVFP4 quantization requires group size and "
9698
"kv_cache_quant_algo specified in "

vllm/model_executor/layers/quantization/utils/petit_utils.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,61 +3,51 @@
33

44
try:
55
from petit_kernel import mul_nvfp4_a16, process_nvfp4_scales, repack_nvfp4
6+
_PETIT_AVAILABLE = True
67
except ImportError:
8+
_PETIT_AVAILABLE = False
79

8-
def _check_petit_nvfp4_supported(
9-
quant_method: str, group_size: Optional[int]
10-
) -> tuple[bool, Optional[str]]:
11-
return (
12-
False,
13-
"Petit is not installed. Please install it with `pip install petit-kernel`.",
14-
)
15-
16-
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
17-
raise ValueError(
18-
"Petit is not installed. Please install it with `pip install petit-kernel`."
19-
)
20-
21-
def apply_petit_nvfp4_linear(
22-
input: torch.Tensor,
23-
weight: torch.Tensor,
24-
weight_scale: torch.Tensor,
25-
weight_scale_2: torch.Tensor,
26-
size_n: int,
27-
size_k: int,
28-
bias: Optional[torch.Tensor] = None,
29-
) -> torch.Tensor:
30-
raise ValueError(
31-
"Petit is not installed. Please install it with `pip install petit-kernel`."
32-
)
10+
_PETIT_INSTALL_MSG = (
11+
"Petit is not installed. Please install it with "
12+
"`pip install petit-kernel`."
13+
)
3314

15+
def _require_petit() -> None:
16+
if not _PETIT_AVAILABLE:
17+
# 统一的报错出口,避免重复代码与行过长
18+
raise ImportError(_PETIT_INSTALL_MSG)
3419

3520
def _check_petit_nvfp4_supported(
3621
quant_method: str, group_size: Optional[int]
3722
) -> tuple[bool, Optional[str]]:
3823
if quant_method != "NVFP4":
3924
return (
4025
False,
41-
"Petit currently only supports: NVFP4"
42-
" quantizations in sglang. Please check the "
43-
"`hf_quant_config.json` file for your model's "
44-
"quant configuration.",
26+
(
27+
"Petit currently only supports: NVFP4 quantizations in sglang. "
28+
"Please check the `hf_quant_config.json` file for your model's "
29+
"quant configuration."
30+
),
4531
)
4632
if group_size is not None and group_size != 16:
4733
return (
4834
False,
49-
"Petit currently only supports: group_size=16" " quantizations.",
35+
"Petit currently only supports: group_size=16 quantizations.",
5036
)
5137
return (True, None)
5238

53-
54-
def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None:
39+
def verify_petit_nvfp4_supported(
40+
quant_method: str, group_size: Optional[int]
41+
) -> None:
5542
supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size)
5643
if not supported:
44+
# 避免 mypy 对 Optional[str] 报错
45+
assert error_msg is not None
5746
raise ValueError(error_msg)
5847

59-
6048
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
49+
_require_petit() # 没装 petit 时在这里统一报错
50+
6151
# Repack weights to petit format
6252
part_size_n = layer.output_size_per_partition
6353
part_size_k = layer.input_size_per_partition
@@ -71,9 +61,6 @@ def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
7161
)
7262
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
7363

74-
return
75-
76-
7764
def apply_petit_nvfp4_linear(
7865
input: torch.Tensor,
7966
weight: torch.Tensor,
@@ -83,6 +70,8 @@ def apply_petit_nvfp4_linear(
8370
size_k: int,
8471
bias: Optional[torch.Tensor] = None,
8572
) -> torch.Tensor:
73+
_require_petit() # 没装 petit 时在这里统一报错
74+
8675
reshaped_x = input.reshape(-1, input.shape[-1])
8776
out_shape = input.shape[:-1] + (size_n,)
8877

@@ -100,4 +89,4 @@ def apply_petit_nvfp4_linear(
10089
if bias is not None:
10190
output.add_(bias) # In-place add
10291

103-
return output.reshape(out_shape)
92+
return output.reshape(out_shape)

0 commit comments

Comments
 (0)