Skip to content

Commit 98917b7

Browse files
committed
fix qparams decompression
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent 8471264 commit 98917b7

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
2121
from compressed_tensors.registry import RegistryMixin
2222
from compressed_tensors.utils import has_offloaded_params
23+
from compressed_tensors.utils.offload import (
24+
delete_offload_parameter,
25+
get_offloaded_device,
26+
register_offload_parameter,
27+
)
2328
from torch import Tensor
2429
from torch.nn import Module
2530

@@ -185,10 +190,21 @@ def decompress_module(self, module: Module):
185190
for name, parameter in module.named_parameters():
186191
compressed_data[name] = parameter
187192

188-
return self.decompress_weight(
193+
result = self.decompress_weight(
189194
compressed_data=compressed_data, quantization_args=quantization_args
190195
).to(device)
191196

197+
# Update module's parameters if they were unpacked/upcast during decompression
198+
for param_name in ["weight_zero_point", "weight_scale"]:
199+
if param_name in compressed_data and hasattr(module, param_name):
200+
# Delete the old parameter and register the updated one
201+
delete_offload_parameter(module, param_name)
202+
offload_device = get_offloaded_device(module)
203+
param = torch.nn.Parameter(compressed_data[param_name], requires_grad=False)
204+
register_offload_parameter(module, param_name, param, offload_device)
205+
206+
return result
207+
192208
def decompress_weight(
193209
self, compressed_data: Dict[str, Tensor], **kwargs
194210
) -> torch.Tensor:

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,7 @@ def _skip_zp(
155155
if zp_name == "output_zero_point":
156156
args = scheme.output_activations
157157

158-
symmetric = args.symmetric
159-
packable_strategies = [
160-
QuantizationStrategy.GROUP.value,
161-
QuantizationStrategy.CHANNEL.value,
162-
]
163-
packed = (
164-
isinstance(self, PackedQuantizationCompressor)
165-
and args.strategy in packable_strategies
166-
)
167-
168-
return symmetric or packed
158+
return args.symmetric
169159

170160
def decompress(
171161
self,

src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def decompress_weight(
117117
m, n = weight.shape
118118
# TODO: use a user provided dequant dtype
119119
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
120+
121+
# cast scale dtype to match unpacked dtype for dequantization
122+
if scale.dtype != unpacked.dtype:
123+
scale = scale.to(unpacked.dtype)
124+
compressed_data["weight_scale"] = scale
125+
120126
decompressed_weight = dequantize(
121127
x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
122128
)

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def decompress_weight(
175175
zero_point = unpack_from_int32(
176176
zero_point, num_bits, original_zp_shape, packed_dim=0
177177
)
178+
# Update the compressed_data dict with the unpacked zero_point
179+
compressed_data["weight_zero_point"] = zero_point
178180

179181
decompressed_weight = dequantize(
180182
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx

0 commit comments

Comments
 (0)