Skip to content

Commit d5c9a11

Browse files
geng-metafacebook-github-bot
authored andcommitted
Correct the inaccurate docstring for int4_row_quantize()
Summary: X-link: facebookresearch/FBGEMM#1930 When I was reading the code in `quantize.py`, the docstring in int4_row_quantize() seems inaccurate. The `int4_row_quantize()` function returns `wq` with shape `[N, K]`, not `[N, K // 2]`: ``` # Line 101: Concatenate chunks back to original K dimension out = torch.cat(out, dim=-1) # Line 104: Convert to int8 dtype out = out.to(dtype=torch.int8)` ``` So `wq` is `[N, K]` stored as `int8` elements, where each `int8` element contains a single int4 value. And after it, the actual packing could be done via: https://www.internalfb.com/code/fbsource/[c12bfdd174f7897f4615f98b630f2ac612c471fa]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py?lines=18 Differential Revision: D82919480 Privacy Context Container: 151967047006994
1 parent 3cefe05 commit d5c9a11

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def int4_row_quantize(
8181
x (Tensor): [N, K] Higher precision weight tensor to quantize.
8282
group_size (int): Number of elements to calculate group scale for.
8383
Returns:
84-
wq (Tensor): [N, K // 2] Quantized int4 tensor stored in int8 elements.
84+
wq (Tensor): [N, K] Quantized int4 tensor stored in int8 elements.
8585
group_scale (Tensor): [K / group_size, N] FP32 Scale per group.
8686
"""
8787
n_bit = 4 # Number of target bits.

0 commit comments

Comments
 (0)