34
34
)
35
35
from torchao .quantization .prototype .qat .linear import (
36
36
FakeQuantizedLinear ,
37
+ Int8DynActInt4WeightQATLinear ,
38
+ Int4WeightOnlyQATLinear
37
39
)
38
40
from torchao .quantization .prototype .qat .utils import (
39
41
_choose_qparams_per_token_asymmetric ,
66
68
TORCH_VERSION_AT_LEAST_2_5 ,
67
69
)
68
70
71
+ from torchao .quantization .GPTQ import (
72
+ _replace_linear_8da4w ,
73
+ _replace_linear_int4
74
+ )
69
75
70
76
# TODO: put this in a common test utils file
71
77
_CUDA_IS_AVAILABLE = torch .cuda .is_available ()
@@ -854,6 +860,48 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
854
860
fq_out = fq_linear (x )
855
861
baseline_out = linear_forward_4w (x2 , fq_linear .weight )
856
862
torch .testing .assert_close (baseline_out , fq_out , atol = 0 , rtol = 0 )
863
+
864
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
865
+ def test_replace_linear_8da4w (self ):
866
+ module = torch .nn .ModuleList ([
867
+ torch .nn .Linear (in_features = 256 , out_features = 50 , bias = True )
868
+ ])
869
+ _replace_linear_8da4w (module , 256 , False , torch .float32 , torch .float32 , Int8DynActInt4WeightQATLinear , copy_weights = True )
870
+ assert (not isinstance (module [0 ], Int8DynActInt4WeightQATLinear ) and isinstance (module [0 ], torch .nn .Linear ))
871
+ module = torch .nn .ModuleList ([
872
+ torch .nn .Linear (in_features = 256 , out_features = 50 , bias = False )
873
+ ])
874
+ _replace_linear_8da4w (module , 256 , False , torch .float32 , torch .float32 , Int8DynActInt4WeightQATLinear , copy_weights = True )
875
+ assert (isinstance (module [0 ], Int8DynActInt4WeightQATLinear ))
876
+
877
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
878
+ def test_replace_linear_int4 (self ):
879
+ module = torch .nn .ModuleList ([
880
+ torch .nn .Linear (in_features = 256 , out_features = 50 , bias = True )
881
+ ])
882
+ _replace_linear_int4 (
883
+ module ,
884
+ 256 ,
885
+ 8 ,
886
+ padding_allowed = True ,
887
+ precision = torch .bfloat16 ,
888
+ scales_precision = torch .bfloat16 ,
889
+ linear_class = Int4WeightOnlyQATLinear ,
890
+ copy_weights = True )
891
+ assert (not isinstance (module [0 ], Int4WeightOnlyQATLinear ) and isinstance (module [0 ], torch .nn .Linear ))
892
+ module = torch .nn .ModuleList ([
893
+ torch .nn .Linear (in_features = 256 , out_features = 50 , bias = False )
894
+ ])
895
+ _replace_linear_int4 (
896
+ module ,
897
+ 256 ,
898
+ 8 ,
899
+ padding_allowed = True ,
900
+ precision = torch .bfloat16 ,
901
+ scales_precision = torch .bfloat16 ,
902
+ linear_class = Int4WeightOnlyQATLinear ,
903
+ copy_weights = True )
904
+ assert (isinstance (module [0 ], Int4WeightOnlyQATLinear ))
857
905
858
906
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
859
907
def test_fake_quantized_embedding_4w (self ):
@@ -891,4 +939,4 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
891
939
892
940
893
941
if __name__ == "__main__" :
894
- unittest .main ()
942
+ unittest .main ()
0 commit comments