diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 7444c3dbb5..81ac9a0f4b 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -17,13 +17,13 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torchao import quantize_ -from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4 from torchao.quantization.granularity import ( PerAxis, PerGroup, PerRow, PerToken, ) +from torchao.quantization.MOVED_GPTQ import _replace_linear_8da4w, _replace_linear_int4 from torchao.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, @@ -284,7 +284,7 @@ def _set_ptq_weight( Set the weight to the quantized version of the given fp32 weights, for making linear outputs comparable with QAT. """ - from torchao.quantization.GPTQ import ( + from torchao.quantization.MOVED_GPTQ import ( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) @@ -332,7 +332,7 @@ def _set_ptq_weight( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) def test_qat_8da4w_linear(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear group_size = 128 @@ -365,7 +365,7 @@ def test_qat_8da4w_linear(self): not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) def test_qat_8da4w_quantizer(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer + from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightQuantizer from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer group_size = 16 @@ -683,7 +683,7 @@ def test_qat_4w_primitives(self): ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.GPTQ import WeightOnlyInt4Linear + from torchao.quantization.MOVED_GPTQ import WeightOnlyInt4Linear from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 @@ -730,7 +730,7 @@ def test_qat_4w_quantizer_gradients(self): ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + from torchao.quantization.MOVED_GPTQ import Int4WeightOnlyQuantizer from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b4ec9f4785..cd13c04165 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -310,7 +310,7 @@ def api(model): not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" ) def test_8da4w_quantizer(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) @@ -325,7 +325,7 @@ def test_8da4w_quantizer(self): not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" ) def test_8da4w_quantizer_linear_bias(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) @@ -343,7 +343,7 @@ def test_8da4w_quantizer_linear_bias(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_8da4w_gptq_quantizer(self): from torchao._models._eval import InputRecorder, TransformerEvalWrapper - from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer + from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightGPTQQuantizer # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 @@ -519,7 +519,7 @@ def test_gptq_quantizer_int4_weight_only(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_quantizer_int4_weight_only(self): from torchao._models._eval import TransformerEvalWrapper - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + from torchao.quantization.MOVED_GPTQ import Int4WeightOnlyQuantizer precision = torch.bfloat16 device = "cuda" @@ -648,7 +648,7 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): ) # reference - from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer( diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/MOVED_GPTQ.py similarity index 100% rename from torchao/quantization/GPTQ.py rename to torchao/quantization/MOVED_GPTQ.py diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 73ccd2e0ff..311a46c1b3 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -13,13 +13,6 @@ OTHER_AUTOQUANT_CLASS_LIST, autoquant, ) -from .GPTQ import ( - Int4WeightOnlyGPTQQuantizer, - Int4WeightOnlyQuantizer, - Int8DynActInt4WeightGPTQQuantizer, - Int8DynActInt4WeightLinear, - Int8DynActInt4WeightQuantizer, -) from .granularity import ( PerAxis, PerGroup, @@ -34,6 +27,13 @@ from .linear_activation_scale import ( to_weight_tensor_with_linear_activation_scale_metadata, ) +from .MOVED_GPTQ import ( + Int4WeightOnlyGPTQQuantizer, + Int4WeightOnlyQuantizer, + Int8DynActInt4WeightGPTQQuantizer, + Int8DynActInt4WeightLinear, + Int8DynActInt4WeightQuantizer, +) from .observer import ( AffineQuantizedMinMaxObserver, AffineQuantizedObserverBase, diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 7c32bc4b19..4df6ad0a14 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -10,7 +10,8 @@ import torch.nn.functional as F from torchao.dtypes.utils import is_device -from torchao.quantization.GPTQ import ( +from torchao.quantization.granularity import PerGroup +from torchao.quantization.MOVED_GPTQ import ( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, _check_linear_int4_k, @@ -18,7 +19,6 @@ _replace_linear_int4, groupwise_affine_quantize_tensor, ) -from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_primitives import ( TorchAODType, ZeroPointDomain, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 37f0cf5bfe..773c1d27b0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -83,12 +83,6 @@ ) from .autoquant import AutoQuantizableLinearWeight, autoquant -from .GPTQ import ( - Int4WeightOnlyGPTQQuantizer, - Int4WeightOnlyQuantizer, - Int8DynActInt4WeightGPTQQuantizer, - Int8DynActInt4WeightQuantizer, -) from .granularity import ( Granularity, PerAxis, @@ -100,6 +94,12 @@ LinearActivationQuantizedTensor, to_linear_activation_quantized, ) +from .MOVED_GPTQ import ( + Int4WeightOnlyGPTQQuantizer, + Int4WeightOnlyQuantizer, + Int8DynActInt4WeightGPTQQuantizer, + Int8DynActInt4WeightQuantizer, +) from .qat import ( intx_quantization_aware_training, )