From d22cfbf868e32e6e392241104387d0e98d07a85f Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 10 Jun 2025 12:14:02 -0700 Subject: [PATCH 1/4] Convert public methods to private --- docs/source/api_ref_quantization.rst | 12 +- test/dtypes/test_affine_quantized_float.py | 24 ++-- test/dtypes/test_floatx.py | 8 +- test/prototype/test_gguf_quant.py | 6 +- test/quantization/test_marlin_qqq.py | 4 +- test/quantization/test_qat.py | 4 +- test/quantization/test_quant_primitives.py | 16 +-- test/test_ops.py | 6 +- torchao/dtypes/affine_quantized_tensor.py | 58 ++++----- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- .../floatx/floatx_tensor_core_layout.py | 2 +- torchao/dtypes/uintx/int4_cpu_layout.py | 4 +- torchao/dtypes/uintx/int4_xpu_layout.py | 4 +- torchao/dtypes/uintx/marlin_qqq_tensor.py | 12 +- .../dtypes/uintx/tensor_core_tiled_layout.py | 4 +- .../prototype/parq/quant/uniform_torchao.py | 24 ++-- .../gguf/gguf_quantized_tensor.py | 12 +- torchao/quantization/__init__.py | 32 ++--- .../qat/affine_fake_quantized_tensor.py | 12 +- torchao/quantization/qat/utils.py | 6 +- torchao/quantization/quant_primitives.py | 118 +++++++++--------- torchao/quantization/utils.py | 24 ++-- tutorials/calibration_flow/gptq_like.py | 6 +- 23 files changed, 205 insertions(+), 201 deletions(-) diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index 5293684ab9..ba8fa4269d 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -63,14 +63,14 @@ Quantization Primitives choose_qparams_affine choose_qparams_affine_with_min_max - choose_qparams_affine_floatx + _choose_qparams_affine_floatx quantize_affine - quantize_affine_floatx + _quantize_affine_floatx dequantize_affine - dequantize_affine_floatx - choose_qparams_and_quantize_affine_hqq - fake_quantize_affine - fake_quantize_affine_cachemask + _dequantize_affine_floatx + _choose_qparams_and_quantize_affine_hqq + _fake_quantize_affine + _fake_quantize_affine_cachemask safe_int_mm int_scaled_matmul MappingType diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 879551fc0a..1efa900efc 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -42,10 +42,10 @@ ) from torchao.quantization.quant_primitives import ( MappingType, + _choose_qparams_affine_float8, + _dequantize_affine_float8, + _quantize_affine_float8, choose_qparams_affine, - choose_qparams_affine_float8, - dequantize_affine_float8, - quantize_affine_float8, ) from torchao.utils import ( is_sm_at_least_89, @@ -357,22 +357,22 @@ def test_mm_float8dq_per_row( @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) @common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)]) - def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): - """Test dequantize_affine_float8 with various configurations""" + def test__dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): + """Test _dequantize_affine_float8 with various configurations""" device = "cuda" input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) # Choose quantization parameters - scale = choose_qparams_affine_float8( + scale = _choose_qparams_affine_float8( input_tensor, float8_dtype=float8_dtype, block_size=block_size ) # Quantize - quantized = quantize_affine_float8(input_tensor, scale, float8_dtype) + quantized = _quantize_affine_float8(input_tensor, scale, float8_dtype) # Dequantize - dequantized = dequantize_affine_float8(quantized, scale, output_dtype) + dequantized = _dequantize_affine_float8(quantized, scale, output_dtype) # Verify output properties self.assertEqual(dequantized.dtype, output_dtype) @@ -387,7 +387,7 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - def test_dequantize_affine_float8_scale_broadcasting(self): + def test__dequantize_affine_float8_scale_broadcasting(self): """Test that scale broadcasting works correctly for block-wise quantization""" device = "cuda" # Create input tensor with known block structure @@ -395,7 +395,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self): block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim # Choose quantization parameters - scale = choose_qparams_affine_float8( + scale = _choose_qparams_affine_float8( input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size ) @@ -407,10 +407,10 @@ def test_dequantize_affine_float8_scale_broadcasting(self): self.assertEqual(scale.shape, expected_scale_shape) # Quantize - quantized = quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn) + quantized = _quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn) # Dequantize - dequantized = dequantize_affine_float8(quantized, scale, torch.float32) + dequantized = _dequantize_affine_float8(quantized, scale, torch.float32) # Verify shapes match self.assertEqual(dequantized.shape, input_tensor.shape) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 956ef9a03e..237bc2bd92 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -91,13 +91,13 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device): @parametrize("ebits,mbits", _Floatx_DTYPES) def test_to_copy_device(self, ebits, mbits): from torchao.quantization.quant_primitives import ( - choose_qparams_affine_floatx, - quantize_affine_floatx, + _choose_qparams_affine_floatx, + _quantize_affine_floatx, ) x = torch.randn(256, 64) - scale = choose_qparams_affine_floatx(x, ebits, mbits) - x = quantize_affine_floatx(x, scale, ebits, mbits) + scale = _choose_qparams_affine_floatx(x, ebits, mbits) + x = _quantize_affine_floatx(x, scale, ebits, mbits) _layout = FloatxTensorCoreLayout(ebits, mbits) floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain( x, scale, None, _layout diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py index b68d84b101..53ffcb5c60 100644 --- a/test/prototype/test_gguf_quant.py +++ b/test/prototype/test_gguf_quant.py @@ -13,7 +13,7 @@ GGUFWeightOnlyConfig, ) from torchao.quantization import quantize_ -from torchao.quantization.quant_primitives import choose_qparams_gguf +from torchao.quantization.quant_primitives import _choose_qparams_gguf from torchao.quantization.utils import compute_error @@ -25,13 +25,13 @@ def setUp(self): self.block_size = (1, 32) self.dtype = torch.uint4 - def test_choose_qparams_gguf(self): + def test__choose_qparams_gguf(self): ( super_block_scale_scale, super_block_min_scale, quantized_block_scale, quantized_block_min, - ) = choose_qparams_gguf(self.input, self.block_size, self.dtype) + ) = _choose_qparams_gguf(self.input, self.block_size, self.dtype) assert super_block_scale_scale.shape, (2, 8) assert super_block_min_scale.shape, (2, 8) diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index cff46ad329..8fe21c6bd3 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -21,7 +21,7 @@ ) from torchao.quantization.quant_primitives import ( MappingType, - choose_qparams_and_quantize_affine_qqq, + _choose_qparams_and_quantize_affine_qqq, ) from torchao.testing.utils import skip_if_rocm from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -102,7 +102,7 @@ def test_pack_unpack_equivalence(self): for group_size in [-1, 128]: # Quantize weights - q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + q_w, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq( w, num_bits, group_size ) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 323802757d..f0404a2ac2 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -64,9 +64,9 @@ MappingType, TorchAODType, ZeroPointDomain, + _fake_quantize_affine, choose_qparams_affine, dequantize_affine, - fake_quantize_affine, quantize_affine, ) from torchao.quantization.unified import ( @@ -637,7 +637,7 @@ def test_qat_4w_primitives(self): group_size, scales_precision, ) - w_fq = fake_quantize_affine( + w_fq = _fake_quantize_affine( weight, block_size, scales, diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index e69d68b27f..ae0fc9987f 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -13,11 +13,11 @@ from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, + _choose_qparams_affine_tinygemm, + _fake_quantize_affine, + _fake_quantize_affine_cachemask, choose_qparams_affine, - choose_qparams_affine_tinygemm, dequantize_affine, - fake_quantize_affine, - fake_quantize_affine_cachemask, quantize_affine, ) @@ -672,7 +672,7 @@ def test_get_groupwise_affine_qparams(self): zero_point_domain=zero_point_domain, ) if zero_point_domain == ZeroPointDomain.FLOAT: - scale, zero_point = choose_qparams_affine_tinygemm( + scale, zero_point = _choose_qparams_affine_tinygemm( input, mapping_type, block_size, @@ -752,7 +752,7 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test_fake_quantize_affine(self): + def test__fake_quantize_affine(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC @@ -780,7 +780,7 @@ def test_fake_quantize_affine(self): dequantized = dequantize_affine( quantized, block_size, scale, zero_point, dtype, quant_min, quant_max ) - fake_quantized = fake_quantize_affine( + fake_quantized = _fake_quantize_affine( input, block_size, scale, zero_point, dtype, quant_min, quant_max ) torch.testing.assert_close(dequantized, fake_quantized) @@ -788,7 +788,7 @@ def test_fake_quantize_affine(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test_fake_quantize_affine_cachemask(self): + def test__fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC @@ -816,7 +816,7 @@ def test_fake_quantize_affine_cachemask(self): dequantized = dequantize_affine( quantized, block_size, scale, zero_point, dtype, quant_min, quant_max ) - (fake_quantized, mask) = fake_quantize_affine_cachemask( + (fake_quantized, mask) = _fake_quantize_affine_cachemask( input, block_size, scale, diff --git a/test/test_ops.py b/test/test_ops.py index 012a4d562d..faec689a69 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -23,7 +23,9 @@ marlin_qqq_workspace, pack_to_marlin_qqq, ) -from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq +from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_affine_qqq, +) from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -713,7 +715,7 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) # Quantize weights - q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq( + q_w, s_group, s_channel, w_ref = _choose_qparams_and_quantize_affine_qqq( b_weight, num_bits, group_size ) q_w = q_w.t() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 132ac0f28e..39f9131a9e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -18,22 +18,22 @@ FP8_TYPES, MappingType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_float8, + _choose_qparams_affine_floatx, + _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_affine_hqq, + _dequantize_affine_float8, + _dequantize_affine_floatx, + _dequantize_affine_no_zero_point, + _dequantize_affine_tinygemm, + _quantize_affine_float8, + _quantize_affine_floatx, + _quantize_affine_no_zero_point, + _quantize_affine_tinygemm, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_float8, - choose_qparams_affine_floatx, - choose_qparams_affine_tinygemm, - choose_qparams_and_quantize_affine_hqq, dequantize_affine, - dequantize_affine_float8, - dequantize_affine_floatx, - dequantize_affine_no_zero_point, - dequantize_affine_tinygemm, quantize_affine, - quantize_affine_float8, - quantize_affine_floatx, - quantize_affine_no_zero_point, - quantize_affine_tinygemm, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -142,7 +142,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if isinstance(self._layout, FloatxTensorCoreLayout): int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx( + return _dequantize_affine_floatx( int_data, scale, self._layout.ebits, @@ -151,11 +151,11 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor ) elif isinstance(self._layout, Float8Layout): data, scale, _ = self.tensor_impl.get_plain() - return dequantize_affine_float8(data, scale, output_dtype) + return _dequantize_affine_float8(data, scale, output_dtype) else: data, scale, zero_point = self.tensor_impl.get_plain() if self.zero_point_domain == ZeroPointDomain.FLOAT: - dq = dequantize_affine_tinygemm( + dq = _dequantize_affine_tinygemm( data, self.block_size, scale, @@ -166,7 +166,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor output_dtype=output_dtype, ) elif self.zero_point_domain == ZeroPointDomain.NONE: - dq = dequantize_affine_no_zero_point( + dq = _dequantize_affine_no_zero_point( data, self.block_size, scale, @@ -270,7 +270,7 @@ def from_hp_to_intx( from torchao.dtypes import Int4CPULayout from torchao.dtypes.uintx import TensorCoreTiledLayout - data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq( + data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq( input_float, nbits=nbits, group_size=group_size, @@ -291,7 +291,7 @@ def from_hp_to_intx( data = data.to(target_dtype) else: if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - scale, zero_point = choose_qparams_affine_tinygemm( + scale, zero_point = _choose_qparams_affine_tinygemm( input_float, mapping_type, block_size, @@ -303,7 +303,7 @@ def from_hp_to_intx( zero_point_dtype, ) elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - scale, zero_point = choose_qparams_affine_dont_preserve_zero( + scale, zero_point = _choose_qparams_affine_dont_preserve_zero( input_float, mapping_type, block_size, @@ -329,7 +329,7 @@ def from_hp_to_intx( # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None if zero_point_domain == ZeroPointDomain.NONE: zero_point = None - data = quantize_affine_no_zero_point( + data = _quantize_affine_no_zero_point( input_float, block_size, scale, @@ -339,7 +339,7 @@ def from_hp_to_intx( quant_max, ) elif zero_point_domain == ZeroPointDomain.FLOAT: - data = quantize_affine_tinygemm( + data = _quantize_affine_tinygemm( input_float, block_size, scale, @@ -400,7 +400,7 @@ def from_hp_to_intx_static( if zero_point_domain == ZeroPointDomain.NONE: zero_point = None - int_data = quantize_affine_no_zero_point( + int_data = _quantize_affine_no_zero_point( input_float, block_size, scale, @@ -410,7 +410,7 @@ def from_hp_to_intx_static( quant_max, ) elif zero_point_domain == ZeroPointDomain.FLOAT: - int_data = quantize_affine_tinygemm( + int_data = _quantize_affine_tinygemm( input_float, block_size, scale, @@ -462,10 +462,10 @@ def from_hp_to_floatx( if target_dtype in FP8_TYPES: original_shape = input_float.shape input_float = _layout.pre_process(input_float) - scale = choose_qparams_affine_float8( + scale = _choose_qparams_affine_float8( input_float, float8_dtype=target_dtype, block_size=block_size ) - data = quantize_affine_float8(input_float, scale, target_dtype) + data = _quantize_affine_float8(input_float, scale, target_dtype) data, scale, zero_point = _layout.post_process( data, scale, None, block_size ) @@ -499,7 +499,7 @@ def from_hp_to_floatx_static( input_float, scale, ZeroPointDomain.NONE, block_size ) - data = quantize_affine_float8( + data = _quantize_affine_float8( input_float, scale, target_dtype, @@ -545,8 +545,8 @@ def from_hp_to_fpx( ebits, mbits = _layout.ebits, _layout.mbits # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = choose_qparams_affine_floatx(input_float, ebits, mbits) - floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) + scale = _choose_qparams_affine_floatx(input_float, ebits, mbits) + floatx_unpacked = _quantize_affine_floatx(input_float, scale, ebits, mbits) floatx_packed, scale, _ = _layout.post_process( floatx_unpacked, scale, None, block_size ) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index a76b4daa23..02a2d3004a 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -92,9 +92,9 @@ ) from torchao.quantization.quant_primitives import ( ZeroPointDomain, + _dequantize_affine_no_zero_point, + _dequantize_affine_tinygemm, dequantize_affine, - dequantize_affine_no_zero_point, - dequantize_affine_tinygemm, ) from torchao.utils import ( fill_defaults, @@ -318,9 +318,9 @@ def _(func, types, args, kwargs): # we need to increase block size to correct dim new_blocks = idx.dim() - 1 if args[1].zero_point_domain == ZeroPointDomain.FLOAT: - _dequantize_affine = dequantize_affine_tinygemm + _dequantize_affine = _dequantize_affine_tinygemm elif args[1].zero_point_domain == ZeroPointDomain.NONE: - _dequantize_affine = dequantize_affine_no_zero_point + _dequantize_affine = _dequantize_affine_no_zero_point else: _dequantize_affine = dequantize_affine diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 6871033f1a..c7fb1e1a7c 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -467,7 +467,7 @@ class FloatxTensorCoreLayout(Layout): class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), it has a internal tensor field of "packed_floatx_data", which is packed from the - uint8 unpacked data (the output of `quantize_affine_floatx` operator) + uint8 unpacked data (the output of `_quantize_affine_floatx` operator) The packing is optimized for TensorCore, from the fp6-llm paper: https://arxiv.org/abs/2401.14112 github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 6c89f98ff7..bf9446d265 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -19,7 +19,7 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ( ZeroPointDomain, - quantize_affine_tinygemm, + _quantize_affine_tinygemm, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -266,7 +266,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine_tinygemm( + int_data = _quantize_affine_tinygemm( dequantized, block_size, scale, diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index c67eebd747..955a7a8610 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -377,8 +377,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( + _quantize_affine_tinygemm, quantize_affine, - quantize_affine_tinygemm, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros @@ -429,7 +429,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine_tinygemm( + int_data = _quantize_affine_tinygemm( dequantized, block_size, scale, diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 3f3f4fa075..04066a6c65 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -24,8 +24,8 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout from torchao.quantization.quant_primitives import ( ZeroPointDomain, - choose_qparams_and_quantize_affine_qqq, - dequantize_affine_qqq, + _choose_qparams_and_quantize_affine_qqq, + _dequantize_affine_qqq, ) logger = logging.getLogger(__name__) @@ -36,9 +36,9 @@ class MarlinQQQTensor(AffineQuantizedTensor): """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. - To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, + To see what happens during _choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq + and check the two quant primitive ops: _choose_qparams_and_quantize_affine_qqq and _dequantize_affine_qqq """ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: @@ -48,7 +48,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor int_data, s_group, s_channel = self.tensor_impl.get_plain() nbits = int(math.log2(self.quant_max - self.quant_min + 1)) group_size = max(self.block_size) - return dequantize_affine_qqq( + return _dequantize_affine_qqq( int_data, s_group, s_channel, nbits, group_size, output_dtype ) @@ -69,7 +69,7 @@ def from_hp_to_intx( input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) group_size = max(block_size) - data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + data, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq( input_float, nbits, group_size ) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 0856d22fee..591d9a9be1 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -21,7 +21,7 @@ from torchao.quantization.quant_primitives import ( ZeroPointDomain, _get_reduction_params, - quantize_affine_tinygemm, + _quantize_affine_tinygemm, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -511,7 +511,7 @@ def dequant_4d(self): target_dtype = torch.int32 quant_min = 0 quant_max = 15 - int_data = quantize_affine_tinygemm( + int_data = _quantize_affine_tinygemm( dequantized, self.block_size, scale, diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index a71ac8b5b3..577a7ef499 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -14,15 +14,15 @@ _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_tinygemm, + _dequantize_affine_no_zero_point, + _dequantize_affine_tinygemm, + _quantize_affine_no_zero_point, + _quantize_affine_tinygemm, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_tinygemm, dequantize_affine, - dequantize_affine_no_zero_point, - dequantize_affine_tinygemm, quantize_affine, - quantize_affine_no_zero_point, - quantize_affine_tinygemm, ) from .quantizer import Quantizer @@ -75,11 +75,11 @@ def quantize( block_size = (1, p.size(-1)) if dim is not None else p.size() if self.zero_point_domain == ZeroPointDomain.FLOAT and not self.preserve_zero: - _choose_qparams_affine = choose_qparams_affine_tinygemm - _quantize_affine = quantize_affine_tinygemm - _dequantize_affine = dequantize_affine_tinygemm + _choose_qparams_affine = _choose_qparams_affine_tinygemm + _quantize_affine = _quantize_affine_tinygemm + _dequantize_affine = _dequantize_affine_tinygemm elif self.zero_point_domain == ZeroPointDomain.INT and not self.preserve_zero: - _choose_qparams_affine = choose_qparams_affine_dont_preserve_zero + _choose_qparams_affine = _choose_qparams_affine_dont_preserve_zero _quantize_affine = quantize_affine _dequantize_affine = dequantize_affine else: # Default case: zero_point_domain == ZeroPointDomain.INT/NONE and preserve_zero @@ -88,8 +88,8 @@ def quantize( _quantize_affine = quantize_affine _dequantize_affine = dequantize_affine else: - _quantize_affine = quantize_affine_no_zero_point - _dequantize_affine = dequantize_affine_no_zero_point + _quantize_affine = _quantize_affine_no_zero_point + _dequantize_affine = _dequantize_affine_no_zero_point s, zero_point = _choose_qparams_affine( p, diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index 9757769d16..c1272fceb6 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -10,9 +10,9 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.quant_primitives import ( - choose_qparams_gguf, - dequantize_gguf, - quantize_gguf, + _choose_qparams_gguf, + _dequantize_gguf, + _quantize_gguf, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -130,7 +130,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor block_size = tuple( [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_blocks_per_superblock] ) - return dequantize_gguf( + return _dequantize_gguf( self.int_data, block_size, self.dtype, @@ -198,9 +198,9 @@ def from_float(cls, input_float, n_blocks_per_superblock, target_dtype): super_block_min_scale, quantized_block_scale, quantized_block_min, - ) = choose_qparams_gguf(input_float, block_size, target_dtype) + ) = _choose_qparams_gguf(input_float, block_size, target_dtype) - int_data = quantize_gguf( + int_data = _quantize_gguf( input_float, block_size, target_dtype, diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 44fc6c8397..4fd3bb1d6c 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -82,18 +82,18 @@ MappingType, TorchAODType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_floatx, + _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_affine_hqq, + _dequantize_affine_floatx, + _fake_quantize_affine, + _fake_quantize_affine_cachemask, + _quantize_affine_floatx, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_floatx, - choose_qparams_affine_tinygemm, choose_qparams_affine_with_min_max, - choose_qparams_and_quantize_affine_hqq, dequantize_affine, - dequantize_affine_floatx, - fake_quantize_affine, - fake_quantize_affine_cachemask, quantize_affine, - quantize_affine_floatx, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -172,17 +172,17 @@ "AffineQuantizedObserverBase", # quant primitive ops "choose_qparams_affine", - "choose_qparams_affine_tinygemm", - "choose_qparams_affine_dont_preserve_zero", + "_choose_qparams_affine_tinygemm", + "_choose_qparams_affine_dont_preserve_zero", "choose_qparams_affine_with_min_max", - "choose_qparams_affine_floatx", + "_choose_qparams_affine_floatx", "quantize_affine", - "quantize_affine_floatx", + "_quantize_affine_floatx", "dequantize_affine", - "dequantize_affine_floatx", - "choose_qparams_and_quantize_affine_hqq", - "fake_quantize_affine", - "fake_quantize_affine_cachemask", + "_dequantize_affine_floatx", + "_choose_qparams_and_quantize_affine_hqq", + "_fake_quantize_affine", + "_fake_quantize_affine_cachemask", # operators/kernels "safe_int_mm", "int_scaled_matmul", diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py index 6896588971..80ecd173c2 100644 --- a/torchao/quantization/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py @@ -12,11 +12,11 @@ from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_tinygemm, + _fake_quantize_affine, _get_and_check_qmin_qmax, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_tinygemm, - fake_quantize_affine, ) from torchao.utils import TorchAOBaseTensor @@ -55,7 +55,7 @@ def apply_fake_quant_fn(t: torch.Tensor): assert isinstance(t, AffineFakeQuantizedTensor) qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - scale, zero_point = choose_qparams_affine_tinygemm( + scale, zero_point = _choose_qparams_affine_tinygemm( t.original_tensor, mapping_type, block_size, @@ -67,7 +67,7 @@ def apply_fake_quant_fn(t: torch.Tensor): zero_point_dtype, ) elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - scale, zero_point = choose_qparams_affine_dont_preserve_zero( + scale, zero_point = _choose_qparams_affine_dont_preserve_zero( t.original_tensor, mapping_type, block_size, @@ -90,7 +90,7 @@ def apply_fake_quant_fn(t: torch.Tensor): scale_dtype, zero_point_dtype, ) - fq = fake_quantize_affine( + fq = _fake_quantize_affine( t, block_size, scale, diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index 132020499c..4f3323a1e8 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -9,7 +9,7 @@ from torchao.quantization.quant_primitives import ( ZeroPointDomain, - fake_quantize_affine, + _fake_quantize_affine, ) from torchao.quantization.utils import ( _get_per_token_block_size, @@ -87,7 +87,7 @@ def _fake_quantize_per_channel_group( assert input.shape[-1] % group_size == 0 assert input.dim() == 2 block_size = (1, group_size) - return fake_quantize_affine( + return _fake_quantize_affine( input, block_size, scales, @@ -110,7 +110,7 @@ def _fake_quantize_per_token( _per_token_quant_qparam_dim_check(input, scales, zero_points) block_size = _get_per_token_block_size(input) - fq = fake_quantize_affine( + fq = _fake_quantize_affine( input, block_size, scales, diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 9e0c6447c8..b99aafbc57 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -24,32 +24,32 @@ __all__ = [ "choose_qparams_affine", - "choose_qparams_affine_tinygemm", - "choose_qparams_affine_dont_preserve_zero", + "_choose_qparams_affine_tinygemm", + "_choose_qparams_affine_dont_preserve_zero", "choose_qparams_affine_with_min_max", - "choose_qparams_affine_floatx", + "_choose_qparams_affine_floatx", "quantize_affine", - "quantize_affine_no_zero_point", - "quantize_affine_tinygemm", + "_quantize_affine_no_zero_point", + "_quantize_affine_tinygemm", "dequantize_affine", - "dequantize_affine_no_zero_point", - "dequantize_affine_tinygemm", - "quantize_affine_floatx", - "dequantize_affine_floatx", - "fake_quantize_affine", - "fake_quantize_affine_cachemask", - "choose_qparams_and_quantize_affine_hqq", - "choose_qparams_and_quantize_affine_qqq", - "dequantize_affine_qqq", + "_dequantize_affine_no_zero_point", + "_dequantize_affine_tinygemm", + "_quantize_affine_floatx", + "_dequantize_affine_floatx", + "_fake_quantize_affine", + "_fake_quantize_affine_cachemask", + "_choose_qparams_and_quantize_affine_hqq", + "_choose_qparams_and_quantize_affine_qqq", + "_dequantize_affine_qqq", "MappingType", "ZeroPointDomain", "TorchAODType", - "choose_qparams_affine_float8", - "quantize_affine_float8", - "dequantize_affine_float8", - "choose_qparams_gguf", - "quantize_gguf", - "dequantize_gguf", + "_choose_qparams_affine_float8", + "_quantize_affine_float8", + "_dequantize_affine_float8", + "_choose_qparams_gguf", + "_quantize_gguf", + "_dequantize_gguf", ] @@ -428,7 +428,7 @@ def _quantize_affine_no_dtype_cast( return quant -def quantize_affine_tinygemm( +def _quantize_affine_tinygemm( input: torch.Tensor, block_size: List[int], scale: torch.Tensor, @@ -453,7 +453,7 @@ def quantize_affine_tinygemm( # torch.uintx dtypes yet if output_dtype in _SUB_BYTE_UINT_BOUNDS: output_dtype = torch.uint8 - return _quantize_affine_tinygemm_no_dtype_cast( + return __quantize_affine_tinygemm_no_dtype_cast( input, block_size, scale, @@ -463,7 +463,7 @@ def quantize_affine_tinygemm( ).to(output_dtype) -def _quantize_affine_tinygemm_no_dtype_cast( +def __quantize_affine_tinygemm_no_dtype_cast( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -513,7 +513,7 @@ def _quantize_affine_tinygemm_no_dtype_cast( return quant -def quantize_affine_no_zero_point( +def _quantize_affine_no_zero_point( input: torch.Tensor, block_size: List[int], scale: torch.Tensor, @@ -539,7 +539,7 @@ def quantize_affine_no_zero_point( # torch.uintx dtypes yet if output_dtype in _SUB_BYTE_UINT_BOUNDS: output_dtype = torch.uint8 - return _quantize_affine_no_zero_point_no_dtype_cast( + return __quantize_affine_no_zero_point_no_dtype_cast( input, block_size, scale, @@ -549,7 +549,7 @@ def quantize_affine_no_zero_point( ).to(output_dtype) -def _quantize_affine_no_zero_point_no_dtype_cast( +def __quantize_affine_no_zero_point_no_dtype_cast( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -714,7 +714,7 @@ def _dequantize_affine_no_dtype_check( return dequant.view(original_shape).to(output_dtype) -def _dequantize_affine_no_zero_point_no_dtype_check( +def __dequantize_affine_no_zero_point_no_dtype_check( input: torch.Tensor, block_size: List[int], scale: torch.Tensor, @@ -745,7 +745,7 @@ def _dequantize_affine_no_zero_point_no_dtype_check( scale = scale.view(shape_after_reduction) assert zero_point is None, ( - "zero_point should be None for dequantize_affine_no_zero_point" + "zero_point should be None for _dequantize_affine_no_zero_point" ) dequant = input.to(output_dtype) dequant = dequant * scale @@ -753,7 +753,7 @@ def _dequantize_affine_no_zero_point_no_dtype_check( return dequant.view(original_shape).to(output_dtype) -def dequantize_affine_no_zero_point( +def _dequantize_affine_no_zero_point( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -792,7 +792,7 @@ def dequantize_affine_no_zero_point( torch.bfloat16, ], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) - return _dequantize_affine_no_zero_point_no_dtype_check( + return __dequantize_affine_no_zero_point_no_dtype_check( input, block_size, scale, @@ -803,7 +803,7 @@ def dequantize_affine_no_zero_point( ) -def _dequantize_affine_tinygemm_no_dtype_check( +def __dequantize_affine_tinygemm_no_dtype_check( input: torch.Tensor, block_size: List[int], scale: torch.Tensor, @@ -848,7 +848,7 @@ def _dequantize_affine_tinygemm_no_dtype_check( return dequant.view(original_shape).to(output_dtype) -def dequantize_affine_tinygemm( +def _dequantize_affine_tinygemm( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -887,7 +887,7 @@ def dequantize_affine_tinygemm( torch.bfloat16, ], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) - return _dequantize_affine_tinygemm_no_dtype_check( + return __dequantize_affine_tinygemm_no_dtype_check( input, block_size, scale, @@ -898,7 +898,7 @@ def dequantize_affine_tinygemm( ) -def fake_quantize_affine( +def _fake_quantize_affine( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -933,7 +933,7 @@ def fake_quantize_affine( raise ValueError("Please use ZeroPointDomain.NONE instead of None") elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: raise ValueError("zero_point should be None when zero_point_domain is NONE") - (_, fq) = _do_fake_quantize_affine( + (_, fq) = _do__fake_quantize_affine( input, block_size, scale, @@ -946,7 +946,7 @@ def fake_quantize_affine( return fq -def fake_quantize_affine_cachemask( +def _fake_quantize_affine_cachemask( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -961,12 +961,12 @@ def fake_quantize_affine_cachemask( This is equivalent to calling `quantize_affine` + `dequantize_affine` but without the dtype casts. - Note: Compared to :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`, + Note: Compared to :func:`~torchao.quantization.quant_primitives._fake_quantize_affine`, this consumes more memory and returns an additional outlier mask for intermediate quantized values. Args: - Same as :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`. + Same as :func:`~torchao.quantization.quant_primitives._fake_quantize_affine`. Returns: A 2-tuple of ( @@ -979,7 +979,7 @@ def fake_quantize_affine_cachemask( raise ValueError("Please use ZeroPointDomain.NONE instead of None") elif zero_point_domain is None and zero_point is not None: raise ValueError("zero_point should be None when zero_point_domain is NONE") - (q, dq) = _do_fake_quantize_affine( + (q, dq) = _do__fake_quantize_affine( input, block_size, scale, @@ -993,7 +993,7 @@ def fake_quantize_affine_cachemask( return (dq, mask) -def _do_fake_quantize_affine( +def _do__fake_quantize_affine( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -1004,7 +1004,7 @@ def _do_fake_quantize_affine( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Helper function for `fake_quantize_affine` that returns both the + Helper function for `_fake_quantize_affine` that returns both the intermediate quantized values and the final dequantized values. """ input_dtype = input.dtype @@ -1013,11 +1013,11 @@ def _do_fake_quantize_affine( _quantize_affine = _quantize_affine_no_dtype_cast _dequantize_affine = _dequantize_affine_no_dtype_check elif zero_point_domain == ZeroPointDomain.FLOAT: - _quantize_affine = _quantize_affine_tinygemm_no_dtype_cast - _dequantize_affine = _dequantize_affine_tinygemm_no_dtype_check + _quantize_affine = __quantize_affine_tinygemm_no_dtype_cast + _dequantize_affine = __dequantize_affine_tinygemm_no_dtype_check elif ZeroPointDomain == ZeroPointDomain.NONE: - _quantize_affine = _quantize_affine_no_zero_point_no_dtype_cast - _dequantize_affine = _dequantize_affine_no_zero_point_no_dtype_check + _quantize_affine = __quantize_affine_no_zero_point_no_dtype_cast + _dequantize_affine = __dequantize_affine_no_zero_point_no_dtype_check else: raise ValueError(f"Unrecognized zero point domain: {zero_point_domain}") q = _quantize_affine( @@ -1086,7 +1086,7 @@ def choose_qparams_affine( # TODO: lower this op to custom op library @torch.no_grad() -def choose_qparams_affine_tinygemm( +def _choose_qparams_affine_tinygemm( input: torch.Tensor, mapping_type: MappingType, block_size: Tuple[int], @@ -1157,7 +1157,7 @@ def choose_qparams_affine_tinygemm( # TODO: lower this op to custom op library -def choose_qparams_affine_dont_preserve_zero( +def _choose_qparams_affine_dont_preserve_zero( input: torch.Tensor, mapping_type: MappingType, block_size: Tuple[int], @@ -1427,7 +1427,7 @@ def _choose_qparams_affine( ) -def choose_qparams_and_quantize_affine_qqq( +def _choose_qparams_and_quantize_affine_qqq( w: torch.Tensor, num_bits: int, group_size: int, @@ -1497,7 +1497,7 @@ def reshape_w(w): return q_w, s_group, s_channel, w_ref -def choose_qparams_gguf( +def _choose_qparams_gguf( input: Optional[torch.Tensor], block_size: List[int], target_dtype: torch.dtype, @@ -1580,7 +1580,7 @@ def choose_qparams_gguf( ) -def quantize_gguf( +def _quantize_gguf( input: torch.Tensor, block_size: List[int], target_dtype: torch.dtype, @@ -1642,7 +1642,7 @@ def quantize_gguf( return int_data -def dequantize_gguf( +def _dequantize_gguf( input: torch.Tensor, block_size: List[int], target_dtype: torch.dtype, @@ -1705,7 +1705,7 @@ def dequantize_gguf( return dequant -def dequantize_affine_qqq( +def _dequantize_affine_qqq( w: torch.Tensor, s_group: torch.Tensor, s_channel: torch.Tensor, @@ -1845,7 +1845,7 @@ def _convert_to_affinequantized_format( # Main hqq quantizer function -def choose_qparams_and_quantize_affine_hqq( +def _choose_qparams_and_quantize_affine_hqq( tensor: torch.Tensor, nbits: float = 4, group_size: int = 64, @@ -1939,7 +1939,7 @@ def choose_qparams_and_quantize_affine_hqq( return W_q, scale, zero, shape -def choose_qparams_affine_floatx( +def _choose_qparams_affine_floatx( tensor: torch.Tensor, ebits: int, mbits: int ) -> torch.Tensor: # _n_ones() is not compatible with torch.compile() due to << operator @@ -1959,7 +1959,7 @@ def choose_qparams_affine_floatx( return scale.to(dtype) -def quantize_affine_floatx( +def _quantize_affine_floatx( tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int ) -> torch.Tensor: """Quantizes the float32 high precision floating point tensor to low precision floating point number and @@ -1970,7 +1970,7 @@ def quantize_affine_floatx( return tensor_floatx -def dequantize_affine_floatx( +def _dequantize_affine_floatx( tensor: torch.Tensor, scale: torch.Tensor, ebits: int, @@ -1983,7 +1983,7 @@ def dequantize_affine_floatx( return tensor -def choose_qparams_affine_float8( +def _choose_qparams_affine_float8( tensor: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, scale_dtype: torch.dtype = torch.float32, @@ -2075,7 +2075,7 @@ def _expand_scale_to_tensor_shape( return expanded_scale -def quantize_affine_float8( +def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, @@ -2095,7 +2095,7 @@ def quantize_affine_float8( return fp8_tensor -def dequantize_affine_float8( +def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, output_dtype: torch.dtype = torch.float32, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3c968e2d40..c7dd92d55c 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -15,15 +15,15 @@ from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_tinygemm, + _dequantize_affine_no_zero_point, + _dequantize_affine_tinygemm, + _quantize_affine_no_zero_point, + _quantize_affine_tinygemm, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_tinygemm, dequantize_affine, - dequantize_affine_no_zero_point, - dequantize_affine_tinygemm, quantize_affine, - quantize_affine_no_zero_point, - quantize_affine_tinygemm, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -357,7 +357,7 @@ def get_groupwise_affine_qparams( ) if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - scale, zero_point = choose_qparams_affine_tinygemm( + scale, zero_point = _choose_qparams_affine_tinygemm( w, mapping_type, block_size, @@ -369,7 +369,7 @@ def get_groupwise_affine_qparams( zero_point_dtype=zero_point_dtype, ) elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - scale, zero_point = choose_qparams_affine_dont_preserve_zero( + scale, zero_point = _choose_qparams_affine_dont_preserve_zero( w, mapping_type, block_size, @@ -439,9 +439,9 @@ def groupwise_affine_quantize_tensor_from_qparams( if zero_point_domain == ZeroPointDomain.INT: _quantize_affine = quantize_affine elif zero_point_domain == ZeroPointDomain.FLOAT: - _quantize_affine = quantize_affine_tinygemm + _quantize_affine = _quantize_affine_tinygemm elif ZeroPointDomain == ZeroPointDomain.NONE: - _quantize_affine = quantize_affine_no_zero_point + _quantize_affine = _quantize_affine_no_zero_point else: raise ValueError(f"Unrecognized zero point domain: {zero_point_domain}") @@ -508,9 +508,9 @@ def groupwise_affine_dequantize_tensor_from_qparams( if zero_point_domain == ZeroPointDomain.INT: _dequantize_affine = dequantize_affine elif zero_point_domain == ZeroPointDomain.FLOAT: - _dequantize_affine = dequantize_affine_tinygemm + _dequantize_affine = _dequantize_affine_tinygemm else: - _dequantize_affine = dequantize_affine_no_zero_point + _dequantize_affine = _dequantize_affine_no_zero_point return _dequantize_affine( w_int32, block_size, diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index ab7a2b4f37..df824e506f 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -48,7 +48,7 @@ LinearActivationQuantizedTensor, MappingType, PerTensor, - fake_quantize_affine, + _fake_quantize_affine, quantize_, to_linear_activation_quantized, ) @@ -237,7 +237,9 @@ def forward_pre_hook( new_input = [] for inp in args[0]: new_input.append( - fake_quantize_affine(inp, inp.shape, input_scale, input_zp, torch.uint8) + _fake_quantize_affine( + inp, inp.shape, input_scale, input_zp, torch.uint8 + ) ) mt = MultiTensor(new_input) From c2fb7cff9e0b39226a65c42daab4689e61361b46 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 13 Jun 2025 15:56:32 -0700 Subject: [PATCH 2/4] Update private APIs --- docs/source/api_ref_quantization.rst | 6 --- test/dtypes/test_affine_quantized_float.py | 4 +- test/prototype/test_gguf_quant.py | 2 +- test/quantization/test_quant_primitives.py | 4 +- torchao/quantization/README.md | 31 +++++++++++ torchao/quantization/__init__.py | 16 ------ torchao/quantization/quant_primitives.py | 60 +++++++++++----------- 7 files changed, 66 insertions(+), 57 deletions(-) diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index ba8fa4269d..f2fad00b69 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -63,14 +63,8 @@ Quantization Primitives choose_qparams_affine choose_qparams_affine_with_min_max - _choose_qparams_affine_floatx quantize_affine - _quantize_affine_floatx dequantize_affine - _dequantize_affine_floatx - _choose_qparams_and_quantize_affine_hqq - _fake_quantize_affine - _fake_quantize_affine_cachemask safe_int_mm int_scaled_matmul MappingType diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 1efa900efc..b63a406715 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -357,7 +357,7 @@ def test_mm_float8dq_per_row( @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) @common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)]) - def test__dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): + def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): """Test _dequantize_affine_float8 with various configurations""" device = "cuda" @@ -387,7 +387,7 @@ def test__dequantize_affine_float8(self, float8_dtype, output_dtype, block_size) @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - def test__dequantize_affine_float8_scale_broadcasting(self): + def test_dequantize_affine_float8_scale_broadcasting(self): """Test that scale broadcasting works correctly for block-wise quantization""" device = "cuda" # Create input tensor with known block structure diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py index 53ffcb5c60..af44243fe4 100644 --- a/test/prototype/test_gguf_quant.py +++ b/test/prototype/test_gguf_quant.py @@ -25,7 +25,7 @@ def setUp(self): self.block_size = (1, 32) self.dtype = torch.uint4 - def test__choose_qparams_gguf(self): + def test_choose_qparams_gguf(self): ( super_block_scale_scale, super_block_min_scale, diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index ae0fc9987f..ac2a42b9cf 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -752,7 +752,7 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test__fake_quantize_affine(self): + def test_fake_quantize_affine(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC @@ -788,7 +788,7 @@ def test__fake_quantize_affine(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test__fake_quantize_affine_cachemask(self): + def test_fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 90f83661aa..ae1619fea0 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -414,6 +414,37 @@ an example can be found in `torchao/_models/llama/eval.py`. The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. + +## Developer Notes + +### Quantization Primitives +The quantization primitives are implemented in `torchao/quantization/quant_primitives.py` and can be re-used while adding new quant techniques: +- Public APIs: + - `choose_qparams_affine` + - `quantize_affine` + - `dequantize_affine` +- Private APIs: + - `_choose_qparams_affine_tinygemm` + - `_choose_qparams_affine_dont_preserve_zero` + - `_choose_qparams_affine_floatx` + - `_choose_qparams_and_quantize_affine_hqq` + - `_choose_qparams_and_quantize_affine_qqq` + - `_choose_qparams_affine_float8` + - `_choose_qparams_gguf` + - `_quantize_affine_no_zero_point` + - `_quantize_affine_tinygemm` + - `_quantize_affine_floatx` + - `_quantize_affine_float8` + - `_quantize_gguf` + - `_dequantize_affine_no_zero_point` + - `_dequantize_affine_tinygemm` + - `_dequantize_affine_floatx` + - `_dequantize_affine_qqq` + - `_dequantize_affine_float8` + - `_dequantize_gguf` + - `_fake_quantize_affine` + - `_fake_quantize_affine_cachemask` + ## Notes 1. APIs have been hardware tested on A100 and T4(colab) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 4fd3bb1d6c..d9aba0bcc5 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -82,14 +82,6 @@ MappingType, TorchAODType, ZeroPointDomain, - _choose_qparams_affine_dont_preserve_zero, - _choose_qparams_affine_floatx, - _choose_qparams_affine_tinygemm, - _choose_qparams_and_quantize_affine_hqq, - _dequantize_affine_floatx, - _fake_quantize_affine, - _fake_quantize_affine_cachemask, - _quantize_affine_floatx, choose_qparams_affine, choose_qparams_affine_with_min_max, dequantize_affine, @@ -172,17 +164,9 @@ "AffineQuantizedObserverBase", # quant primitive ops "choose_qparams_affine", - "_choose_qparams_affine_tinygemm", - "_choose_qparams_affine_dont_preserve_zero", "choose_qparams_affine_with_min_max", - "_choose_qparams_affine_floatx", "quantize_affine", - "_quantize_affine_floatx", "dequantize_affine", - "_dequantize_affine_floatx", - "_choose_qparams_and_quantize_affine_hqq", - "_fake_quantize_affine", - "_fake_quantize_affine_cachemask", # operators/kernels "safe_int_mm", "int_scaled_matmul", diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index b99aafbc57..6e64b39119 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -24,32 +24,32 @@ __all__ = [ "choose_qparams_affine", + "choose_qparams_affine_with_min_max", + "quantize_affine", + "dequantize_affine", + "MappingType", + "ZeroPointDomain", + "TorchAODType", "_choose_qparams_affine_tinygemm", "_choose_qparams_affine_dont_preserve_zero", - "choose_qparams_affine_with_min_max", "_choose_qparams_affine_floatx", - "quantize_affine", + "_choose_qparams_and_quantize_affine_hqq", + "_choose_qparams_and_quantize_affine_qqq", + "_choose_qparams_affine_float8", + "_choose_qparams_gguf", "_quantize_affine_no_zero_point", "_quantize_affine_tinygemm", - "dequantize_affine", + "_quantize_affine_floatx", + "_quantize_affine_float8", + "_quantize_gguf", "_dequantize_affine_no_zero_point", "_dequantize_affine_tinygemm", - "_quantize_affine_floatx", "_dequantize_affine_floatx", - "_fake_quantize_affine", - "_fake_quantize_affine_cachemask", - "_choose_qparams_and_quantize_affine_hqq", - "_choose_qparams_and_quantize_affine_qqq", "_dequantize_affine_qqq", - "MappingType", - "ZeroPointDomain", - "TorchAODType", - "_choose_qparams_affine_float8", - "_quantize_affine_float8", "_dequantize_affine_float8", - "_choose_qparams_gguf", - "_quantize_gguf", "_dequantize_gguf", + "_fake_quantize_affine", + "_fake_quantize_affine_cachemask", ] @@ -453,7 +453,7 @@ def _quantize_affine_tinygemm( # torch.uintx dtypes yet if output_dtype in _SUB_BYTE_UINT_BOUNDS: output_dtype = torch.uint8 - return __quantize_affine_tinygemm_no_dtype_cast( + return _quantize_affine_tinygemm_no_dtype_cast( input, block_size, scale, @@ -463,7 +463,7 @@ def _quantize_affine_tinygemm( ).to(output_dtype) -def __quantize_affine_tinygemm_no_dtype_cast( +def _quantize_affine_tinygemm_no_dtype_cast( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -539,7 +539,7 @@ def _quantize_affine_no_zero_point( # torch.uintx dtypes yet if output_dtype in _SUB_BYTE_UINT_BOUNDS: output_dtype = torch.uint8 - return __quantize_affine_no_zero_point_no_dtype_cast( + return _quantize_affine_no_zero_point_no_dtype_cast( input, block_size, scale, @@ -549,7 +549,7 @@ def _quantize_affine_no_zero_point( ).to(output_dtype) -def __quantize_affine_no_zero_point_no_dtype_cast( +def _quantize_affine_no_zero_point_no_dtype_cast( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -714,7 +714,7 @@ def _dequantize_affine_no_dtype_check( return dequant.view(original_shape).to(output_dtype) -def __dequantize_affine_no_zero_point_no_dtype_check( +def _dequantize_affine_no_zero_point_no_dtype_check( input: torch.Tensor, block_size: List[int], scale: torch.Tensor, @@ -792,7 +792,7 @@ def _dequantize_affine_no_zero_point( torch.bfloat16, ], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) - return __dequantize_affine_no_zero_point_no_dtype_check( + return _dequantize_affine_no_zero_point_no_dtype_check( input, block_size, scale, @@ -803,7 +803,7 @@ def _dequantize_affine_no_zero_point( ) -def __dequantize_affine_tinygemm_no_dtype_check( +def _dequantize_affine_tinygemm_no_dtype_check( input: torch.Tensor, block_size: List[int], scale: torch.Tensor, @@ -887,7 +887,7 @@ def _dequantize_affine_tinygemm( torch.bfloat16, ], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) - return __dequantize_affine_tinygemm_no_dtype_check( + return _dequantize_affine_tinygemm_no_dtype_check( input, block_size, scale, @@ -933,7 +933,7 @@ def _fake_quantize_affine( raise ValueError("Please use ZeroPointDomain.NONE instead of None") elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: raise ValueError("zero_point should be None when zero_point_domain is NONE") - (_, fq) = _do__fake_quantize_affine( + (_, fq) = _do_fake_quantize_affine( input, block_size, scale, @@ -979,7 +979,7 @@ def _fake_quantize_affine_cachemask( raise ValueError("Please use ZeroPointDomain.NONE instead of None") elif zero_point_domain is None and zero_point is not None: raise ValueError("zero_point should be None when zero_point_domain is NONE") - (q, dq) = _do__fake_quantize_affine( + (q, dq) = _do_fake_quantize_affine( input, block_size, scale, @@ -993,7 +993,7 @@ def _fake_quantize_affine_cachemask( return (dq, mask) -def _do__fake_quantize_affine( +def _do_fake_quantize_affine( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -1013,11 +1013,11 @@ def _do__fake_quantize_affine( _quantize_affine = _quantize_affine_no_dtype_cast _dequantize_affine = _dequantize_affine_no_dtype_check elif zero_point_domain == ZeroPointDomain.FLOAT: - _quantize_affine = __quantize_affine_tinygemm_no_dtype_cast - _dequantize_affine = __dequantize_affine_tinygemm_no_dtype_check + _quantize_affine = _quantize_affine_tinygemm_no_dtype_cast + _dequantize_affine = _dequantize_affine_tinygemm_no_dtype_check elif ZeroPointDomain == ZeroPointDomain.NONE: - _quantize_affine = __quantize_affine_no_zero_point_no_dtype_cast - _dequantize_affine = __dequantize_affine_no_zero_point_no_dtype_check + _quantize_affine = _quantize_affine_no_zero_point_no_dtype_cast + _dequantize_affine = _dequantize_affine_no_zero_point_no_dtype_check else: raise ValueError(f"Unrecognized zero point domain: {zero_point_domain}") q = _quantize_affine( From 43ff9962b9a9146d7b7e1d9dda685f35d75d78a6 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 13 Jun 2025 16:49:37 -0700 Subject: [PATCH 3/4] Add doc string --- torchao/quantization/quant_primitives.py | 277 +++++++++++++++++++---- 1 file changed, 236 insertions(+), 41 deletions(-) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 6e64b39119..df136bc06e 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -228,9 +228,19 @@ def backward(ctx, gy: torch.Tensor) -> torch.Tensor: # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): - """Get quant_min and quant_max args based on dtype and also - verify that they are within the range of possible quant_min/quant_max - for dtype + """Get quant_min and quant_max args based on dtype and also verify bounds. + + Args: + dtype: Target quantization dtype (e.g., torch.uint8, torch.int8, or FP8 types) + quant_min: Minimum quantized value, or None to use dtype default + quant_max: Maximum quantized value, or None to use dtype default + + Returns: + Tuple[int/float, int/float]: Validated (quant_min, quant_max) values + + Raises: + ValueError: If dtype is unsupported + AssertionError: If quant_min/quant_max are out of bounds for dtype """ if dtype in FP8_TYPES: quant_min_lower_bound, quant_max_upper_bound = ( @@ -357,11 +367,25 @@ def _quantize_affine( quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, ) -> torch.Tensor: - """op definition that has compatible signatures with custom op library + """Quantize tensor using affine quantization with integer zero point domain. + + Op definition that has compatible signatures with custom op library. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + output_dtype: Target quantized dtype (e.g., torch.uint8, torch.int8) + quant_min: Minimum quantized value, derived from dtype if None + quant_max: Maximum quantized value, derived from dtype if None + + Returns: + Quantized tensor with requested dtype Note: - zero_point_domain is pre-defined specifies how we quantize the floating point to quantized data: - INT: quantized_val = (float_val / scale) (integer) + zero_point (integer) + zero_point_domain is pre-defined as INT, meaning: + quantized_val = (float_val / scale) (integer) + zero_point (integer) """ quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) # workaround for uintx dtypes, since we don't have native Uintx dtype connected with @@ -386,12 +410,26 @@ def _quantize_affine_no_dtype_cast( quant_min: Union[int, float], quant_max: Union[int, float], ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization without dtype casting. + + Performs quantization with integer zero point domain without casting to target dtype. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + quant_min: Minimum quantized value + quant_max: Maximum quantized value + + Returns: + Quantized tensor without dtype casting + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = INT - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale and zero_point with zero_point_domain = INT + 3. Reshape the quantized result to original shape """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -437,16 +475,31 @@ def _quantize_affine_tinygemm( quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization with float zero point domain for tinygemm. + + Specialized quantization for tinygemm int4mm kernel where zero point is in floating point domain. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + output_dtype: Target quantized dtype (e.g., torch.uint8, torch.int8) + quant_min: Minimum quantized value, derived from dtype if None + quant_max: Maximum quantized value, derived from dtype if None + + Returns: + Quantized tensor with requested dtype + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = FLOAT - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale and zero_point with zero_point_domain = FLOAT + 3. Reshape the quantized result to original shape Note: - zero_point_domain is pre-defined specifies how we quantize the floating point to quantized data: - FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + zero_point_domain is pre-defined as FLOAT, meaning: + quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale """ quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) # workaround for uintx dtypes, since we don't have native Uintx dtype connected with @@ -471,12 +524,26 @@ def _quantize_affine_tinygemm_no_dtype_cast( quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization with float zero point domain without dtype casting. + + Specialized quantization for tinygemm int4mm kernel where zero point is in floating point domain. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + quant_min: Minimum quantized value + quant_max: Maximum quantized value + + Returns: + Quantized tensor without dtype casting + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = FLOAT - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale and zero_point with zero_point_domain = FLOAT + 3. Reshape the quantized result to original shape """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -522,17 +589,32 @@ def _quantize_affine_no_zero_point( quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization without zero point. + + Specialized quantization for cases where zero point is not needed (e.g., floatx quantization). + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (ignored, should be None) + output_dtype: Target quantized dtype (e.g., torch.uint8, torch.int8) + quant_min: Minimum quantized value, derived from dtype if None + quant_max: Maximum quantized value, derived from dtype if None + + Returns: + Quantized tensor with requested dtype + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = NONE - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale with zero_point_domain = NONE + 3. Reshape the quantized result to original shape Note: - zero_point_domain is pre-defined specifies how we quantize the floating point to quantized data: - None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization - Where we do not want to round values to nearest integer and instead scale and cast. + zero_point_domain is pre-defined as NONE, meaning: + quantized_val = (float_val / scale) | This is primarily used for floatx quantization + where we do not want to round values to nearest integer and instead scale and cast. """ quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) # workaround for uintx dtypes, since we don't have native Uintx dtype connected with @@ -557,12 +639,26 @@ def _quantize_affine_no_zero_point_no_dtype_cast( quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization without zero point and without dtype casting. + + Specialized quantization for cases where zero point is not needed without casting to target dtype. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (ignored, should be None) + quant_min: Minimum quantized value + quant_max: Maximum quantized value + + Returns: + Quantized tensor without dtype casting + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = NONE - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale with zero_point_domain = NONE + 3. Reshape the quantized result to original shape """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -648,7 +744,23 @@ def _dequantize_affine( quant_max: Optional[Union[int, float, bool]] = None, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - """op definition that has compatible signatures with custom op library""" + """Dequantize tensor using affine dequantization with integer zero point domain. + + Op definition that has compatible signatures with custom op library. + + Args: + input: Quantized tensor to dequantize + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + input_dtype: Expected dtype of input tensor (e.g., torch.uint8, torch.int8) + quant_min: Minimum quantized value for input tensor + quant_max: Maximum quantized value for input tensor + output_dtype: Target output dtype (default: torch.float32) + + Returns: + Dequantized tensor with requested output dtype + """ # TODO: validate scale/zero_point dimensions are compatible with block_size if input_dtype not in _SUB_BYTE_UINT_BOUNDS: assert input.dtype == input_dtype, ( @@ -680,13 +792,27 @@ def _dequantize_affine_no_dtype_check( quant_max: Union[int, float], output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - """This function converts AQT tensors to their high precision floating point representation + """Dequantize tensor using affine dequantization without dtype checking. + + Converts quantized tensors to their high precision floating point representation. + + Args: + input: Quantized tensor to dequantize + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + quant_min: Minimum quantized value for input tensor + quant_max: Maximum quantized value for input tensor + output_dtype: Target output dtype (default: torch.float32) + + Returns: + Dequantized tensor with requested output dtype The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain - 3. reshape the quantized result to origianl shape and change dtype to the output_dtype + 2. Dequantize the input based on the quantization parameters scale and zero_point + 3. Reshape the quantized result to original shape and change dtype to the output_dtype """ assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" @@ -723,13 +849,27 @@ def _dequantize_affine_no_zero_point_no_dtype_check( quant_max: Union[int, float], output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - """This function converts AQT tensors to their high precision floating point representation + """Dequantize tensor using affine dequantization without zero point and without dtype checking. + + Converts quantized tensors to their high precision floating point representation without zero point. + + Args: + input: Quantized tensor to dequantize + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (ignored, should be None) + quant_min: Minimum quantized value for input tensor + quant_max: Maximum quantized value for input tensor + output_dtype: Target output dtype (default: torch.float32) + + Returns: + Dequantized tensor with requested output dtype The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain - 3. reshape the quantized result to origianl shape and change dtype to the output_dtype + 2. Dequantize the input based on the quantization parameters scale (no zero point) + 3. Reshape the quantized result to original shape and change dtype to the output_dtype """ assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" @@ -1003,7 +1143,24 @@ def _do_fake_quantize_affine( quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ + """Helper function for fake quantization that returns both intermediate and final values. + + Performs quantization followed by dequantization without dtype casting, returning both + the intermediate quantized values and the final dequantized values. + + Args: + input: Input tensor to fake quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + quant_dtype: Target quantized dtype for determining quant_min/quant_max + quant_min: Minimum quantized value, derived from dtype if None + quant_max: Maximum quantized value, derived from dtype if None + zero_point_domain: Domain of zero point (INT, FLOAT, or NONE) + + Returns: + Tuple of (intermediate quantized values, final dequantized values) + Helper function for `_fake_quantize_affine` that returns both the intermediate quantized values and the final dequantized values. """ @@ -1857,6 +2014,28 @@ def _choose_qparams_and_quantize_affine_hqq( raw_output: bool = False, # If True, it will return the quant params in hqq lib format optimize_weights: Callable = optimize_weights_proximal_legacy, # weights proximal optimizer function ) -> tuple: + """Choose quantization parameters and quantize tensor using HQQ (Half-Quadratic Quantization). + + Performs quantization using HQQ method with optional weight optimization via proximal solver. + + Args: + tensor: Input tensor to quantize (float32, float16, or bfloat16) + nbits: Number of bits for quantization (default: 4) + group_size: Size of quantization groups (default: 64) + optimize: Whether to optimize weights using proximal solver (default: True) + axis: Axis along which to perform quantization (0 or 1, default: 1) + compute_dtype: Target compute dtype (default: torch.float16) + device: Target device for computation (default: "cuda") + verbose: Whether to print optimization error information (default: False) + raw_output: If True, return params in HQQ library format (default: False) + optimize_weights: Weight optimization function (default: optimize_weights_proximal_legacy) + + Returns: + Tuple of (quantized_weights, scale, zero_point, original_shape) + + Note: + Uses proximal solver to minimize ||W - dequantize(quantize(W))||_p^p for weight optimization. + """ assert axis in [0, 1], "axis should be either 0 or 1" if group_size is not None: assert _is_divisible(tensor.numel(), group_size), ( @@ -1942,6 +2121,22 @@ def _choose_qparams_and_quantize_affine_hqq( def _choose_qparams_affine_floatx( tensor: torch.Tensor, ebits: int, mbits: int ) -> torch.Tensor: + """Choose quantization parameters for floatx quantization. + + Calculates scale parameter for quantizing to custom floating point format. + + Args: + tensor: Input tensor to quantize (float32, float16, or bfloat16) + ebits: Number of exponent bits in target floatx format + mbits: Number of mantissa bits in target floatx format + + Returns: + Scale tensor for floatx quantization + + Note: + Uses global lookup table as workaround for torch.compile() compatibility + since _n_ones() is not compatible due to << operator. + """ # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) From 9acb9915532c585ee83b096fa48aabec7be38586 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 16 Jun 2025 10:52:32 -0700 Subject: [PATCH 4/4] Update --- torchao/quantization/README.md | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index ae1619fea0..90f83661aa 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -414,37 +414,6 @@ an example can be found in `torchao/_models/llama/eval.py`. The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. - -## Developer Notes - -### Quantization Primitives -The quantization primitives are implemented in `torchao/quantization/quant_primitives.py` and can be re-used while adding new quant techniques: -- Public APIs: - - `choose_qparams_affine` - - `quantize_affine` - - `dequantize_affine` -- Private APIs: - - `_choose_qparams_affine_tinygemm` - - `_choose_qparams_affine_dont_preserve_zero` - - `_choose_qparams_affine_floatx` - - `_choose_qparams_and_quantize_affine_hqq` - - `_choose_qparams_and_quantize_affine_qqq` - - `_choose_qparams_affine_float8` - - `_choose_qparams_gguf` - - `_quantize_affine_no_zero_point` - - `_quantize_affine_tinygemm` - - `_quantize_affine_floatx` - - `_quantize_affine_float8` - - `_quantize_gguf` - - `_dequantize_affine_no_zero_point` - - `_dequantize_affine_tinygemm` - - `_dequantize_affine_floatx` - - `_dequantize_affine_qqq` - - `_dequantize_affine_float8` - - `_dequantize_gguf` - - `_fake_quantize_affine` - - `_fake_quantize_affine_cachemask` - ## Notes 1. APIs have been hardware tested on A100 and T4(colab)