Skip to content

Commit 488c856

Browse files
authored
Test PARQ with torchao activation quantization (#2370)
* Test PARQ with torchao activation quantization * Replace assertTrue with torch.testing.assert_close
1 parent 0a81ae8 commit 488c856

File tree

2 files changed

+102
-31
lines changed

2 files changed

+102
-31
lines changed

test/prototype/test_parq.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,15 @@
2727
)
2828
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
2929
from torchao.quantization.granularity import PerGroup
30+
from torchao.quantization.qat import (
31+
FakeQuantizeConfig,
32+
FromIntXQuantizationAwareTrainingConfig,
33+
IntXQuantizationAwareTrainingConfig,
34+
)
3035
from torchao.quantization.quant_api import (
36+
Int8DynamicActivationIntxWeightConfig,
3137
IntxWeightOnlyConfig,
38+
MappingType,
3239
_is_linear,
3340
int4_weight_only,
3441
quantize_,
@@ -68,9 +75,9 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
6875

6976

7077
class M(nn.Module):
71-
def __init__(self, m=256, n=128, k=16, bias=False):
78+
def __init__(self, m=256, n=128, k=16, bias=False, embedding=True):
7279
super().__init__()
73-
self.embedding = nn.Embedding(10, m)
80+
self.embedding = nn.Embedding(10, m) if embedding else nn.Identity()
7481
self.linear1 = nn.Linear(m, n, bias=bias)
7582
self.linear2 = nn.Linear(n, k, bias=bias)
7683
self.relu = nn.ReLU()
@@ -83,7 +90,11 @@ def reset_parameters(self):
8390
nn.init.zeros_(module.bias)
8491

8592
def example_inputs(self, device=None):
86-
return torch.randint(1, 10, (1, 256), device=device)
93+
return (
94+
torch.randint(1, 10, (1, self.linear1.in_features), device=device)
95+
if isinstance(self.embedding, nn.Embedding)
96+
else torch.randn(1, self.linear1.in_features, device=device)
97+
)
8798

8899
def forward(self, x):
89100
x = self.embedding(x)
@@ -150,11 +161,11 @@ def compare_quantized_models(
150161
p = p.view(-1, group_size)
151162

152163
q, Q = quantizer.quantize(p, b=b, dim=-1)
153-
q = q.view(original_shape)
154164

155165
# compare to AffineQuantizedTensor instance
166+
q = q.view(original_shape)
156167
ref = getattr(m_ref, n).weight.dequantize()
157-
self.assertTrue(q.equal(ref))
168+
torch.testing.assert_close(q, ref, atol=0, rtol=0)
158169

159170
def compare_parq_convert(
160171
self,
@@ -182,13 +193,13 @@ def compare_parq_convert(
182193
p = module.weight.dequantize() # PARQ weight after quantize_
183194
p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_
184195

185-
self.assertTrue(p_orig.equal(p_ref))
186-
self.assertTrue(p.equal(p_ref))
196+
torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0)
197+
torch.testing.assert_true(p, p_ref, atol=0, rtol=0)
187198

188199
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
189200
@common_utils.parametrize("group_size", [32, 256])
190201
def test_int4_weight_only(self, group_size: int = 32):
191-
model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE)
202+
model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16)
192203
model.reset_parameters()
193204

194205
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
@@ -265,8 +276,70 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
265276
self.compare_parq_convert(model, m_ref, optimizer, config)
266277

267278

279+
class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
280+
def setUp(self):
281+
torch.manual_seed(123)
282+
283+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
284+
@common_utils.parametrize("b", [2, 3, 4, 8])
285+
@common_utils.parametrize("model_dtype", [torch.float16, torch.float32])
286+
@common_utils.parametrize("group_size", [32, 128])
287+
def test_int8_dynamic_activation_intx_e2e(
288+
self,
289+
b: int = 2,
290+
model_dtype: torch.dtype = torch.float32,
291+
group_size: int = 32,
292+
):
293+
model = M(embedding=False).to(_DEVICE, dtype=model_dtype)
294+
x = model.example_inputs(device=_DEVICE).to(model_dtype)
295+
296+
# reference model using native quantization
297+
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
298+
quantizer = UnifTorchaoQuantizer()
299+
config = Int8DynamicActivationIntxWeightConfig(
300+
weight_dtype=_BIT_WIDTH_TO_DTYPE[b],
301+
weight_granularity=PerGroup(group_size),
302+
weight_mapping_type=quantizer.mapping_type,
303+
act_mapping_type=MappingType.ASYMMETRIC,
304+
)
305+
quantize_(m_ref, config)
306+
ref_out = m_ref(x)
307+
308+
# quantize weights with PARQ
309+
base_optimizer = torch.optim.SGD(build_param_groups(model, b, group_size))
310+
optimizer = QuantOptimizer(
311+
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
312+
)
313+
optimizer.zero_grad()
314+
optimizer.step()
315+
316+
# apply torchao quantized activations on top
317+
activation_config = FakeQuantizeConfig(
318+
torch.int8,
319+
granularity="per_token",
320+
mapping_type=config.act_mapping_type,
321+
)
322+
filter_fn = optimizer.get_filter_fn(model)
323+
quantize_(
324+
model,
325+
IntXQuantizationAwareTrainingConfig(activation_config=activation_config),
326+
filter_fn=filter_fn,
327+
)
328+
out = model(x)
329+
torch.testing.assert_close(out, ref_out, atol=0, rtol=0)
330+
331+
# equivalent to torchao's convert step
332+
model.eval()
333+
optimizer.restore_latent_params()
334+
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
335+
quantize_(model, config, filter_fn=filter_fn)
336+
converted_out = model(x)
337+
torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0)
338+
339+
268340
common_utils.instantiate_parametrized_tests(TestPARQuantization)
269341
common_utils.instantiate_parametrized_tests(TestUnifTorchaoQuantizer)
342+
common_utils.instantiate_parametrized_tests(TestInt8DynamicActivationTorchaoQuantizer)
270343

271344

272345
if __name__ == "__main__":

torchao/prototype/parq/quant/uniform_torchao.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,23 @@ def __init__(
5050
self.quant_min = quant_min
5151
self.quant_max = quant_max
5252
self.eps = eps
53-
self.preserve_zero = preserve_zero
54-
self.zero_point_domain = zero_point_domain
53+
54+
# defaults: zero_point_domain=ZeroPointDomain.INT, preserve_zero=True
55+
self._choose_qparams = choose_qparams_affine
56+
self._quantize = quantize_affine
57+
self._dequantize = dequantize_affine
58+
59+
if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
60+
self._choose_qparams = choose_qparams_affine_tinygemm
61+
self._quantize = quantize_affine_tinygemm
62+
self._dequantize = dequantize_affine_tinygemm
63+
elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero:
64+
self._choose_qparams = choose_qparams_affine_dont_preserve_zero
65+
self._quantize = quantize_affine
66+
self._dequantize = dequantize_affine
67+
elif zero_point_domain == ZeroPointDomain.NONE:
68+
self._quantize = quantize_affine_no_zero_point
69+
self._dequantize = dequantize_affine_no_zero_point
5570

5671
def _init_quant_min_max(self, b: int) -> None:
5772
if self.quant_min is None or self.quant_max is None:
@@ -74,24 +89,7 @@ def quantize(
7489
# assume that p has already been grouped in QuantOptimizer.step
7590
block_size = (1, p.size(-1)) if dim is not None else p.size()
7691

77-
if self.zero_point_domain == ZeroPointDomain.FLOAT and not self.preserve_zero:
78-
_choose_qparams_affine = choose_qparams_affine_tinygemm
79-
_quantize_affine = quantize_affine_tinygemm
80-
_dequantize_affine = dequantize_affine_tinygemm
81-
elif self.zero_point_domain == ZeroPointDomain.INT and not self.preserve_zero:
82-
_choose_qparams_affine = choose_qparams_affine_dont_preserve_zero
83-
_quantize_affine = quantize_affine
84-
_dequantize_affine = dequantize_affine
85-
else: # Default case: zero_point_domain == ZeroPointDomain.INT/NONE and preserve_zero
86-
_choose_qparams_affine = choose_qparams_affine
87-
if self.zero_point_domain == ZeroPointDomain.INT:
88-
_quantize_affine = quantize_affine
89-
_dequantize_affine = dequantize_affine
90-
else:
91-
_quantize_affine = quantize_affine_no_zero_point
92-
_dequantize_affine = dequantize_affine_no_zero_point
93-
94-
s, zero_point = _choose_qparams_affine(
92+
s, zero_point = self._choose_qparams(
9593
p,
9694
self.mapping_type,
9795
block_size,
@@ -101,13 +99,13 @@ def quantize(
10199
quant_max=self.quant_max,
102100
)
103101
q_args = (block_size, s, zero_point, self.target_dtype)
104-
q = _quantize_affine(
102+
q = self._quantize(
105103
p,
106104
*q_args,
107105
quant_min=self.quant_min,
108106
quant_max=self.quant_max,
109107
)
110-
q = _dequantize_affine(
108+
q = self._dequantize(
111109
q,
112110
*q_args,
113111
output_dtype=p.dtype,
@@ -124,7 +122,7 @@ def quantize(
124122
else:
125123
block_size = Q.shape
126124

127-
Q = _dequantize_affine(
125+
Q = self._dequantize(
128126
Q,
129127
block_size,
130128
*q_args[1:],

0 commit comments

Comments
 (0)