File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -240,10 +240,16 @@ def enable_ipex_fusion(linear, x):
240
240
)
241
241
elif x .device .type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq (2 , 5 ):
242
242
converted_weight = reverse_4bit_compress_format (linear .weight .data )
243
- new_weight = converted_weight .reshape ([quant_state .shape [0 ], quant_state .shape [1 ] // 2 ])
244
243
new_scales = quant_state .absmax .view (quant_state .shape [0 ], quant_state .shape [1 ] // quant_state .blocksize )
245
244
new_zeros = None
246
245
compensation = None
246
+ new_weight = converted_weight .reshape ([quant_state .shape [0 ], quant_state .shape [1 ] // 2 ])
247
+ # ipex 2.7 requires new_scales is a list of tensors
248
+ if _ipex_xpu_version_prereq (2 , 7 ):
249
+ new_scales = list (new_scales )
250
+ # ipex 2.7 can dequant converted_weight directly.
251
+ if linear .training or x .requires_grad == False :
252
+ new_weight = converted_weight
247
253
else :
248
254
raise ValueError (
249
255
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5"
You can’t perform that action at this time.
0 commit comments