Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,18 @@ def get_observer(self) -> str:


def round_to_quantized_type_dtype(
tensor: torch.Tensor, dtype: torch.dtype
tensor: torch.Tensor,
dtype: torch.dtype,
cast_to_original_dtype: Optional[bool] = True,
) -> torch.Tensor:
"""
Rounds an input tensor to the nearest quantized representation given a dtype.
The original dtype is kept post-rounding.

:param tensor: tensor to round
:param dtype: dtype to use for rounding
:param cast_to_original_dtype: whether or not we cast the rounded tensor to
the original dtype
:return: rounded tensor
"""
original_dtype = tensor.dtype
Expand All @@ -425,14 +429,17 @@ def round_to_quantized_type_dtype(
iinfo = torch.iinfo(dtype)
rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max))

return rounded.to(original_dtype)
if cast_to_original_dtype:
return rounded.to(original_dtype)
return rounded


def round_to_quantized_type_args(
tensor: torch.Tensor,
args: QuantizationArgs,
min: torch.Tensor,
max: torch.Tensor,
cast_to_original_dtype: Optional[bool] = True,
) -> torch.Tensor:
"""
Rounds an input tensor to the nearest quantized representation given
Expand All @@ -442,6 +449,8 @@ def round_to_quantized_type_args(
:param args: quantization args to use for rounding
:param min: min value to use for clamping
:param max: max value to use for clamping
:param cast_to_original_dtype: whether or not we cast the rounded tensor to
the original dtype
:return: rounded tensor
"""

Expand All @@ -459,4 +468,6 @@ def round_to_quantized_type_args(
else:
raise ValueError(f"Invalid quantization type {args.type}")

return rounded.to(original_dtype)
if cast_to_original_dtype:
return rounded.to(original_dtype)
return rounded
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def calculate_qparams(

# 5. Round the zp to zp_dtype
zero_points = round_to_quantized_type_dtype(
zero_points, dtype=quantization_args.zp_dtype
zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
)

if scales.ndim == 0:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_quantization/lifecycle/test_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@


def test_quantization_enabled_disabled():
inp = torch.randn(16, dtype=torch.bfloat16)
model = Linear(16, 16, dtype=torch.bfloat16)
inp = torch.randn(16)
model = Linear(16, 16)
quantized_model = deepcopy(model)
apply_quantization_config(
model=quantized_model,
Expand Down