1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2
+
3
+
4
+ from vllm .logger import init_logger
5
+ from typing import Any , Callable , Dict , List , Optional
6
+
7
+ import regex as re
8
+ import torch
9
+ from torch .nn import Module
10
+ from torch .nn .parameter import Parameter
11
+ from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
12
+ from vllm .model_executor .layers .quantization import QuantizationMethods
13
+ from vllm .model_executor .layers .linear import LinearBase , UnquantizedLinearMethod , LinearMethodBase
14
+ from vllm .model_executor .parameter import ModelWeightParameter , PerTensorScaleParameter
15
+ from vllm .model_executor .layers .quantization .base_config import (
16
+ QuantizationConfig ,
17
+ QuantizeMethodBase ,
18
+ )
19
+ from vllm .model_executor .layers .quantization .utils .petit_utils import (
20
+ apply_petit_nvfp4_linear ,
21
+ prepare_nvfp4_layer_for_petit ,
22
+ verify_petit_nvfp4_supported ,
23
+ )
24
+ from vllm .model_executor .layers .quantization .utils .quant_utils import is_layer_skipped
25
+
26
+ # Initialize logger for the module
27
+ logger = init_logger (__name__ )
28
+
29
+ # Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool
30
+ class PetitNvFp4Config (QuantizationConfig ):
31
+ """Config class for Petit FP4."""
32
+
33
+ def __init__ (
34
+ self ,
35
+ is_checkpoint_nvfp4_serialized : bool = False ,
36
+ kv_cache_quant_algo : str = None ,
37
+ group_size : int = None ,
38
+ exclude_modules : List [str ] = None ,
39
+ ) -> None :
40
+ self .is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
41
+ if is_checkpoint_nvfp4_serialized :
42
+ logger .warning (
43
+ "Detected nvfp4 checkpoint. Please note that the "
44
+ "format is experimental and subject to change."
45
+ )
46
+ self .group_size = group_size
47
+ self .kv_cache_quant_algo = kv_cache_quant_algo
48
+ self .exclude_modules = exclude_modules
49
+
50
+ @classmethod
51
+ def get_name (cls ) -> QuantizationMethods :
52
+ return "petit_nvfp4"
53
+
54
+ @classmethod
55
+ def get_supported_act_dtypes (cls ) -> List [torch .dtype ]:
56
+ return [torch .bfloat16 , torch .half ]
57
+
58
+ @classmethod
59
+ def get_min_capability (cls ) -> int :
60
+ # Petit supports the gfx90a and gfx942 GPUs
61
+ return 90
62
+
63
+ @classmethod
64
+ def get_config_filenames (cls ) -> List [str ]:
65
+ return ["hf_quant_config.json" ]
66
+
67
+ @classmethod
68
+ def from_config (cls , config : Dict [str , Any ]) -> "PetitNvFp4Config" :
69
+ quant_config = cls .get_from_keys (config , ["quantization" ])
70
+ quant_method = quant_config ["quant_algo" ]
71
+ group_size = quant_config .get ("group_size" , None )
72
+ verify_petit_nvfp4_supported (quant_method , group_size )
73
+
74
+ is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
75
+ kv_cache_quant_algo = quant_config ["kv_cache_quant_algo" ]
76
+ if not kv_cache_quant_algo :
77
+ kv_cache_quant_algo = "auto"
78
+ exclude_modules = quant_config .get ("exclude_modules" , None )
79
+ if not (group_size and kv_cache_quant_algo and (exclude_modules is not None )):
80
+ logger .warning (
81
+ f"group_size: { group_size } ,"
82
+ f"kv_cache_quant_algo: { kv_cache_quant_algo } ,"
83
+ f"exclude_modules: { exclude_modules } "
84
+ )
85
+ raise ValueError (
86
+ "NVFP4 quantization requires group size and "
87
+ "kv_cache_quant_algo specified in "
88
+ "hf_quant_config.json"
89
+ )
90
+ return cls (
91
+ is_checkpoint_nvfp4_serialized ,
92
+ kv_cache_quant_algo ,
93
+ group_size ,
94
+ exclude_modules ,
95
+ )
96
+
97
+ @classmethod
98
+ def override_quantization_method (cls , hf_quant_cfg , user_quant ) -> Optional [str ]:
99
+ qc = hf_quant_cfg .get ("quantization" , hf_quant_cfg )
100
+ algo = (qc .get ("quant_algo" ) or qc .get ("quant_method" ) or "" ).upper ()
101
+ if algo in ("NVFP4" , "MODELOPT_FP4" , "MODELOPT" ):
102
+ return cls .get_name () # "petit_nvfp4"
103
+ return None
104
+
105
+ @classmethod
106
+ def is_petit_nvfp4_compatible (cls , quant_config : Dict [str , Any ]) -> bool :
107
+ qc = quant_config .get ("quantization" , quant_config )
108
+ algo = (qc .get ("quant_algo" ) or qc .get ("quant_method" ) or "" ).upper ()
109
+ return algo == "NVFP4"
110
+
111
+ def is_layer_excluded (self , prefix : str , exclude_modules : list ):
112
+ for pattern in exclude_modules :
113
+ regex_str = pattern .replace ("." , r"\." ).replace ("*" , r".*" )
114
+ if re .fullmatch (regex_str , prefix ):
115
+ return True
116
+ return False
117
+
118
+ def get_quant_method (
119
+ self , layer : torch .nn .Module , prefix : str
120
+ ) -> Optional ["QuantizeMethodBase" ]:
121
+ from vllm .attention .layer import Attention # Avoid circular import
122
+ if isinstance (layer , LinearBase ):
123
+ if is_layer_skipped (prefix , self .exclude_modules ) or self .is_layer_excluded (
124
+ prefix , self .exclude_modules
125
+ ):
126
+ return UnquantizedLinearMethod ()
127
+ return PetitNvFp4LinearMethod (self )
128
+ elif isinstance (layer , Attention ):
129
+ return PetitFp8KVCacheMethod (self )
130
+ return None
131
+
132
+ def get_scaled_act_names (self ) -> List [str ]:
133
+ return []
134
+
135
+
136
+ class PetitFp8KVCacheMethod (BaseKVCacheMethod ):
137
+ """
138
+ Supports loading kv-cache scaling factors from FP8 checkpoints.
139
+ """
140
+ def __init__ (self , quant_config : PetitNvFp4Config ):
141
+ super ().__init__ (quant_config )
142
+
143
+
144
+ class PetitNvFp4LinearMethod (LinearMethodBase ):
145
+ """Linear method for NVFP4.
146
+ Supports loading NVFP4 checkpoints with the following structure:
147
+
148
+ |Tensor Name | datatype | shape |
149
+ |----------------------------------------------------|
150
+ |input_scale | torch.float32 | scalar |
151
+ |weight | NVFP4(SE2M1) | [1, X, y/2] |
152
+ |weight_scale | FP8-E4M3 | [X, Y] |
153
+ |weight_scale_2 | torch.float32 | scalar |
154
+
155
+ The weights are quantized per block of 16 elements.
156
+ Args: quant_config: The ModelOpt quantization config.
157
+ """
158
+
159
+ def __init__ (self , quant_config : PetitNvFp4Config ):
160
+ self .quant_config = quant_config
161
+
162
+ def create_weights (
163
+ self ,
164
+ layer : torch .nn .Module ,
165
+ input_size_per_partition : int ,
166
+ output_partition_sizes : List [int ],
167
+ input_size : int ,
168
+ output_size : int ,
169
+ params_dtype : torch .dtype ,
170
+ ** extra_weight_attrs ,
171
+ ):
172
+ del input_size , output_size
173
+ if not self .quant_config .is_checkpoint_nvfp4_serialized :
174
+ raise ValueError (
175
+ "NVFP4 quantization was selected, "
176
+ " dynamic quantization is not supported."
177
+ )
178
+
179
+ output_size_per_partition = sum (output_partition_sizes )
180
+ weight_loader = extra_weight_attrs .get ("weight_loader" )
181
+
182
+ layer .logical_widths = output_partition_sizes
183
+
184
+ layer .input_size_per_partition = input_size_per_partition
185
+ layer .output_size_per_partition = output_size_per_partition
186
+ if input_size_per_partition % 16 != 0 :
187
+ raise ValueError (
188
+ "Unsupported model when in features size is " "not multiple of 16"
189
+ )
190
+
191
+ weight_dtype = (
192
+ torch .float8_e4m3fn
193
+ if self .quant_config .is_checkpoint_nvfp4_serialized
194
+ else params_dtype
195
+ )
196
+
197
+ weight = ModelWeightParameter (
198
+ data = torch .empty (
199
+ # 2 fp4 data is packed in one uint8 in the input dimension
200
+ output_size_per_partition ,
201
+ input_size_per_partition // 2 ,
202
+ dtype = torch .uint8 ,
203
+ ),
204
+ input_dim = 1 ,
205
+ output_dim = 0 ,
206
+ weight_loader = weight_loader ,
207
+ )
208
+ layer .register_parameter ("weight" , weight )
209
+
210
+ input_scale = PerTensorScaleParameter (
211
+ data = torch .empty (len (output_partition_sizes ), dtype = torch .float32 ),
212
+ weight_loader = weight_loader ,
213
+ )
214
+
215
+ layer .register_parameter ("input_scale" , input_scale )
216
+
217
+ weight_scale_2 = PerTensorScaleParameter (
218
+ data = torch .empty (len (output_partition_sizes ), dtype = torch .float32 ),
219
+ weight_loader = weight_loader ,
220
+ )
221
+ layer .register_parameter ("weight_scale_2" , weight_scale_2 )
222
+
223
+ weight_scale = ModelWeightParameter (
224
+ data = torch .empty (
225
+ output_size_per_partition ,
226
+ input_size_per_partition // self .quant_config .group_size ,
227
+ dtype = weight_dtype ,
228
+ ),
229
+ input_dim = 1 ,
230
+ output_dim = 0 ,
231
+ weight_loader = weight_loader ,
232
+ )
233
+
234
+ layer .register_parameter ("weight_scale" , weight_scale )
235
+
236
+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
237
+ input_scale_2 = layer .input_scale .max ().to (torch .float32 )
238
+ weight_scale_2 = layer .weight_scale_2 .max ().to (torch .float32 )
239
+ layer .input_scale = Parameter (input_scale_2 , requires_grad = False )
240
+ layer .weight_scale_2 = Parameter (weight_scale_2 , requires_grad = False )
241
+ layer .alpha = Parameter (
242
+ layer .input_scale * layer .weight_scale_2 , requires_grad = False
243
+ )
244
+
245
+ prepare_nvfp4_layer_for_petit (layer )
246
+ del layer .input_scale
247
+
248
+ def apply (
249
+ self ,
250
+ layer : torch .nn .Module ,
251
+ x : torch .Tensor ,
252
+ bias : Optional [torch .Tensor ] = None ,
253
+ ) -> torch .Tensor :
254
+ return apply_petit_nvfp4_linear (
255
+ input = x ,
256
+ weight = layer .weight ,
257
+ weight_scale = layer .weight_scale ,
258
+ weight_scale_2 = layer .weight_scale_2 ,
259
+ size_n = layer .output_size_per_partition ,
260
+ size_k = layer .input_size_per_partition ,
261
+ bias = bias ,
262
+ )
0 commit comments