Skip to content

Commit 263179a

Browse files
authored
fix xpu ipex linear in torch2.7 (#1618)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 5027e64 commit 263179a

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

bitsandbytes/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,16 @@ def enable_ipex_fusion(linear, x):
240240
)
241241
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
242242
converted_weight = reverse_4bit_compress_format(linear.weight.data)
243-
new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
244243
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
245244
new_zeros = None
246245
compensation = None
246+
new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
247+
# ipex 2.7 requires new_scales is a list of tensors
248+
if _ipex_xpu_version_prereq(2, 7):
249+
new_scales = list(new_scales)
250+
# ipex 2.7 can dequant converted_weight directly.
251+
if linear.training or x.requires_grad == False:
252+
new_weight = converted_weight
247253
else:
248254
raise ValueError(
249255
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5"

0 commit comments

Comments
 (0)