From 232aa5deb2b00d4f3ef5b0c3b2ed6251748323dd Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 16:48:33 -0500 Subject: [PATCH 1/4] flakey? --- tests/test_quantization/lifecycle/test_enabled.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_quantization/lifecycle/test_enabled.py b/tests/test_quantization/lifecycle/test_enabled.py index 24be64bf..87e25d55 100644 --- a/tests/test_quantization/lifecycle/test_enabled.py +++ b/tests/test_quantization/lifecycle/test_enabled.py @@ -26,8 +26,8 @@ def test_quantization_enabled_disabled(): - inp = torch.randn(16, dtype=torch.bfloat16) - model = Linear(16, 16, dtype=torch.bfloat16) + inp = torch.randn(16) + model = Linear(16, 16) quantized_model = deepcopy(model) apply_quantization_config( model=quantized_model, From 824e12f5ead957f3b48e5b74ada854c8139f011d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 17:09:46 -0500 Subject: [PATCH 2/4] update --- .../quantization/quant_args.py | 17 ++++++++++++++--- .../quantization/utils/helpers.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index df11d1ec..450ccb37 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -394,7 +394,9 @@ def get_observer(self) -> str: def round_to_quantized_type_dtype( - tensor: torch.Tensor, dtype: torch.dtype + tensor: torch.Tensor, + dtype: torch.dtype, + cast_to_original_dtype: Optional[bool] = True, ) -> torch.Tensor: """ Rounds an input tensor to the nearest quantized representation given a dtype. @@ -402,6 +404,8 @@ def round_to_quantized_type_dtype( :param tensor: tensor to round :param dtype: dtype to use for rounding + :param cast_to_original_dtype: whether or not we cast the rounded tensor to + the original dtype :return: rounded tensor """ original_dtype = tensor.dtype @@ -412,7 +416,9 @@ def round_to_quantized_type_dtype( iinfo = torch.iinfo(dtype) rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max)) - return rounded.to(original_dtype) + if cast_to_original_dtype: + return rounded.to(original_dtype) + return rounded def round_to_quantized_type_args( @@ -420,6 +426,7 @@ def round_to_quantized_type_args( args: QuantizationArgs, min: torch.Tensor, max: torch.Tensor, + cast_to_original_dtype: Optional[bool] = True, ) -> torch.Tensor: """ Rounds an input tensor to the nearest quantized representation given @@ -429,6 +436,8 @@ def round_to_quantized_type_args( :param args: quantization args to use for rounding :param min: min value to use for clamping :param max: max value to use for clamping + :param cast_to_original_dtype: whether or not we cast the rounded tensor to + the original dtype :return: rounded tensor """ @@ -446,4 +455,6 @@ def round_to_quantized_type_args( else: raise ValueError(f"Invalid quantization type {args.type}") - return rounded.to(original_dtype) + if cast_to_original_dtype: + return rounded.to(original_dtype) + return rounded diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5e728dd5..45a4ef83 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -127,7 +127,7 @@ def calculate_qparams( # 5. Round the zp to zp_dtype zero_points = round_to_quantized_type_dtype( - zero_points, dtype=quantization_args.zp_dtype + zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False ) if scales.ndim == 0: From 7f14dac3a20eb1f3bf40f83da0a255b99c315dbd Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 17:14:41 -0500 Subject: [PATCH 3/4] try float16 --- tests/test_quantization/lifecycle/test_enabled.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_quantization/lifecycle/test_enabled.py b/tests/test_quantization/lifecycle/test_enabled.py index 87e25d55..63eb8e8b 100644 --- a/tests/test_quantization/lifecycle/test_enabled.py +++ b/tests/test_quantization/lifecycle/test_enabled.py @@ -26,8 +26,8 @@ def test_quantization_enabled_disabled(): - inp = torch.randn(16) - model = Linear(16, 16) + inp = torch.randn(16, dtype=torch.float16) + model = Linear(16, 16, dtype=torch.float16) quantized_model = deepcopy(model) apply_quantization_config( model=quantized_model, From b805b7798729b391f1b5e0cd4d2bd76ca8fc2cf5 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 17:33:49 -0500 Subject: [PATCH 4/4] update --- tests/test_quantization/lifecycle/test_enabled.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_quantization/lifecycle/test_enabled.py b/tests/test_quantization/lifecycle/test_enabled.py index 63eb8e8b..87e25d55 100644 --- a/tests/test_quantization/lifecycle/test_enabled.py +++ b/tests/test_quantization/lifecycle/test_enabled.py @@ -26,8 +26,8 @@ def test_quantization_enabled_disabled(): - inp = torch.randn(16, dtype=torch.float16) - model = Linear(16, 16, dtype=torch.float16) + inp = torch.randn(16) + model = Linear(16, 16) quantized_model = deepcopy(model) apply_quantization_config( model=quantized_model,