Skip to content

Commit 7c1133d

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add inplace quantizer examples (pytorch#2345)
Summary: Pull Request resolved: pytorch#2345 Add a quantizer example for in place ops, and add a patch to the constant fold pass such that the mutable buffer won't be folded Reviewed By: jerryzh168 Differential Revision: D76312488
1 parent 63a91d7 commit 7c1133d

File tree

3 files changed

+139
-0
lines changed

3 files changed

+139
-0
lines changed

test/quantization/pt2e/test_graph_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch._dynamo as torchdynamo
1313
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, run_tests
14+
from torchao.quantization.pt2e.constant_fold import constant_fold
1415

1516
from torchao.quantization.pt2e.graph_utils import (
1617
find_sequential_partitions,
@@ -128,6 +129,37 @@ def forward(self, x):
128129
[torch.nn.Conv2d, torch.nn.ReLU6],
129130
)
130131
self.assertEqual(len(fused_partitions), 1)
132+
133+
def test_constant_fold(self):
134+
class M(torch.nn.Module):
135+
def __init__(self):
136+
super().__init__()
137+
self.register_buffer("buf", torch.randn(1, 2, 3, 3))
138+
139+
def forward(self, x):
140+
self.buf = torch.abs(self.buf)
141+
self.buf.add_(x)
142+
return self.buf
143+
# Generated graph:
144+
# opcode name target args kwargs
145+
# ------------- ------------- ------------------ ------------ --------
146+
# get_attr buf buf () {}
147+
# placeholder x x () {}
148+
# call_function abs_1 aten.abs.default (buf,) {}
149+
# call_function add_ aten.add_.Tensor (abs_1, x) {}
150+
# call_function copy__default aten.copy_.default (buf, abs_1) {}
151+
# output output output ((add_,),) {}
152+
153+
m = M().eval()
154+
example_inputs = (torch.randn(1, 2, 3, 3),)
155+
ep = torch.export.export(m, example_inputs)
156+
157+
gm = ep.module()
158+
self.assertTrue(len([n for n in gm.graph.nodes if n.op == "call_function" and torch.ops.aten.abs.default]) == 1)
159+
constant_fold(gm)
160+
# The mutable buffer shoudn't be folded with the abs op, because it's a mutable buffer
161+
self.assertTrue(len([n for n in gm.graph.nodes if n.op == "call_function" and torch.ops.aten.abs.default]) == 1)
162+
131163

132164

133165
if __name__ == "__main__":

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,6 +2826,88 @@ def check_nn_module(node):
28262826
if node.name == "mul":
28272827
check_nn_module(node)
28282828

2829+
def test_quantize_in_place_ops(self):
2830+
class TestQuantizer(Quantizer):
2831+
example_inputs = None
2832+
2833+
def set_example_inputs(self, example_inputs):
2834+
self.example_inputs = example_inputs
2835+
2836+
def transform_for_annotation(
2837+
self, model: torch.fx.GraphModule
2838+
) -> torch.fx.GraphModule:
2839+
# Make a copy of the graph to ensure that we are using the
2840+
# return value of this function.
2841+
ep = torch.export.export(model, self.example_inputs)
2842+
ep = ep.run_decompositions({})
2843+
return ep.module()
2844+
2845+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2846+
act_qspec = QuantizationSpec(
2847+
dtype=torch.uint8,
2848+
quant_min=0,
2849+
quant_max=255,
2850+
qscheme=torch.per_tensor_affine,
2851+
is_dynamic=False,
2852+
observer_or_fake_quant_ctr=observer.default_observer
2853+
)
2854+
for node in model.graph.nodes:
2855+
if (
2856+
node.op == "call_function"
2857+
and node.target == torch.ops.aten.add.Tensor
2858+
):
2859+
input_act0 = node.args[0]
2860+
assert isinstance(input_act0, torch.fx.Node)
2861+
input_act1 = node.args[1]
2862+
assert isinstance(input_act1, torch.fx.Node)
2863+
print("input_act1 is a node")
2864+
node.meta["quantization_annotation"] = QuantizationAnnotation(
2865+
input_qspec_map={
2866+
input_act0: act_qspec,
2867+
input_act1: act_qspec,
2868+
},
2869+
output_qspec=act_qspec,
2870+
_annotated=True,
2871+
)
2872+
2873+
def validate(self, model: torch.fx.GraphModule) -> None:
2874+
pass
2875+
2876+
class M(torch.nn.Module):
2877+
def __init__(self):
2878+
super().__init__()
2879+
self.register_buffer("buf", torch.randn(1, 2, 3, 3))
2880+
2881+
def forward(self, x):
2882+
self.buf.add_(x)
2883+
return self.buf
2884+
2885+
def has_inplace_ops(graph_module: GraphModule) -> bool:
2886+
return len([
2887+
n for n in graph_module.graph.nodes if n.op == "call_function" and n.name.endswith("_") and n.name != "copy_"
2888+
]) > 0
2889+
2890+
m = M().eval()
2891+
quantizer = TestQuantizer()
2892+
example_inputs = (torch.randn(1, 2, 3, 3),)
2893+
quantizer.set_example_inputs(example_inputs)
2894+
m = export_for_training(m, example_inputs, strict=True).module()
2895+
# Check that the model has in-place ops
2896+
self.assertTrue(has_inplace_ops(m))
2897+
m = prepare_pt2e(m, quantizer)
2898+
# Check that the model no longer has in-place ops because the graph is funtionalized during annotate_to_tranform
2899+
self.assertFalse(has_inplace_ops(m))
2900+
m(*example_inputs)
2901+
m = convert_pt2e(m, fold_quantize=True)
2902+
for node in m.graph.nodes:
2903+
if node.name == "quantize_per_tensor_default":
2904+
# Ensure the quant node is not fused with the mutable buffer
2905+
self.assertTrue(node.op == "call_function")
2906+
2907+
# Verify the quantized model works
2908+
result = m(*example_inputs)
2909+
self.assertIsNotNone(result)
2910+
28292911

28302912
@skipIfNoQNNPACK
28312913
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")

torchao/quantization/pt2e/constant_fold.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ def __init__(
9292
self.lifted_constant_names = lifted_constant_names
9393
self.deferred_value = object()
9494
self.skip_folding_node_fn = skip_folding_node_fn
95+
96+
# Identify mutable buffers by finding copy_ operations
97+
self.mutable_buffers = self._find_mutable_buffers()
98+
99+
def _find_mutable_buffers(self) -> set[torch.fx.Node]:
100+
"""Find mutable buffers by identifying copy_ operations.
101+
The first argument of copy_ op is the mutable buffer."""
102+
mutable_buffers = set()
103+
for node in self.module.graph.nodes:
104+
if (
105+
node.op == "call_function"
106+
and hasattr(node.target, "_schema")
107+
and "copy_" in str(node.target)
108+
):
109+
# The first argument of copy_ is the mutable buffer
110+
if len(node.args) > 0 and isinstance(node.args[0], torch.fx.Node):
111+
mutable_buffers.add(node.args[0])
112+
return mutable_buffers
95113

96114
def _support_dynamic_shape(self) -> bool:
97115
# ConstantFolder not support dynamic shape now
@@ -156,6 +174,13 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
156174
# We only folding fp32_weight -> q
157175
# int8_weight and leave dq in graph to be fused
158176
return True
177+
178+
# Check if any input to this node is a mutable buffer
179+
# If so, prevent constant folding to avoid issues with quantize_per_tensor_default
180+
for arg in node.args:
181+
if isinstance(arg, torch.fx.Node) and arg in self.mutable_buffers:
182+
return True
183+
159184
return False
160185

161186
def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:

0 commit comments

Comments
 (0)