diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index 5293684ab9..f2fad00b69 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -63,14 +63,8 @@ Quantization Primitives choose_qparams_affine choose_qparams_affine_with_min_max - choose_qparams_affine_floatx quantize_affine - quantize_affine_floatx dequantize_affine - dequantize_affine_floatx - choose_qparams_and_quantize_affine_hqq - fake_quantize_affine - fake_quantize_affine_cachemask safe_int_mm int_scaled_matmul MappingType diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 879551fc0a..b63a406715 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -42,10 +42,10 @@ ) from torchao.quantization.quant_primitives import ( MappingType, + _choose_qparams_affine_float8, + _dequantize_affine_float8, + _quantize_affine_float8, choose_qparams_affine, - choose_qparams_affine_float8, - dequantize_affine_float8, - quantize_affine_float8, ) from torchao.utils import ( is_sm_at_least_89, @@ -358,21 +358,21 @@ def test_mm_float8dq_per_row( @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) @common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)]) def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): - """Test dequantize_affine_float8 with various configurations""" + """Test _dequantize_affine_float8 with various configurations""" device = "cuda" input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) # Choose quantization parameters - scale = choose_qparams_affine_float8( + scale = _choose_qparams_affine_float8( input_tensor, float8_dtype=float8_dtype, block_size=block_size ) # Quantize - quantized = quantize_affine_float8(input_tensor, scale, float8_dtype) + quantized = _quantize_affine_float8(input_tensor, scale, float8_dtype) # Dequantize - dequantized = dequantize_affine_float8(quantized, scale, output_dtype) + dequantized = _dequantize_affine_float8(quantized, scale, output_dtype) # Verify output properties self.assertEqual(dequantized.dtype, output_dtype) @@ -395,7 +395,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self): block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim # Choose quantization parameters - scale = choose_qparams_affine_float8( + scale = _choose_qparams_affine_float8( input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size ) @@ -407,10 +407,10 @@ def test_dequantize_affine_float8_scale_broadcasting(self): self.assertEqual(scale.shape, expected_scale_shape) # Quantize - quantized = quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn) + quantized = _quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn) # Dequantize - dequantized = dequantize_affine_float8(quantized, scale, torch.float32) + dequantized = _dequantize_affine_float8(quantized, scale, torch.float32) # Verify shapes match self.assertEqual(dequantized.shape, input_tensor.shape) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 956ef9a03e..237bc2bd92 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -91,13 +91,13 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device): @parametrize("ebits,mbits", _Floatx_DTYPES) def test_to_copy_device(self, ebits, mbits): from torchao.quantization.quant_primitives import ( - choose_qparams_affine_floatx, - quantize_affine_floatx, + _choose_qparams_affine_floatx, + _quantize_affine_floatx, ) x = torch.randn(256, 64) - scale = choose_qparams_affine_floatx(x, ebits, mbits) - x = quantize_affine_floatx(x, scale, ebits, mbits) + scale = _choose_qparams_affine_floatx(x, ebits, mbits) + x = _quantize_affine_floatx(x, scale, ebits, mbits) _layout = FloatxTensorCoreLayout(ebits, mbits) floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain( x, scale, None, _layout diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py index b68d84b101..af44243fe4 100644 --- a/test/prototype/test_gguf_quant.py +++ b/test/prototype/test_gguf_quant.py @@ -13,7 +13,7 @@ GGUFWeightOnlyConfig, ) from torchao.quantization import quantize_ -from torchao.quantization.quant_primitives import choose_qparams_gguf +from torchao.quantization.quant_primitives import _choose_qparams_gguf from torchao.quantization.utils import compute_error @@ -31,7 +31,7 @@ def test_choose_qparams_gguf(self): super_block_min_scale, quantized_block_scale, quantized_block_min, - ) = choose_qparams_gguf(self.input, self.block_size, self.dtype) + ) = _choose_qparams_gguf(self.input, self.block_size, self.dtype) assert super_block_scale_scale.shape, (2, 8) assert super_block_min_scale.shape, (2, 8) diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index cff46ad329..8fe21c6bd3 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -21,7 +21,7 @@ ) from torchao.quantization.quant_primitives import ( MappingType, - choose_qparams_and_quantize_affine_qqq, + _choose_qparams_and_quantize_affine_qqq, ) from torchao.testing.utils import skip_if_rocm from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -102,7 +102,7 @@ def test_pack_unpack_equivalence(self): for group_size in [-1, 128]: # Quantize weights - q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + q_w, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq( w, num_bits, group_size ) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 323802757d..f0404a2ac2 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -64,9 +64,9 @@ MappingType, TorchAODType, ZeroPointDomain, + _fake_quantize_affine, choose_qparams_affine, dequantize_affine, - fake_quantize_affine, quantize_affine, ) from torchao.quantization.unified import ( @@ -637,7 +637,7 @@ def test_qat_4w_primitives(self): group_size, scales_precision, ) - w_fq = fake_quantize_affine( + w_fq = _fake_quantize_affine( weight, block_size, scales, diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index e69d68b27f..ac2a42b9cf 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -13,11 +13,11 @@ from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, + _choose_qparams_affine_tinygemm, + _fake_quantize_affine, + _fake_quantize_affine_cachemask, choose_qparams_affine, - choose_qparams_affine_tinygemm, dequantize_affine, - fake_quantize_affine, - fake_quantize_affine_cachemask, quantize_affine, ) @@ -672,7 +672,7 @@ def test_get_groupwise_affine_qparams(self): zero_point_domain=zero_point_domain, ) if zero_point_domain == ZeroPointDomain.FLOAT: - scale, zero_point = choose_qparams_affine_tinygemm( + scale, zero_point = _choose_qparams_affine_tinygemm( input, mapping_type, block_size, @@ -780,7 +780,7 @@ def test_fake_quantize_affine(self): dequantized = dequantize_affine( quantized, block_size, scale, zero_point, dtype, quant_min, quant_max ) - fake_quantized = fake_quantize_affine( + fake_quantized = _fake_quantize_affine( input, block_size, scale, zero_point, dtype, quant_min, quant_max ) torch.testing.assert_close(dequantized, fake_quantized) @@ -816,7 +816,7 @@ def test_fake_quantize_affine_cachemask(self): dequantized = dequantize_affine( quantized, block_size, scale, zero_point, dtype, quant_min, quant_max ) - (fake_quantized, mask) = fake_quantize_affine_cachemask( + (fake_quantized, mask) = _fake_quantize_affine_cachemask( input, block_size, scale, diff --git a/test/test_ops.py b/test/test_ops.py index 012a4d562d..faec689a69 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -23,7 +23,9 @@ marlin_qqq_workspace, pack_to_marlin_qqq, ) -from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq +from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_affine_qqq, +) from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -713,7 +715,7 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) # Quantize weights - q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq( + q_w, s_group, s_channel, w_ref = _choose_qparams_and_quantize_affine_qqq( b_weight, num_bits, group_size ) q_w = q_w.t() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 132ac0f28e..39f9131a9e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -18,22 +18,22 @@ FP8_TYPES, MappingType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_float8, + _choose_qparams_affine_floatx, + _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_affine_hqq, + _dequantize_affine_float8, + _dequantize_affine_floatx, + _dequantize_affine_no_zero_point, + _dequantize_affine_tinygemm, + _quantize_affine_float8, + _quantize_affine_floatx, + _quantize_affine_no_zero_point, + _quantize_affine_tinygemm, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_float8, - choose_qparams_affine_floatx, - choose_qparams_affine_tinygemm, - choose_qparams_and_quantize_affine_hqq, dequantize_affine, - dequantize_affine_float8, - dequantize_affine_floatx, - dequantize_affine_no_zero_point, - dequantize_affine_tinygemm, quantize_affine, - quantize_affine_float8, - quantize_affine_floatx, - quantize_affine_no_zero_point, - quantize_affine_tinygemm, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -142,7 +142,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if isinstance(self._layout, FloatxTensorCoreLayout): int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx( + return _dequantize_affine_floatx( int_data, scale, self._layout.ebits, @@ -151,11 +151,11 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor ) elif isinstance(self._layout, Float8Layout): data, scale, _ = self.tensor_impl.get_plain() - return dequantize_affine_float8(data, scale, output_dtype) + return _dequantize_affine_float8(data, scale, output_dtype) else: data, scale, zero_point = self.tensor_impl.get_plain() if self.zero_point_domain == ZeroPointDomain.FLOAT: - dq = dequantize_affine_tinygemm( + dq = _dequantize_affine_tinygemm( data, self.block_size, scale, @@ -166,7 +166,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor output_dtype=output_dtype, ) elif self.zero_point_domain == ZeroPointDomain.NONE: - dq = dequantize_affine_no_zero_point( + dq = _dequantize_affine_no_zero_point( data, self.block_size, scale, @@ -270,7 +270,7 @@ def from_hp_to_intx( from torchao.dtypes import Int4CPULayout from torchao.dtypes.uintx import TensorCoreTiledLayout - data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq( + data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq( input_float, nbits=nbits, group_size=group_size, @@ -291,7 +291,7 @@ def from_hp_to_intx( data = data.to(target_dtype) else: if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - scale, zero_point = choose_qparams_affine_tinygemm( + scale, zero_point = _choose_qparams_affine_tinygemm( input_float, mapping_type, block_size, @@ -303,7 +303,7 @@ def from_hp_to_intx( zero_point_dtype, ) elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - scale, zero_point = choose_qparams_affine_dont_preserve_zero( + scale, zero_point = _choose_qparams_affine_dont_preserve_zero( input_float, mapping_type, block_size, @@ -329,7 +329,7 @@ def from_hp_to_intx( # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None if zero_point_domain == ZeroPointDomain.NONE: zero_point = None - data = quantize_affine_no_zero_point( + data = _quantize_affine_no_zero_point( input_float, block_size, scale, @@ -339,7 +339,7 @@ def from_hp_to_intx( quant_max, ) elif zero_point_domain == ZeroPointDomain.FLOAT: - data = quantize_affine_tinygemm( + data = _quantize_affine_tinygemm( input_float, block_size, scale, @@ -400,7 +400,7 @@ def from_hp_to_intx_static( if zero_point_domain == ZeroPointDomain.NONE: zero_point = None - int_data = quantize_affine_no_zero_point( + int_data = _quantize_affine_no_zero_point( input_float, block_size, scale, @@ -410,7 +410,7 @@ def from_hp_to_intx_static( quant_max, ) elif zero_point_domain == ZeroPointDomain.FLOAT: - int_data = quantize_affine_tinygemm( + int_data = _quantize_affine_tinygemm( input_float, block_size, scale, @@ -462,10 +462,10 @@ def from_hp_to_floatx( if target_dtype in FP8_TYPES: original_shape = input_float.shape input_float = _layout.pre_process(input_float) - scale = choose_qparams_affine_float8( + scale = _choose_qparams_affine_float8( input_float, float8_dtype=target_dtype, block_size=block_size ) - data = quantize_affine_float8(input_float, scale, target_dtype) + data = _quantize_affine_float8(input_float, scale, target_dtype) data, scale, zero_point = _layout.post_process( data, scale, None, block_size ) @@ -499,7 +499,7 @@ def from_hp_to_floatx_static( input_float, scale, ZeroPointDomain.NONE, block_size ) - data = quantize_affine_float8( + data = _quantize_affine_float8( input_float, scale, target_dtype, @@ -545,8 +545,8 @@ def from_hp_to_fpx( ebits, mbits = _layout.ebits, _layout.mbits # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = choose_qparams_affine_floatx(input_float, ebits, mbits) - floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) + scale = _choose_qparams_affine_floatx(input_float, ebits, mbits) + floatx_unpacked = _quantize_affine_floatx(input_float, scale, ebits, mbits) floatx_packed, scale, _ = _layout.post_process( floatx_unpacked, scale, None, block_size ) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index a76b4daa23..02a2d3004a 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -92,9 +92,9 @@ ) from torchao.quantization.quant_primitives import ( ZeroPointDomain, + _dequantize_affine_no_zero_point, + _dequantize_affine_tinygemm, dequantize_affine, - dequantize_affine_no_zero_point, - dequantize_affine_tinygemm, ) from torchao.utils import ( fill_defaults, @@ -318,9 +318,9 @@ def _(func, types, args, kwargs): # we need to increase block size to correct dim new_blocks = idx.dim() - 1 if args[1].zero_point_domain == ZeroPointDomain.FLOAT: - _dequantize_affine = dequantize_affine_tinygemm + _dequantize_affine = _dequantize_affine_tinygemm elif args[1].zero_point_domain == ZeroPointDomain.NONE: - _dequantize_affine = dequantize_affine_no_zero_point + _dequantize_affine = _dequantize_affine_no_zero_point else: _dequantize_affine = dequantize_affine diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 6871033f1a..c7fb1e1a7c 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -467,7 +467,7 @@ class FloatxTensorCoreLayout(Layout): class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), it has a internal tensor field of "packed_floatx_data", which is packed from the - uint8 unpacked data (the output of `quantize_affine_floatx` operator) + uint8 unpacked data (the output of `_quantize_affine_floatx` operator) The packing is optimized for TensorCore, from the fp6-llm paper: https://arxiv.org/abs/2401.14112 github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 6c89f98ff7..bf9446d265 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -19,7 +19,7 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ( ZeroPointDomain, - quantize_affine_tinygemm, + _quantize_affine_tinygemm, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -266,7 +266,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine_tinygemm( + int_data = _quantize_affine_tinygemm( dequantized, block_size, scale, diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index c67eebd747..955a7a8610 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -377,8 +377,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( + _quantize_affine_tinygemm, quantize_affine, - quantize_affine_tinygemm, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros @@ -429,7 +429,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine_tinygemm( + int_data = _quantize_affine_tinygemm( dequantized, block_size, scale, diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 3f3f4fa075..04066a6c65 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -24,8 +24,8 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout from torchao.quantization.quant_primitives import ( ZeroPointDomain, - choose_qparams_and_quantize_affine_qqq, - dequantize_affine_qqq, + _choose_qparams_and_quantize_affine_qqq, + _dequantize_affine_qqq, ) logger = logging.getLogger(__name__) @@ -36,9 +36,9 @@ class MarlinQQQTensor(AffineQuantizedTensor): """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. - To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, + To see what happens during _choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq + and check the two quant primitive ops: _choose_qparams_and_quantize_affine_qqq and _dequantize_affine_qqq """ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: @@ -48,7 +48,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor int_data, s_group, s_channel = self.tensor_impl.get_plain() nbits = int(math.log2(self.quant_max - self.quant_min + 1)) group_size = max(self.block_size) - return dequantize_affine_qqq( + return _dequantize_affine_qqq( int_data, s_group, s_channel, nbits, group_size, output_dtype ) @@ -69,7 +69,7 @@ def from_hp_to_intx( input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) group_size = max(block_size) - data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + data, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq( input_float, nbits, group_size ) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 0856d22fee..591d9a9be1 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -21,7 +21,7 @@ from torchao.quantization.quant_primitives import ( ZeroPointDomain, _get_reduction_params, - quantize_affine_tinygemm, + _quantize_affine_tinygemm, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -511,7 +511,7 @@ def dequant_4d(self): target_dtype = torch.int32 quant_min = 0 quant_max = 15 - int_data = quantize_affine_tinygemm( + int_data = _quantize_affine_tinygemm( dequantized, self.block_size, scale, diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index 4f90f9cb92..ebe4e775e6 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -14,15 +14,15 @@ _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_tinygemm, + _dequantize_affine_no_zero_point, + _dequantize_affine_tinygemm, + _quantize_affine_no_zero_point, + _quantize_affine_tinygemm, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_tinygemm, dequantize_affine, - dequantize_affine_no_zero_point, - dequantize_affine_tinygemm, quantize_affine, - quantize_affine_no_zero_point, - quantize_affine_tinygemm, ) from .quantizer import Quantizer @@ -57,16 +57,16 @@ def __init__( self._dequantize = dequantize_affine if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - self._choose_qparams = choose_qparams_affine_tinygemm - self._quantize = quantize_affine_tinygemm - self._dequantize = dequantize_affine_tinygemm + self._choose_qparams = _choose_qparams_affine_tinygemm + self._quantize = _quantize_affine_tinygemm + self._dequantize = _dequantize_affine_tinygemm elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - self._choose_qparams = choose_qparams_affine_dont_preserve_zero + self._choose_qparams = _choose_qparams_affine_dont_preserve_zero self._quantize = quantize_affine self._dequantize = dequantize_affine elif zero_point_domain == ZeroPointDomain.NONE: - self._quantize = quantize_affine_no_zero_point - self._dequantize = dequantize_affine_no_zero_point + self._quantize = _quantize_affine_no_zero_point + self._dequantize = _dequantize_affine_no_zero_point def _init_quant_min_max(self, b: int) -> None: if self.quant_min is None or self.quant_max is None: diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index 9757769d16..c1272fceb6 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -10,9 +10,9 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.quant_primitives import ( - choose_qparams_gguf, - dequantize_gguf, - quantize_gguf, + _choose_qparams_gguf, + _dequantize_gguf, + _quantize_gguf, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -130,7 +130,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor block_size = tuple( [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_blocks_per_superblock] ) - return dequantize_gguf( + return _dequantize_gguf( self.int_data, block_size, self.dtype, @@ -198,9 +198,9 @@ def from_float(cls, input_float, n_blocks_per_superblock, target_dtype): super_block_min_scale, quantized_block_scale, quantized_block_min, - ) = choose_qparams_gguf(input_float, block_size, target_dtype) + ) = _choose_qparams_gguf(input_float, block_size, target_dtype) - int_data = quantize_gguf( + int_data = _quantize_gguf( input_float, block_size, target_dtype, diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 44fc6c8397..d9aba0bcc5 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -83,17 +83,9 @@ TorchAODType, ZeroPointDomain, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_floatx, - choose_qparams_affine_tinygemm, choose_qparams_affine_with_min_max, - choose_qparams_and_quantize_affine_hqq, dequantize_affine, - dequantize_affine_floatx, - fake_quantize_affine, - fake_quantize_affine_cachemask, quantize_affine, - quantize_affine_floatx, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -172,17 +164,9 @@ "AffineQuantizedObserverBase", # quant primitive ops "choose_qparams_affine", - "choose_qparams_affine_tinygemm", - "choose_qparams_affine_dont_preserve_zero", "choose_qparams_affine_with_min_max", - "choose_qparams_affine_floatx", "quantize_affine", - "quantize_affine_floatx", "dequantize_affine", - "dequantize_affine_floatx", - "choose_qparams_and_quantize_affine_hqq", - "fake_quantize_affine", - "fake_quantize_affine_cachemask", # operators/kernels "safe_int_mm", "int_scaled_matmul", diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py index 6896588971..80ecd173c2 100644 --- a/torchao/quantization/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py @@ -12,11 +12,11 @@ from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_tinygemm, + _fake_quantize_affine, _get_and_check_qmin_qmax, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_tinygemm, - fake_quantize_affine, ) from torchao.utils import TorchAOBaseTensor @@ -55,7 +55,7 @@ def apply_fake_quant_fn(t: torch.Tensor): assert isinstance(t, AffineFakeQuantizedTensor) qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - scale, zero_point = choose_qparams_affine_tinygemm( + scale, zero_point = _choose_qparams_affine_tinygemm( t.original_tensor, mapping_type, block_size, @@ -67,7 +67,7 @@ def apply_fake_quant_fn(t: torch.Tensor): zero_point_dtype, ) elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - scale, zero_point = choose_qparams_affine_dont_preserve_zero( + scale, zero_point = _choose_qparams_affine_dont_preserve_zero( t.original_tensor, mapping_type, block_size, @@ -90,7 +90,7 @@ def apply_fake_quant_fn(t: torch.Tensor): scale_dtype, zero_point_dtype, ) - fq = fake_quantize_affine( + fq = _fake_quantize_affine( t, block_size, scale, diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index 132020499c..4f3323a1e8 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -9,7 +9,7 @@ from torchao.quantization.quant_primitives import ( ZeroPointDomain, - fake_quantize_affine, + _fake_quantize_affine, ) from torchao.quantization.utils import ( _get_per_token_block_size, @@ -87,7 +87,7 @@ def _fake_quantize_per_channel_group( assert input.shape[-1] % group_size == 0 assert input.dim() == 2 block_size = (1, group_size) - return fake_quantize_affine( + return _fake_quantize_affine( input, block_size, scales, @@ -110,7 +110,7 @@ def _fake_quantize_per_token( _per_token_quant_qparam_dim_check(input, scales, zero_points) block_size = _get_per_token_block_size(input) - fq = fake_quantize_affine( + fq = _fake_quantize_affine( input, block_size, scales, diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 9e0c6447c8..df136bc06e 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -24,32 +24,32 @@ __all__ = [ "choose_qparams_affine", - "choose_qparams_affine_tinygemm", - "choose_qparams_affine_dont_preserve_zero", "choose_qparams_affine_with_min_max", - "choose_qparams_affine_floatx", "quantize_affine", - "quantize_affine_no_zero_point", - "quantize_affine_tinygemm", "dequantize_affine", - "dequantize_affine_no_zero_point", - "dequantize_affine_tinygemm", - "quantize_affine_floatx", - "dequantize_affine_floatx", - "fake_quantize_affine", - "fake_quantize_affine_cachemask", - "choose_qparams_and_quantize_affine_hqq", - "choose_qparams_and_quantize_affine_qqq", - "dequantize_affine_qqq", "MappingType", "ZeroPointDomain", "TorchAODType", - "choose_qparams_affine_float8", - "quantize_affine_float8", - "dequantize_affine_float8", - "choose_qparams_gguf", - "quantize_gguf", - "dequantize_gguf", + "_choose_qparams_affine_tinygemm", + "_choose_qparams_affine_dont_preserve_zero", + "_choose_qparams_affine_floatx", + "_choose_qparams_and_quantize_affine_hqq", + "_choose_qparams_and_quantize_affine_qqq", + "_choose_qparams_affine_float8", + "_choose_qparams_gguf", + "_quantize_affine_no_zero_point", + "_quantize_affine_tinygemm", + "_quantize_affine_floatx", + "_quantize_affine_float8", + "_quantize_gguf", + "_dequantize_affine_no_zero_point", + "_dequantize_affine_tinygemm", + "_dequantize_affine_floatx", + "_dequantize_affine_qqq", + "_dequantize_affine_float8", + "_dequantize_gguf", + "_fake_quantize_affine", + "_fake_quantize_affine_cachemask", ] @@ -228,9 +228,19 @@ def backward(ctx, gy: torch.Tensor) -> torch.Tensor: # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): - """Get quant_min and quant_max args based on dtype and also - verify that they are within the range of possible quant_min/quant_max - for dtype + """Get quant_min and quant_max args based on dtype and also verify bounds. + + Args: + dtype: Target quantization dtype (e.g., torch.uint8, torch.int8, or FP8 types) + quant_min: Minimum quantized value, or None to use dtype default + quant_max: Maximum quantized value, or None to use dtype default + + Returns: + Tuple[int/float, int/float]: Validated (quant_min, quant_max) values + + Raises: + ValueError: If dtype is unsupported + AssertionError: If quant_min/quant_max are out of bounds for dtype """ if dtype in FP8_TYPES: quant_min_lower_bound, quant_max_upper_bound = ( @@ -357,11 +367,25 @@ def _quantize_affine( quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, ) -> torch.Tensor: - """op definition that has compatible signatures with custom op library + """Quantize tensor using affine quantization with integer zero point domain. + + Op definition that has compatible signatures with custom op library. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + output_dtype: Target quantized dtype (e.g., torch.uint8, torch.int8) + quant_min: Minimum quantized value, derived from dtype if None + quant_max: Maximum quantized value, derived from dtype if None + + Returns: + Quantized tensor with requested dtype Note: - zero_point_domain is pre-defined specifies how we quantize the floating point to quantized data: - INT: quantized_val = (float_val / scale) (integer) + zero_point (integer) + zero_point_domain is pre-defined as INT, meaning: + quantized_val = (float_val / scale) (integer) + zero_point (integer) """ quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) # workaround for uintx dtypes, since we don't have native Uintx dtype connected with @@ -386,12 +410,26 @@ def _quantize_affine_no_dtype_cast( quant_min: Union[int, float], quant_max: Union[int, float], ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization without dtype casting. + + Performs quantization with integer zero point domain without casting to target dtype. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + quant_min: Minimum quantized value + quant_max: Maximum quantized value + + Returns: + Quantized tensor without dtype casting + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = INT - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale and zero_point with zero_point_domain = INT + 3. Reshape the quantized result to original shape """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -428,7 +466,7 @@ def _quantize_affine_no_dtype_cast( return quant -def quantize_affine_tinygemm( +def _quantize_affine_tinygemm( input: torch.Tensor, block_size: List[int], scale: torch.Tensor, @@ -437,16 +475,31 @@ def quantize_affine_tinygemm( quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization with float zero point domain for tinygemm. + + Specialized quantization for tinygemm int4mm kernel where zero point is in floating point domain. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + output_dtype: Target quantized dtype (e.g., torch.uint8, torch.int8) + quant_min: Minimum quantized value, derived from dtype if None + quant_max: Maximum quantized value, derived from dtype if None + + Returns: + Quantized tensor with requested dtype + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = FLOAT - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale and zero_point with zero_point_domain = FLOAT + 3. Reshape the quantized result to original shape Note: - zero_point_domain is pre-defined specifies how we quantize the floating point to quantized data: - FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + zero_point_domain is pre-defined as FLOAT, meaning: + quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale """ quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) # workaround for uintx dtypes, since we don't have native Uintx dtype connected with @@ -471,12 +524,26 @@ def _quantize_affine_tinygemm_no_dtype_cast( quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization with float zero point domain without dtype casting. + + Specialized quantization for tinygemm int4mm kernel where zero point is in floating point domain. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + quant_min: Minimum quantized value + quant_max: Maximum quantized value + + Returns: + Quantized tensor without dtype casting + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = FLOAT - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale and zero_point with zero_point_domain = FLOAT + 3. Reshape the quantized result to original shape """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -513,7 +580,7 @@ def _quantize_affine_tinygemm_no_dtype_cast( return quant -def quantize_affine_no_zero_point( +def _quantize_affine_no_zero_point( input: torch.Tensor, block_size: List[int], scale: torch.Tensor, @@ -522,17 +589,32 @@ def quantize_affine_no_zero_point( quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization without zero point. + + Specialized quantization for cases where zero point is not needed (e.g., floatx quantization). + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (ignored, should be None) + output_dtype: Target quantized dtype (e.g., torch.uint8, torch.int8) + quant_min: Minimum quantized value, derived from dtype if None + quant_max: Maximum quantized value, derived from dtype if None + + Returns: + Quantized tensor with requested dtype + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = NONE - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale with zero_point_domain = NONE + 3. Reshape the quantized result to original shape Note: - zero_point_domain is pre-defined specifies how we quantize the floating point to quantized data: - None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization - Where we do not want to round values to nearest integer and instead scale and cast. + zero_point_domain is pre-defined as NONE, meaning: + quantized_val = (float_val / scale) | This is primarily used for floatx quantization + where we do not want to round values to nearest integer and instead scale and cast. """ quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) # workaround for uintx dtypes, since we don't have native Uintx dtype connected with @@ -557,12 +639,26 @@ def _quantize_affine_no_zero_point_no_dtype_cast( quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, ) -> torch.Tensor: - """ + """Quantize tensor using affine quantization without zero point and without dtype casting. + + Specialized quantization for cases where zero point is not needed without casting to target dtype. + + Args: + input: Input tensor to quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (ignored, should be None) + quant_min: Minimum quantized value + quant_max: Maximum quantized value + + Returns: + Quantized tensor without dtype casting + The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. quantize the input based on the quantization parameters scale and zero_point and zero_point_domain = NONE - 3. reshape the quantized result to origianl shape + 2. Quantize the input based on the quantization parameters scale with zero_point_domain = NONE + 3. Reshape the quantized result to original shape """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -648,7 +744,23 @@ def _dequantize_affine( quant_max: Optional[Union[int, float, bool]] = None, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - """op definition that has compatible signatures with custom op library""" + """Dequantize tensor using affine dequantization with integer zero point domain. + + Op definition that has compatible signatures with custom op library. + + Args: + input: Quantized tensor to dequantize + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + input_dtype: Expected dtype of input tensor (e.g., torch.uint8, torch.int8) + quant_min: Minimum quantized value for input tensor + quant_max: Maximum quantized value for input tensor + output_dtype: Target output dtype (default: torch.float32) + + Returns: + Dequantized tensor with requested output dtype + """ # TODO: validate scale/zero_point dimensions are compatible with block_size if input_dtype not in _SUB_BYTE_UINT_BOUNDS: assert input.dtype == input_dtype, ( @@ -680,13 +792,27 @@ def _dequantize_affine_no_dtype_check( quant_max: Union[int, float], output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - """This function converts AQT tensors to their high precision floating point representation + """Dequantize tensor using affine dequantization without dtype checking. + + Converts quantized tensors to their high precision floating point representation. + + Args: + input: Quantized tensor to dequantize + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + quant_min: Minimum quantized value for input tensor + quant_max: Maximum quantized value for input tensor + output_dtype: Target output dtype (default: torch.float32) + + Returns: + Dequantized tensor with requested output dtype The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain - 3. reshape the quantized result to origianl shape and change dtype to the output_dtype + 2. Dequantize the input based on the quantization parameters scale and zero_point + 3. Reshape the quantized result to original shape and change dtype to the output_dtype """ assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" @@ -723,13 +849,27 @@ def _dequantize_affine_no_zero_point_no_dtype_check( quant_max: Union[int, float], output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - """This function converts AQT tensors to their high precision floating point representation + """Dequantize tensor using affine dequantization without zero point and without dtype checking. + + Converts quantized tensors to their high precision floating point representation without zero point. + + Args: + input: Quantized tensor to dequantize + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (ignored, should be None) + quant_min: Minimum quantized value for input tensor + quant_max: Maximum quantized value for input tensor + output_dtype: Target output dtype (default: torch.float32) + + Returns: + Dequantized tensor with requested output dtype The op does the following: - 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + 1. Figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction - 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain - 3. reshape the quantized result to origianl shape and change dtype to the output_dtype + 2. Dequantize the input based on the quantization parameters scale (no zero point) + 3. Reshape the quantized result to original shape and change dtype to the output_dtype """ assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" @@ -745,7 +885,7 @@ def _dequantize_affine_no_zero_point_no_dtype_check( scale = scale.view(shape_after_reduction) assert zero_point is None, ( - "zero_point should be None for dequantize_affine_no_zero_point" + "zero_point should be None for _dequantize_affine_no_zero_point" ) dequant = input.to(output_dtype) dequant = dequant * scale @@ -753,7 +893,7 @@ def _dequantize_affine_no_zero_point_no_dtype_check( return dequant.view(original_shape).to(output_dtype) -def dequantize_affine_no_zero_point( +def _dequantize_affine_no_zero_point( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -848,7 +988,7 @@ def _dequantize_affine_tinygemm_no_dtype_check( return dequant.view(original_shape).to(output_dtype) -def dequantize_affine_tinygemm( +def _dequantize_affine_tinygemm( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -898,7 +1038,7 @@ def dequantize_affine_tinygemm( ) -def fake_quantize_affine( +def _fake_quantize_affine( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -946,7 +1086,7 @@ def fake_quantize_affine( return fq -def fake_quantize_affine_cachemask( +def _fake_quantize_affine_cachemask( input: torch.Tensor, block_size: Tuple[int, ...], scale: torch.Tensor, @@ -961,12 +1101,12 @@ def fake_quantize_affine_cachemask( This is equivalent to calling `quantize_affine` + `dequantize_affine` but without the dtype casts. - Note: Compared to :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`, + Note: Compared to :func:`~torchao.quantization.quant_primitives._fake_quantize_affine`, this consumes more memory and returns an additional outlier mask for intermediate quantized values. Args: - Same as :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`. + Same as :func:`~torchao.quantization.quant_primitives._fake_quantize_affine`. Returns: A 2-tuple of ( @@ -1003,8 +1143,25 @@ def _do_fake_quantize_affine( quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Helper function for `fake_quantize_affine` that returns both the + """Helper function for fake quantization that returns both intermediate and final values. + + Performs quantization followed by dequantization without dtype casting, returning both + the intermediate quantized values and the final dequantized values. + + Args: + input: Input tensor to fake quantize (float32, float16, or bfloat16) + block_size: Granularity of quantization - size of tensor elements sharing same qparam + scale: Quantization scale parameter + zero_point: Quantization zero point parameter (optional) + quant_dtype: Target quantized dtype for determining quant_min/quant_max + quant_min: Minimum quantized value, derived from dtype if None + quant_max: Maximum quantized value, derived from dtype if None + zero_point_domain: Domain of zero point (INT, FLOAT, or NONE) + + Returns: + Tuple of (intermediate quantized values, final dequantized values) + + Helper function for `_fake_quantize_affine` that returns both the intermediate quantized values and the final dequantized values. """ input_dtype = input.dtype @@ -1086,7 +1243,7 @@ def choose_qparams_affine( # TODO: lower this op to custom op library @torch.no_grad() -def choose_qparams_affine_tinygemm( +def _choose_qparams_affine_tinygemm( input: torch.Tensor, mapping_type: MappingType, block_size: Tuple[int], @@ -1157,7 +1314,7 @@ def choose_qparams_affine_tinygemm( # TODO: lower this op to custom op library -def choose_qparams_affine_dont_preserve_zero( +def _choose_qparams_affine_dont_preserve_zero( input: torch.Tensor, mapping_type: MappingType, block_size: Tuple[int], @@ -1427,7 +1584,7 @@ def _choose_qparams_affine( ) -def choose_qparams_and_quantize_affine_qqq( +def _choose_qparams_and_quantize_affine_qqq( w: torch.Tensor, num_bits: int, group_size: int, @@ -1497,7 +1654,7 @@ def reshape_w(w): return q_w, s_group, s_channel, w_ref -def choose_qparams_gguf( +def _choose_qparams_gguf( input: Optional[torch.Tensor], block_size: List[int], target_dtype: torch.dtype, @@ -1580,7 +1737,7 @@ def choose_qparams_gguf( ) -def quantize_gguf( +def _quantize_gguf( input: torch.Tensor, block_size: List[int], target_dtype: torch.dtype, @@ -1642,7 +1799,7 @@ def quantize_gguf( return int_data -def dequantize_gguf( +def _dequantize_gguf( input: torch.Tensor, block_size: List[int], target_dtype: torch.dtype, @@ -1705,7 +1862,7 @@ def dequantize_gguf( return dequant -def dequantize_affine_qqq( +def _dequantize_affine_qqq( w: torch.Tensor, s_group: torch.Tensor, s_channel: torch.Tensor, @@ -1845,7 +2002,7 @@ def _convert_to_affinequantized_format( # Main hqq quantizer function -def choose_qparams_and_quantize_affine_hqq( +def _choose_qparams_and_quantize_affine_hqq( tensor: torch.Tensor, nbits: float = 4, group_size: int = 64, @@ -1857,6 +2014,28 @@ def choose_qparams_and_quantize_affine_hqq( raw_output: bool = False, # If True, it will return the quant params in hqq lib format optimize_weights: Callable = optimize_weights_proximal_legacy, # weights proximal optimizer function ) -> tuple: + """Choose quantization parameters and quantize tensor using HQQ (Half-Quadratic Quantization). + + Performs quantization using HQQ method with optional weight optimization via proximal solver. + + Args: + tensor: Input tensor to quantize (float32, float16, or bfloat16) + nbits: Number of bits for quantization (default: 4) + group_size: Size of quantization groups (default: 64) + optimize: Whether to optimize weights using proximal solver (default: True) + axis: Axis along which to perform quantization (0 or 1, default: 1) + compute_dtype: Target compute dtype (default: torch.float16) + device: Target device for computation (default: "cuda") + verbose: Whether to print optimization error information (default: False) + raw_output: If True, return params in HQQ library format (default: False) + optimize_weights: Weight optimization function (default: optimize_weights_proximal_legacy) + + Returns: + Tuple of (quantized_weights, scale, zero_point, original_shape) + + Note: + Uses proximal solver to minimize ||W - dequantize(quantize(W))||_p^p for weight optimization. + """ assert axis in [0, 1], "axis should be either 0 or 1" if group_size is not None: assert _is_divisible(tensor.numel(), group_size), ( @@ -1939,9 +2118,25 @@ def choose_qparams_and_quantize_affine_hqq( return W_q, scale, zero, shape -def choose_qparams_affine_floatx( +def _choose_qparams_affine_floatx( tensor: torch.Tensor, ebits: int, mbits: int ) -> torch.Tensor: + """Choose quantization parameters for floatx quantization. + + Calculates scale parameter for quantizing to custom floating point format. + + Args: + tensor: Input tensor to quantize (float32, float16, or bfloat16) + ebits: Number of exponent bits in target floatx format + mbits: Number of mantissa bits in target floatx format + + Returns: + Scale tensor for floatx quantization + + Note: + Uses global lookup table as workaround for torch.compile() compatibility + since _n_ones() is not compatible due to << operator. + """ # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -1959,7 +2154,7 @@ def choose_qparams_affine_floatx( return scale.to(dtype) -def quantize_affine_floatx( +def _quantize_affine_floatx( tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int ) -> torch.Tensor: """Quantizes the float32 high precision floating point tensor to low precision floating point number and @@ -1970,7 +2165,7 @@ def quantize_affine_floatx( return tensor_floatx -def dequantize_affine_floatx( +def _dequantize_affine_floatx( tensor: torch.Tensor, scale: torch.Tensor, ebits: int, @@ -1983,7 +2178,7 @@ def dequantize_affine_floatx( return tensor -def choose_qparams_affine_float8( +def _choose_qparams_affine_float8( tensor: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, scale_dtype: torch.dtype = torch.float32, @@ -2075,7 +2270,7 @@ def _expand_scale_to_tensor_shape( return expanded_scale -def quantize_affine_float8( +def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, @@ -2095,7 +2290,7 @@ def quantize_affine_float8( return fp8_tensor -def dequantize_affine_float8( +def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, output_dtype: torch.dtype = torch.float32, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3c968e2d40..c7dd92d55c 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -15,15 +15,15 @@ from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, + _choose_qparams_affine_dont_preserve_zero, + _choose_qparams_affine_tinygemm, + _dequantize_affine_no_zero_point, + _dequantize_affine_tinygemm, + _quantize_affine_no_zero_point, + _quantize_affine_tinygemm, choose_qparams_affine, - choose_qparams_affine_dont_preserve_zero, - choose_qparams_affine_tinygemm, dequantize_affine, - dequantize_affine_no_zero_point, - dequantize_affine_tinygemm, quantize_affine, - quantize_affine_no_zero_point, - quantize_affine_tinygemm, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -357,7 +357,7 @@ def get_groupwise_affine_qparams( ) if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - scale, zero_point = choose_qparams_affine_tinygemm( + scale, zero_point = _choose_qparams_affine_tinygemm( w, mapping_type, block_size, @@ -369,7 +369,7 @@ def get_groupwise_affine_qparams( zero_point_dtype=zero_point_dtype, ) elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - scale, zero_point = choose_qparams_affine_dont_preserve_zero( + scale, zero_point = _choose_qparams_affine_dont_preserve_zero( w, mapping_type, block_size, @@ -439,9 +439,9 @@ def groupwise_affine_quantize_tensor_from_qparams( if zero_point_domain == ZeroPointDomain.INT: _quantize_affine = quantize_affine elif zero_point_domain == ZeroPointDomain.FLOAT: - _quantize_affine = quantize_affine_tinygemm + _quantize_affine = _quantize_affine_tinygemm elif ZeroPointDomain == ZeroPointDomain.NONE: - _quantize_affine = quantize_affine_no_zero_point + _quantize_affine = _quantize_affine_no_zero_point else: raise ValueError(f"Unrecognized zero point domain: {zero_point_domain}") @@ -508,9 +508,9 @@ def groupwise_affine_dequantize_tensor_from_qparams( if zero_point_domain == ZeroPointDomain.INT: _dequantize_affine = dequantize_affine elif zero_point_domain == ZeroPointDomain.FLOAT: - _dequantize_affine = dequantize_affine_tinygemm + _dequantize_affine = _dequantize_affine_tinygemm else: - _dequantize_affine = dequantize_affine_no_zero_point + _dequantize_affine = _dequantize_affine_no_zero_point return _dequantize_affine( w_int32, block_size, diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index ab7a2b4f37..df824e506f 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -48,7 +48,7 @@ LinearActivationQuantizedTensor, MappingType, PerTensor, - fake_quantize_affine, + _fake_quantize_affine, quantize_, to_linear_activation_quantized, ) @@ -237,7 +237,9 @@ def forward_pre_hook( new_input = [] for inp in args[0]: new_input.append( - fake_quantize_affine(inp, inp.shape, input_scale, input_zp, torch.uint8) + _fake_quantize_affine( + inp, inp.shape, input_scale, input_zp, torch.uint8 + ) ) mt = MultiTensor(new_input)