Skip to content

Commit bc50dc0

Browse files
authored
Int4Linear should have the same shape regardless of device
This make Int4 quantization work for both CPU and CUDA devices
1 parent d474616 commit bc50dc0

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

quantize.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -500,16 +500,10 @@ def __init__(
500500

501501
assert out_features % 8 == 0, "require out_features % 8 == 0"
502502
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
503-
if use_cuda:
504-
self.register_buffer(
505-
"weight",
506-
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
507-
)
508-
else:
509-
self.register_buffer(
510-
"weight",
511-
torch.empty((out_features, in_features // 2), dtype=torch.uint8)
512-
)
503+
self.register_buffer(
504+
"weight",
505+
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
506+
)
513507
self.register_buffer(
514508
"scales_and_zeros",
515509
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)

0 commit comments

Comments
 (0)