Skip to content

Commit f1f2304

Browse files
authored
[Tests] Small Fixes (#516)
* flakey? * update * try float16 * update
1 parent ba35114 commit f1f2304

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,14 +407,18 @@ def get_observer(self) -> str:
407407

408408

409409
def round_to_quantized_type_dtype(
410-
tensor: torch.Tensor, dtype: torch.dtype
410+
tensor: torch.Tensor,
411+
dtype: torch.dtype,
412+
cast_to_original_dtype: Optional[bool] = True,
411413
) -> torch.Tensor:
412414
"""
413415
Rounds an input tensor to the nearest quantized representation given a dtype.
414416
The original dtype is kept post-rounding.
415417
416418
:param tensor: tensor to round
417419
:param dtype: dtype to use for rounding
420+
:param cast_to_original_dtype: whether or not we cast the rounded tensor to
421+
the original dtype
418422
:return: rounded tensor
419423
"""
420424
original_dtype = tensor.dtype
@@ -425,14 +429,17 @@ def round_to_quantized_type_dtype(
425429
iinfo = torch.iinfo(dtype)
426430
rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max))
427431

428-
return rounded.to(original_dtype)
432+
if cast_to_original_dtype:
433+
return rounded.to(original_dtype)
434+
return rounded
429435

430436

431437
def round_to_quantized_type_args(
432438
tensor: torch.Tensor,
433439
args: QuantizationArgs,
434440
min: torch.Tensor,
435441
max: torch.Tensor,
442+
cast_to_original_dtype: Optional[bool] = True,
436443
) -> torch.Tensor:
437444
"""
438445
Rounds an input tensor to the nearest quantized representation given
@@ -442,6 +449,8 @@ def round_to_quantized_type_args(
442449
:param args: quantization args to use for rounding
443450
:param min: min value to use for clamping
444451
:param max: max value to use for clamping
452+
:param cast_to_original_dtype: whether or not we cast the rounded tensor to
453+
the original dtype
445454
:return: rounded tensor
446455
"""
447456

@@ -459,4 +468,6 @@ def round_to_quantized_type_args(
459468
else:
460469
raise ValueError(f"Invalid quantization type {args.type}")
461470

462-
return rounded.to(original_dtype)
471+
if cast_to_original_dtype:
472+
return rounded.to(original_dtype)
473+
return rounded

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def calculate_qparams(
127127

128128
# 5. Round the zp to zp_dtype
129129
zero_points = round_to_quantized_type_dtype(
130-
zero_points, dtype=quantization_args.zp_dtype
130+
zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
131131
)
132132

133133
if scales.ndim == 0:

tests/test_quantization/lifecycle/test_enabled.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727

2828
def test_quantization_enabled_disabled():
29-
inp = torch.randn(16, dtype=torch.bfloat16)
30-
model = Linear(16, 16, dtype=torch.bfloat16)
29+
inp = torch.randn(16)
30+
model = Linear(16, 16)
3131
quantized_model = deepcopy(model)
3232
apply_quantization_config(
3333
model=quantized_model,

0 commit comments

Comments
 (0)