|
11 | 11 | import unittest
|
12 | 12 |
|
13 | 13 | import torch
|
| 14 | + |
| 15 | +import torchao |
14 | 16 | from torch import Tensor
|
15 | 17 | from torch.ao.quantization import QConfigMapping
|
16 | 18 | from torch.ao.quantization.qconfig import (
|
17 |
| - QConfig, |
18 | 19 | default_per_channel_symmetric_qnnpack_qconfig,
|
19 | 20 | per_channel_weight_observer_range_neg_127_to_127,
|
| 21 | + QConfig, |
20 | 22 | weight_observer_range_neg_127_to_127,
|
21 | 23 | )
|
| 24 | +from torch.export import ExportedProgram |
22 | 25 | from torch.fx import Node
|
| 26 | +from torch.fx.graph_module import GraphModule |
23 | 27 | from torch.testing._internal.common_quantization import (
|
24 | 28 | NodeSpec as ns,
|
25 |
| -) |
26 |
| -from torch.testing._internal.common_quantization import ( |
27 |
| - TestHelperModules, |
28 | 29 | skipIfNoQNNPACK,
|
| 30 | + TestHelperModules, |
29 | 31 | )
|
30 | 32 | from torch.testing._internal.common_utils import (
|
31 |
| - TEST_CUDA, |
32 |
| - TemporaryFileName, |
33 | 33 | instantiate_parametrized_tests,
|
34 | 34 | parametrize,
|
35 | 35 | run_tests,
|
| 36 | + TemporaryFileName, |
| 37 | + TEST_CUDA, |
36 | 38 | )
|
37 |
| - |
38 |
| -import torchao |
39 |
| -from torchao.quantization.pt2e import ObserverOrFakeQuantize, observer |
| 39 | +from torchao.quantization.pt2e import observer, ObserverOrFakeQuantize |
40 | 40 | from torchao.quantization.pt2e.quantize_pt2e import (
|
41 | 41 | convert_pt2e,
|
42 | 42 | prepare_pt2e,
|
|
58 | 58 | EmbeddingQuantizer,
|
59 | 59 | )
|
60 | 60 | from torchao.testing.pt2e._xnnpack_quantizer import (
|
61 |
| - XNNPACKQuantizer, |
62 | 61 | get_symmetric_quantization_config,
|
| 62 | + XNNPACKQuantizer, |
63 | 63 | )
|
64 | 64 | from torchao.testing.pt2e._xnnpack_quantizer_utils import (
|
65 | 65 | OP_TO_ANNOTATOR,
|
|
75 | 75 | DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else [])
|
76 | 76 |
|
77 | 77 | if TORCH_VERSION_AT_LEAST_2_7:
|
78 |
| - from torch.testing._internal.common_utils import ( |
79 |
| - TEST_HPU, |
80 |
| - ) |
| 78 | + from torch.testing._internal.common_utils import TEST_HPU |
81 | 79 |
|
82 | 80 | DEVICE_LIST += ["hpu"] if TEST_HPU else []
|
83 | 81 |
|
@@ -2826,6 +2824,88 @@ def check_nn_module(node):
|
2826 | 2824 | if node.name == "mul":
|
2827 | 2825 | check_nn_module(node)
|
2828 | 2826 |
|
| 2827 | + def test_quantize_in_place_ops(self): |
| 2828 | + class TestQuantizer(Quantizer): |
| 2829 | + example_inputs = None |
| 2830 | + |
| 2831 | + def set_example_inputs(self, example_inputs): |
| 2832 | + self.example_inputs = example_inputs |
| 2833 | + |
| 2834 | + def transform_for_annotation( |
| 2835 | + self, model: torch.fx.GraphModule |
| 2836 | + ) -> torch.fx.GraphModule: |
| 2837 | + # Make a copy of the graph to ensure that we are using the |
| 2838 | + # return value of this function. |
| 2839 | + ep = torch.export.export(model, self.example_inputs) |
| 2840 | + ep = ep.run_decompositions({}) |
| 2841 | + return ep.module() |
| 2842 | + |
| 2843 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 2844 | + act_qspec = QuantizationSpec( |
| 2845 | + dtype=torch.uint8, |
| 2846 | + quant_min=0, |
| 2847 | + quant_max=255, |
| 2848 | + qscheme=torch.per_tensor_affine, |
| 2849 | + is_dynamic=False, |
| 2850 | + observer_or_fake_quant_ctr=observer.default_observer |
| 2851 | + ) |
| 2852 | + for node in model.graph.nodes: |
| 2853 | + if ( |
| 2854 | + node.op == "call_function" |
| 2855 | + and node.target == torch.ops.aten.add.Tensor |
| 2856 | + ): |
| 2857 | + input_act0 = node.args[0] |
| 2858 | + assert isinstance(input_act0, torch.fx.Node) |
| 2859 | + input_act1 = node.args[1] |
| 2860 | + assert isinstance(input_act1, torch.fx.Node) |
| 2861 | + print("input_act1 is a node") |
| 2862 | + node.meta["quantization_annotation"] = QuantizationAnnotation( |
| 2863 | + input_qspec_map={ |
| 2864 | + input_act0: act_qspec, |
| 2865 | + input_act1: act_qspec, |
| 2866 | + }, |
| 2867 | + output_qspec=act_qspec, |
| 2868 | + _annotated=True, |
| 2869 | + ) |
| 2870 | + |
| 2871 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 2872 | + pass |
| 2873 | + |
| 2874 | + class M(torch.nn.Module): |
| 2875 | + def __init__(self): |
| 2876 | + super().__init__() |
| 2877 | + self.register_buffer("buf", torch.randn(1, 2, 3, 3)) |
| 2878 | + |
| 2879 | + def forward(self, x): |
| 2880 | + self.buf.add_(x) |
| 2881 | + return self.buf |
| 2882 | + |
| 2883 | + def has_inplace_ops(graph_module: GraphModule) -> bool: |
| 2884 | + return len([ |
| 2885 | + n for n in graph_module.graph.nodes if n.op == "call_function" and n.name.endswith("_") and n.name != "copy_" |
| 2886 | + ]) > 0 |
| 2887 | + |
| 2888 | + m = M().eval() |
| 2889 | + quantizer = TestQuantizer() |
| 2890 | + example_inputs = (torch.randn(1, 2, 3, 3),) |
| 2891 | + quantizer.set_example_inputs(example_inputs) |
| 2892 | + m = export_for_training(m, example_inputs, strict=True).module() |
| 2893 | + # Check that the model has in-place ops |
| 2894 | + self.assertTrue(has_inplace_ops(m)) |
| 2895 | + m = prepare_pt2e(m, quantizer) |
| 2896 | + # Check that the model no longer has in-place ops because the graph is funtionalized during annotate_to_tranform |
| 2897 | + self.assertFalse(has_inplace_ops(m)) |
| 2898 | + m(*example_inputs) |
| 2899 | + m = convert_pt2e(m, fold_quantize=True) |
| 2900 | + for node in m.graph.nodes: |
| 2901 | + if node.name == "quantize_per_tensor_default": |
| 2902 | + # Ensure the quant node is not fused with the mutable buffer |
| 2903 | + self.assertTrue(node.op == "call_function") |
| 2904 | + |
| 2905 | + # Verify the quantized model works |
| 2906 | + result = m(*example_inputs) |
| 2907 | + self.assertIsNotNone(result) |
| 2908 | + |
2829 | 2909 |
|
2830 | 2910 | @skipIfNoQNNPACK
|
2831 | 2911 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
|
|
0 commit comments