From 62483bcac7fcd6ec98d24fe87c5074dfab7e46c4 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 17 Jun 2025 13:54:57 -0700 Subject: [PATCH] Add inplace quantizer examples (#2345) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/2345 Add a quantizer example for in place ops, and add a patch to the constant fold pass such that the mutable buffer won't be folded Reviewed By: jerryzh168 Differential Revision: D76312488 --- test/quantization/pt2e/test_quantize_pt2e.py | 91 ++++++++++++++++++++ torchao/quantization/pt2e/constant_fold.py | 25 ++++++ 2 files changed, 116 insertions(+) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 730969ba9c..be5a4dc537 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2826,6 +2826,97 @@ def check_nn_module(node): if node.name == "mul": check_nn_module(node) + def test_quantize_in_place_ops(self): + class TestQuantizer(Quantizer): + example_inputs = None + + def set_example_inputs(self, example_inputs): + self.example_inputs = example_inputs + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + # Make a copy of the graph to ensure that we are using the + # return value of this function. + ep = torch.export.export(model, self.example_inputs) + ep = ep.run_decompositions({}) + return ep.module() + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.add.Tensor + ): + input_act0 = node.args[0] + assert isinstance(input_act0, torch.fx.Node) + input_act1 = node.args[1] + assert isinstance(input_act1, torch.fx.Node) + print("input_act1 is a node") + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act0: act_qspec, + input_act1: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.randn(1, 2, 3, 3)) + + def forward(self, x): + self.buf.add_(x) + return self.buf + + def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool: + return ( + len( + [ + n + for n in graph_module.graph.nodes + if n.op == "call_function" + and n.name.endswith("_") + and n.name != "copy_" + ] + ) + > 0 + ) + + m = M().eval() + quantizer = TestQuantizer() + example_inputs = (torch.randn(1, 2, 3, 3),) + quantizer.set_example_inputs(example_inputs) + m = export_for_training(m, example_inputs, strict=True).module() + # Check that the model has in-place ops + self.assertTrue(has_inplace_ops(m)) + m = prepare_pt2e(m, quantizer) + # Check that the model no longer has in-place ops because the graph is funtionalized during annotate_to_tranform + self.assertFalse(has_inplace_ops(m)) + m(*example_inputs) + m = convert_pt2e(m, fold_quantize=True) + for node in m.graph.nodes: + if node.name == "quantize_per_tensor_default": + # Ensure the quant node is not fused with the mutable buffer + self.assertTrue(node.op == "call_function") + + # Verify the quantized model works + result = m(*example_inputs) + self.assertIsNotNone(result) + @skipIfNoQNNPACK @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") diff --git a/torchao/quantization/pt2e/constant_fold.py b/torchao/quantization/pt2e/constant_fold.py index 37a84c45bf..27f82e6757 100644 --- a/torchao/quantization/pt2e/constant_fold.py +++ b/torchao/quantization/pt2e/constant_fold.py @@ -93,6 +93,24 @@ def __init__( self.deferred_value = object() self.skip_folding_node_fn = skip_folding_node_fn + # Identify mutable buffers by finding copy_ operations + self.mutable_buffers = self._find_mutable_buffers() + + def _find_mutable_buffers(self) -> set[torch.fx.Node]: + """Find mutable buffers by identifying copy_ operations. + The first argument of copy_ op is the mutable buffer.""" + mutable_buffers = set() + for node in self.module.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "_schema") + and "copy_" in str(node.target) + ): + # The first argument of copy_ is the mutable buffer + if len(node.args) > 0 and isinstance(node.args[0], torch.fx.Node): + mutable_buffers.add(node.args[0]) + return mutable_buffers + def _support_dynamic_shape(self) -> bool: # ConstantFolder not support dynamic shape now return False @@ -156,6 +174,13 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: # We only folding fp32_weight -> q # int8_weight and leave dq in graph to be fused return True + + # Check if any input to this node is a mutable buffer + # If so, prevent constant folding to avoid issues with quantize_per_tensor_default + for arg in node.args: + if isinstance(arg, torch.fx.Node) and arg in self.mutable_buffers: + return True + return False def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]: