27
27
)
28
28
from torchao .prototype .parq .quant .uniform_torchao import _BIT_WIDTH_TO_DTYPE
29
29
from torchao .quantization .granularity import PerGroup
30
+ from torchao .quantization .qat import (
31
+ FakeQuantizeConfig ,
32
+ FromIntXQuantizationAwareTrainingConfig ,
33
+ IntXQuantizationAwareTrainingConfig ,
34
+ )
30
35
from torchao .quantization .quant_api import (
36
+ Int8DynamicActivationIntxWeightConfig ,
31
37
IntxWeightOnlyConfig ,
38
+ MappingType ,
32
39
_is_linear ,
33
40
int4_weight_only ,
34
41
quantize_ ,
@@ -68,9 +75,9 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
68
75
69
76
70
77
class M (nn .Module ):
71
- def __init__ (self , m = 256 , n = 128 , k = 16 , bias = False ):
78
+ def __init__ (self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True ):
72
79
super ().__init__ ()
73
- self .embedding = nn .Embedding (10 , m )
80
+ self .embedding = nn .Embedding (10 , m ) if embedding else nn . Identity ()
74
81
self .linear1 = nn .Linear (m , n , bias = bias )
75
82
self .linear2 = nn .Linear (n , k , bias = bias )
76
83
self .relu = nn .ReLU ()
@@ -83,7 +90,11 @@ def reset_parameters(self):
83
90
nn .init .zeros_ (module .bias )
84
91
85
92
def example_inputs (self , device = None ):
86
- return torch .randint (1 , 10 , (1 , 256 ), device = device )
93
+ return (
94
+ torch .randint (1 , 10 , (1 , self .linear1 .in_features ), device = device )
95
+ if isinstance (self .embedding , nn .Embedding )
96
+ else torch .randn (1 , self .linear1 .in_features , device = device )
97
+ )
87
98
88
99
def forward (self , x ):
89
100
x = self .embedding (x )
@@ -150,11 +161,11 @@ def compare_quantized_models(
150
161
p = p .view (- 1 , group_size )
151
162
152
163
q , Q = quantizer .quantize (p , b = b , dim = - 1 )
153
- q = q .view (original_shape )
154
164
155
165
# compare to AffineQuantizedTensor instance
166
+ q = q .view (original_shape )
156
167
ref = getattr (m_ref , n ).weight .dequantize ()
157
- self . assertTrue ( q . equal ( ref ) )
168
+ torch . testing . assert_close ( q , ref , atol = 0 , rtol = 0 )
158
169
159
170
def compare_parq_convert (
160
171
self ,
@@ -182,13 +193,13 @@ def compare_parq_convert(
182
193
p = module .weight .dequantize () # PARQ weight after quantize_
183
194
p_ref = getattr (m_ref , n ).weight .dequantize () # native quantize_
184
195
185
- self . assertTrue ( p_orig . equal ( p_ref ) )
186
- self . assertTrue ( p . equal ( p_ref ) )
196
+ torch . testing . assert_true ( p_orig , p_ref , atol = 0 , rtol = 0 )
197
+ torch . testing . assert_true ( p , p_ref , atol = 0 , rtol = 0 )
187
198
188
199
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
189
200
@common_utils .parametrize ("group_size" , [32 , 256 ])
190
201
def test_int4_weight_only (self , group_size : int = 32 ):
191
- model = M (m = 512 , n = 512 ).to (torch .bfloat16 ). to ( _DEVICE )
202
+ model = M (m = 512 , n = 512 ).to (_DEVICE , dtype = torch .bfloat16 )
192
203
model .reset_parameters ()
193
204
194
205
m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
@@ -265,8 +276,70 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
265
276
self .compare_parq_convert (model , m_ref , optimizer , config )
266
277
267
278
279
+ class TestInt8DynamicActivationTorchaoQuantizer (common_utils .TestCase ):
280
+ def setUp (self ):
281
+ torch .manual_seed (123 )
282
+
283
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6+" )
284
+ @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
285
+ @common_utils .parametrize ("model_dtype" , [torch .float16 , torch .float32 ])
286
+ @common_utils .parametrize ("group_size" , [32 , 128 ])
287
+ def test_int8_dynamic_activation_intx_e2e (
288
+ self ,
289
+ b : int = 2 ,
290
+ model_dtype : torch .dtype = torch .float32 ,
291
+ group_size : int = 32 ,
292
+ ):
293
+ model = M (embedding = False ).to (_DEVICE , dtype = model_dtype )
294
+ x = model .example_inputs (device = _DEVICE ).to (model_dtype )
295
+
296
+ # reference model using native quantization
297
+ m_ref = copy .deepcopy (model ).eval ().to (_DEVICE )
298
+ quantizer = UnifTorchaoQuantizer ()
299
+ config = Int8DynamicActivationIntxWeightConfig (
300
+ weight_dtype = _BIT_WIDTH_TO_DTYPE [b ],
301
+ weight_granularity = PerGroup (group_size ),
302
+ weight_mapping_type = quantizer .mapping_type ,
303
+ act_mapping_type = MappingType .ASYMMETRIC ,
304
+ )
305
+ quantize_ (m_ref , config )
306
+ ref_out = m_ref (x )
307
+
308
+ # quantize weights with PARQ
309
+ base_optimizer = torch .optim .SGD (build_param_groups (model , b , group_size ))
310
+ optimizer = QuantOptimizer (
311
+ base_optimizer , quantizer , ProxHardQuant (), quant_per_channel = True
312
+ )
313
+ optimizer .zero_grad ()
314
+ optimizer .step ()
315
+
316
+ # apply torchao quantized activations on top
317
+ activation_config = FakeQuantizeConfig (
318
+ torch .int8 ,
319
+ granularity = "per_token" ,
320
+ mapping_type = config .act_mapping_type ,
321
+ )
322
+ filter_fn = optimizer .get_filter_fn (model )
323
+ quantize_ (
324
+ model ,
325
+ IntXQuantizationAwareTrainingConfig (activation_config = activation_config ),
326
+ filter_fn = filter_fn ,
327
+ )
328
+ out = model (x )
329
+ torch .testing .assert_close (out , ref_out , atol = 0 , rtol = 0 )
330
+
331
+ # equivalent to torchao's convert step
332
+ model .eval ()
333
+ optimizer .restore_latent_params ()
334
+ quantize_ (model , FromIntXQuantizationAwareTrainingConfig (), filter_fn = filter_fn )
335
+ quantize_ (model , config , filter_fn = filter_fn )
336
+ converted_out = model (x )
337
+ torch .testing .assert_close (converted_out , ref_out , atol = 0 , rtol = 0 )
338
+
339
+
268
340
common_utils .instantiate_parametrized_tests (TestPARQuantization )
269
341
common_utils .instantiate_parametrized_tests (TestUnifTorchaoQuantizer )
342
+ common_utils .instantiate_parametrized_tests (TestInt8DynamicActivationTorchaoQuantizer )
270
343
271
344
272
345
if __name__ == "__main__" :
0 commit comments