Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit bf7d0f6

Browse files
committed
shape consistency
1 parent c894305 commit bf7d0f6

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,22 @@ def fasterprune(
179179
fake_quantize,
180180
)
181181

182+
while scale.ndim < 2:
183+
scale = scale.unsqueeze(1)
184+
zero_point = zero_point.unsqueeze(1)
185+
186+
while q.ndim < 2:
187+
q = q.unsqueeze(1)
182188
q = fake_quantize(
183-
q, scale, zero_point, self.layer.quantization_scheme.weights
189+
q,
190+
scale[:, i],
191+
zero_point[:, i],
192+
self.layer.quantization_scheme.weights,
184193
)
185194

195+
while q.ndim != 1:
196+
q.squeeze()
197+
186198
Q1[:, i] = q
187199
Losses1[:, i] = (w - q) ** 2 / d**2
188200

0 commit comments

Comments
 (0)