Skip to content

Commit 3103e7e

Browse files
elfisworkingyumin
andauthored
Temporary fix for QAT quantizer when linear layer bias is True (#1087)
Temporary fix for QAT when linear layer bias is True Signed-off-by: yumin <zhangym33@chinatelecom.cn> Co-authored-by: yumin <zhangym33@chinatelecom.cn>
1 parent 6b52996 commit 3103e7e

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

test/quantization/test_qat.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
)
3535
from torchao.quantization.prototype.qat.linear import (
3636
FakeQuantizedLinear,
37+
Int8DynActInt4WeightQATLinear,
38+
Int4WeightOnlyQATLinear
3739
)
3840
from torchao.quantization.prototype.qat.utils import (
3941
_choose_qparams_per_token_asymmetric,
@@ -66,6 +68,10 @@
6668
TORCH_VERSION_AT_LEAST_2_5,
6769
)
6870

71+
from torchao.quantization.GPTQ import (
72+
_replace_linear_8da4w,
73+
_replace_linear_int4
74+
)
6975

7076
# TODO: put this in a common test utils file
7177
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
@@ -854,6 +860,48 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
854860
fq_out = fq_linear(x)
855861
baseline_out = linear_forward_4w(x2, fq_linear.weight)
856862
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
863+
864+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
865+
def test_replace_linear_8da4w(self):
866+
module = torch.nn.ModuleList([
867+
torch.nn.Linear(in_features=256, out_features=50, bias=True)
868+
])
869+
_replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True)
870+
assert(not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(module[0], torch.nn.Linear))
871+
module = torch.nn.ModuleList([
872+
torch.nn.Linear(in_features=256, out_features=50, bias=False)
873+
])
874+
_replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True)
875+
assert(isinstance(module[0], Int8DynActInt4WeightQATLinear))
876+
877+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
878+
def test_replace_linear_int4(self):
879+
module = torch.nn.ModuleList([
880+
torch.nn.Linear(in_features=256, out_features=50, bias=True)
881+
])
882+
_replace_linear_int4(
883+
module,
884+
256,
885+
8,
886+
padding_allowed=True,
887+
precision=torch.bfloat16,
888+
scales_precision=torch.bfloat16,
889+
linear_class=Int4WeightOnlyQATLinear,
890+
copy_weights=True)
891+
assert(not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance(module[0], torch.nn.Linear))
892+
module = torch.nn.ModuleList([
893+
torch.nn.Linear(in_features=256, out_features=50, bias=False)
894+
])
895+
_replace_linear_int4(
896+
module,
897+
256,
898+
8,
899+
padding_allowed=True,
900+
precision=torch.bfloat16,
901+
scales_precision=torch.bfloat16,
902+
linear_class=Int4WeightOnlyQATLinear,
903+
copy_weights=True)
904+
assert(isinstance(module[0], Int4WeightOnlyQATLinear))
857905

858906
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
859907
def test_fake_quantized_embedding_4w(self):
@@ -891,4 +939,4 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
891939

892940

893941
if __name__ == "__main__":
894-
unittest.main()
942+
unittest.main()

torchao/quantization/GPTQ.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,8 @@ def _replace_linear_int4(
617617
copy_weights: bool = False,
618618
):
619619
for name, child in module.named_children():
620-
if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)):
620+
# TODO: support linear bias
621+
if isinstance(child, nn.Linear) and child.bias is None and (skip_layer_func is None or not skip_layer_func(child.weight)):
621622
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
622623
new_linear = linear_class(
623624
child.in_features,
@@ -979,7 +980,8 @@ def _replace_linear_8da4w(
979980
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
980981

981982
def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
982-
return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)
983+
# TODO: support linear bias
984+
return isinstance(child, nn.Linear) and child.bias is None and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)
983985

984986
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
985987
new_linear = linear_class(

0 commit comments

Comments
 (0)