Skip to content

[do not land] testing if moving this breaks my PRs #2283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401

from torchao import quantize_
from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4
from torchao.quantization.granularity import (
PerAxis,
PerGroup,
PerRow,
PerToken,
)
from torchao.quantization.MOVED_GPTQ import _replace_linear_8da4w, _replace_linear_int4
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
Expand Down Expand Up @@ -284,7 +284,7 @@ def _set_ptq_weight(
Set the weight to the quantized version of the given fp32 weights,
for making linear outputs comparable with QAT.
"""
from torchao.quantization.GPTQ import (
from torchao.quantization.MOVED_GPTQ import (
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
Expand Down Expand Up @@ -332,7 +332,7 @@ def _set_ptq_weight(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_qat_8da4w_linear(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear

group_size = 128
Expand Down Expand Up @@ -365,7 +365,7 @@ def test_qat_8da4w_linear(self):
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_qat_8da4w_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

group_size = 16
Expand Down Expand Up @@ -683,7 +683,7 @@ def test_qat_4w_primitives(self):
)
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
from torchao.quantization.MOVED_GPTQ import WeightOnlyInt4Linear
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear

group_size = 128
Expand Down Expand Up @@ -730,7 +730,7 @@ def test_qat_4w_quantizer_gradients(self):
)
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao.quantization.MOVED_GPTQ import Int4WeightOnlyQuantizer
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer

group_size = 32
Expand Down
10 changes: 5 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def api(model):
not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower"
)
def test_8da4w_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
Expand All @@ -325,7 +325,7 @@ def test_8da4w_quantizer(self):
not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower"
)
def test_8da4w_quantizer_linear_bias(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
Expand All @@ -343,7 +343,7 @@ def test_8da4w_quantizer_linear_bias(self):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_8da4w_gptq_quantizer(self):
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightGPTQQuantizer

# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
Expand Down Expand Up @@ -519,7 +519,7 @@ def test_gptq_quantizer_int4_weight_only(self):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4_weight_only(self):
from torchao._models._eval import TransformerEvalWrapper
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao.quantization.MOVED_GPTQ import Int4WeightOnlyQuantizer

precision = torch.bfloat16
device = "cuda"
Expand Down Expand Up @@ -648,7 +648,7 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type):
)

# reference
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.MOVED_GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

quantizer = Int8DynActInt4WeightQuantizer(
Expand Down
File renamed without changes.
14 changes: 7 additions & 7 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@
OTHER_AUTOQUANT_CLASS_LIST,
autoquant,
)
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightLinear,
Int8DynActInt4WeightQuantizer,
)
from .granularity import (
PerAxis,
PerGroup,
Expand All @@ -34,6 +27,13 @@
from .linear_activation_scale import (
to_weight_tensor_with_linear_activation_scale_metadata,
)
from .MOVED_GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightLinear,
Int8DynActInt4WeightQuantizer,
)
from .observer import (
AffineQuantizedMinMaxObserver,
AffineQuantizedObserverBase,
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import torch.nn.functional as F

from torchao.dtypes.utils import is_device
from torchao.quantization.GPTQ import (
from torchao.quantization.granularity import PerGroup
from torchao.quantization.MOVED_GPTQ import (
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
_check_linear_int4_k,
_replace_linear_8da4w,
_replace_linear_int4,
groupwise_affine_quantize_tensor,
)
from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_primitives import (
TorchAODType,
ZeroPointDomain,
Expand Down
12 changes: 6 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@
)

from .autoquant import AutoQuantizableLinearWeight, autoquant
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightQuantizer,
)
from .granularity import (
Granularity,
PerAxis,
Expand All @@ -100,6 +94,12 @@
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .MOVED_GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightQuantizer,
)
from .qat import (
intx_quantization_aware_training,
)
Expand Down
Loading