Skip to content

Commit 52dbbc6

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add inplace quantizer examples (pytorch#2345)
Summary: Pull Request resolved: pytorch#2345 Add a quantizer example for in place ops, and add a patch to the constant fold pass such that the mutable buffer won't be folded Differential Revision: D76312488
1 parent e4f2715 commit 52dbbc6

File tree

2 files changed

+118
-14
lines changed

2 files changed

+118
-14
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,32 @@
1111
import unittest
1212

1313
import torch
14+
15+
import torchao
1416
from torch import Tensor
1517
from torch.ao.quantization import QConfigMapping
1618
from torch.ao.quantization.qconfig import (
17-
QConfig,
1819
default_per_channel_symmetric_qnnpack_qconfig,
1920
per_channel_weight_observer_range_neg_127_to_127,
21+
QConfig,
2022
weight_observer_range_neg_127_to_127,
2123
)
24+
from torch.export import ExportedProgram
2225
from torch.fx import Node
26+
from torch.fx.graph_module import GraphModule
2327
from torch.testing._internal.common_quantization import (
2428
NodeSpec as ns,
25-
)
26-
from torch.testing._internal.common_quantization import (
27-
TestHelperModules,
2829
skipIfNoQNNPACK,
30+
TestHelperModules,
2931
)
3032
from torch.testing._internal.common_utils import (
31-
TEST_CUDA,
32-
TemporaryFileName,
3333
instantiate_parametrized_tests,
3434
parametrize,
3535
run_tests,
36+
TemporaryFileName,
37+
TEST_CUDA,
3638
)
37-
38-
import torchao
39-
from torchao.quantization.pt2e import ObserverOrFakeQuantize, observer
39+
from torchao.quantization.pt2e import observer, ObserverOrFakeQuantize
4040
from torchao.quantization.pt2e.quantize_pt2e import (
4141
convert_pt2e,
4242
prepare_pt2e,
@@ -58,8 +58,8 @@
5858
EmbeddingQuantizer,
5959
)
6060
from torchao.testing.pt2e._xnnpack_quantizer import (
61-
XNNPACKQuantizer,
6261
get_symmetric_quantization_config,
62+
XNNPACKQuantizer,
6363
)
6464
from torchao.testing.pt2e._xnnpack_quantizer_utils import (
6565
OP_TO_ANNOTATOR,
@@ -75,9 +75,7 @@
7575
DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else [])
7676

7777
if TORCH_VERSION_AT_LEAST_2_7:
78-
from torch.testing._internal.common_utils import (
79-
TEST_HPU,
80-
)
78+
from torch.testing._internal.common_utils import TEST_HPU
8179

8280
DEVICE_LIST += ["hpu"] if TEST_HPU else []
8381

@@ -2826,6 +2824,88 @@ def check_nn_module(node):
28262824
if node.name == "mul":
28272825
check_nn_module(node)
28282826

2827+
def test_quantize_in_place_ops(self):
2828+
class TestQuantizer(Quantizer):
2829+
example_inputs = None
2830+
2831+
def set_example_inputs(self, example_inputs):
2832+
self.example_inputs = example_inputs
2833+
2834+
def transform_for_annotation(
2835+
self, model: torch.fx.GraphModule
2836+
) -> torch.fx.GraphModule:
2837+
# Make a copy of the graph to ensure that we are using the
2838+
# return value of this function.
2839+
ep = torch.export.export(model, self.example_inputs)
2840+
ep = ep.run_decompositions({})
2841+
return ep.module()
2842+
2843+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2844+
act_qspec = QuantizationSpec(
2845+
dtype=torch.uint8,
2846+
quant_min=0,
2847+
quant_max=255,
2848+
qscheme=torch.per_tensor_affine,
2849+
is_dynamic=False,
2850+
observer_or_fake_quant_ctr=observer.default_observer
2851+
)
2852+
for node in model.graph.nodes:
2853+
if (
2854+
node.op == "call_function"
2855+
and node.target == torch.ops.aten.add.Tensor
2856+
):
2857+
input_act0 = node.args[0]
2858+
assert isinstance(input_act0, torch.fx.Node)
2859+
input_act1 = node.args[1]
2860+
assert isinstance(input_act1, torch.fx.Node)
2861+
print("input_act1 is a node")
2862+
node.meta["quantization_annotation"] = QuantizationAnnotation(
2863+
input_qspec_map={
2864+
input_act0: act_qspec,
2865+
input_act1: act_qspec,
2866+
},
2867+
output_qspec=act_qspec,
2868+
_annotated=True,
2869+
)
2870+
2871+
def validate(self, model: torch.fx.GraphModule) -> None:
2872+
pass
2873+
2874+
class M(torch.nn.Module):
2875+
def __init__(self):
2876+
super().__init__()
2877+
self.register_buffer("buf", torch.randn(1, 2, 3, 3))
2878+
2879+
def forward(self, x):
2880+
self.buf.add_(x)
2881+
return self.buf
2882+
2883+
def has_inplace_ops(graph_module: GraphModule) -> bool:
2884+
return len([
2885+
n for n in graph_module.graph.nodes if n.op == "call_function" and n.name.endswith("_") and n.name != "copy_"
2886+
]) > 0
2887+
2888+
m = M().eval()
2889+
quantizer = TestQuantizer()
2890+
example_inputs = (torch.randn(1, 2, 3, 3),)
2891+
quantizer.set_example_inputs(example_inputs)
2892+
m = export_for_training(m, example_inputs, strict=True).module()
2893+
# Check that the model has in-place ops
2894+
self.assertTrue(has_inplace_ops(m))
2895+
m = prepare_pt2e(m, quantizer)
2896+
# Check that the model no longer has in-place ops because the graph is funtionalized during annotate_to_tranform
2897+
self.assertFalse(has_inplace_ops(m))
2898+
m(*example_inputs)
2899+
m = convert_pt2e(m, fold_quantize=True)
2900+
for node in m.graph.nodes:
2901+
if node.name == "quantize_per_tensor_default":
2902+
# Ensure the quant node is not fused with the mutable buffer
2903+
self.assertTrue(node.op == "call_function")
2904+
2905+
# Verify the quantized model works
2906+
result = m(*example_inputs)
2907+
self.assertIsNotNone(result)
2908+
28292909

28302910
@skipIfNoQNNPACK
28312911
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")

torchao/quantization/pt2e/constant_fold.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ def __init__(
9292
self.lifted_constant_names = lifted_constant_names
9393
self.deferred_value = object()
9494
self.skip_folding_node_fn = skip_folding_node_fn
95+
96+
# Identify mutable buffers by finding copy_ operations
97+
self.mutable_buffers = self._find_mutable_buffers()
98+
99+
def _find_mutable_buffers(self) -> set[torch.fx.Node]:
100+
"""Find mutable buffers by identifying copy_ operations.
101+
The first argument of copy_ op is the mutable buffer."""
102+
mutable_buffers = set()
103+
for node in self.module.graph.nodes:
104+
if (
105+
node.op == "call_function"
106+
and hasattr(node.target, "_schema")
107+
and "copy_" in str(node.target)
108+
):
109+
# The first argument of copy_ is the mutable buffer
110+
if len(node.args) > 0 and isinstance(node.args[0], torch.fx.Node):
111+
mutable_buffers.add(node.args[0])
112+
return mutable_buffers
95113

96114
def _support_dynamic_shape(self) -> bool:
97115
# ConstantFolder not support dynamic shape now
@@ -156,6 +174,13 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
156174
# We only folding fp32_weight -> q
157175
# int8_weight and leave dq in graph to be fused
158176
return True
177+
178+
# Check if any input to this node is a mutable buffer
179+
# If so, prevent constant folding to avoid issues with quantize_per_tensor_default
180+
for arg in node.args:
181+
if isinstance(arg, torch.fx.Node) and arg in self.mutable_buffers:
182+
return True
183+
159184
return False
160185

161186
def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:
@@ -261,7 +286,6 @@ def set_env(arg: torch.fx.Node) -> None:
261286

262287
if self.is_impure(node):
263288
return self.unknown_value
264-
265289
self.add_node_replacement(node, out)
266290

267291
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)

0 commit comments

Comments
 (0)