Skip to content

Commit 0498b7d

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add inplace quantizer examples (#2345)
Summary: Pull Request resolved: #2345 Add a quantizer example for in place ops, and add a patch to the constant fold pass such that the mutable buffer won't be folded Reviewed By: jerryzh168 Differential Revision: D76312488
1 parent 7e7ea92 commit 0498b7d

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,6 +2826,88 @@ def check_nn_module(node):
28262826
if node.name == "mul":
28272827
check_nn_module(node)
28282828

2829+
def test_quantize_in_place_ops(self):
2830+
class TestQuantizer(Quantizer):
2831+
example_inputs = None
2832+
2833+
def set_example_inputs(self, example_inputs):
2834+
self.example_inputs = example_inputs
2835+
2836+
def transform_for_annotation(
2837+
self, model: torch.fx.GraphModule
2838+
) -> torch.fx.GraphModule:
2839+
# Make a copy of the graph to ensure that we are using the
2840+
# return value of this function.
2841+
ep = torch.export.export(model, self.example_inputs)
2842+
ep = ep.run_decompositions({})
2843+
return ep.module()
2844+
2845+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2846+
act_qspec = QuantizationSpec(
2847+
dtype=torch.uint8,
2848+
quant_min=0,
2849+
quant_max=255,
2850+
qscheme=torch.per_tensor_affine,
2851+
is_dynamic=False,
2852+
observer_or_fake_quant_ctr=observer.default_observer
2853+
)
2854+
for node in model.graph.nodes:
2855+
if (
2856+
node.op == "call_function"
2857+
and node.target == torch.ops.aten.add.Tensor
2858+
):
2859+
input_act0 = node.args[0]
2860+
assert isinstance(input_act0, torch.fx.Node)
2861+
input_act1 = node.args[1]
2862+
assert isinstance(input_act1, torch.fx.Node)
2863+
print("input_act1 is a node")
2864+
node.meta["quantization_annotation"] = QuantizationAnnotation(
2865+
input_qspec_map={
2866+
input_act0: act_qspec,
2867+
input_act1: act_qspec,
2868+
},
2869+
output_qspec=act_qspec,
2870+
_annotated=True,
2871+
)
2872+
2873+
def validate(self, model: torch.fx.GraphModule) -> None:
2874+
pass
2875+
2876+
class M(torch.nn.Module):
2877+
def __init__(self):
2878+
super().__init__()
2879+
self.register_buffer("buf", torch.randn(1, 2, 3, 3))
2880+
2881+
def forward(self, x):
2882+
self.buf.add_(x)
2883+
return self.buf
2884+
2885+
def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
2886+
return len([
2887+
n for n in graph_module.graph.nodes if n.op == "call_function" and n.name.endswith("_") and n.name != "copy_"
2888+
]) > 0
2889+
2890+
m = M().eval()
2891+
quantizer = TestQuantizer()
2892+
example_inputs = (torch.randn(1, 2, 3, 3),)
2893+
quantizer.set_example_inputs(example_inputs)
2894+
m = export_for_training(m, example_inputs, strict=True).module()
2895+
# Check that the model has in-place ops
2896+
self.assertTrue(has_inplace_ops(m))
2897+
m = prepare_pt2e(m, quantizer)
2898+
# Check that the model no longer has in-place ops because the graph is funtionalized during annotate_to_tranform
2899+
self.assertFalse(has_inplace_ops(m))
2900+
m(*example_inputs)
2901+
m = convert_pt2e(m, fold_quantize=True)
2902+
for node in m.graph.nodes:
2903+
if node.name == "quantize_per_tensor_default":
2904+
# Ensure the quant node is not fused with the mutable buffer
2905+
self.assertTrue(node.op == "call_function")
2906+
2907+
# Verify the quantized model works
2908+
result = m(*example_inputs)
2909+
self.assertIsNotNone(result)
2910+
28292911

28302912
@skipIfNoQNNPACK
28312913
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")

torchao/quantization/pt2e/constant_fold.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ def __init__(
9292
self.lifted_constant_names = lifted_constant_names
9393
self.deferred_value = object()
9494
self.skip_folding_node_fn = skip_folding_node_fn
95+
96+
# Identify mutable buffers by finding copy_ operations
97+
self.mutable_buffers = self._find_mutable_buffers()
98+
99+
def _find_mutable_buffers(self) -> set[torch.fx.Node]:
100+
"""Find mutable buffers by identifying copy_ operations.
101+
The first argument of copy_ op is the mutable buffer."""
102+
mutable_buffers = set()
103+
for node in self.module.graph.nodes:
104+
if (
105+
node.op == "call_function"
106+
and hasattr(node.target, "_schema")
107+
and "copy_" in str(node.target)
108+
):
109+
# The first argument of copy_ is the mutable buffer
110+
if len(node.args) > 0 and isinstance(node.args[0], torch.fx.Node):
111+
mutable_buffers.add(node.args[0])
112+
return mutable_buffers
95113

96114
def _support_dynamic_shape(self) -> bool:
97115
# ConstantFolder not support dynamic shape now
@@ -156,6 +174,13 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
156174
# We only folding fp32_weight -> q
157175
# int8_weight and leave dq in graph to be fused
158176
return True
177+
178+
# Check if any input to this node is a mutable buffer
179+
# If so, prevent constant folding to avoid issues with quantize_per_tensor_default
180+
for arg in node.args:
181+
if isinstance(arg, torch.fx.Node) and arg in self.mutable_buffers:
182+
return True
183+
159184
return False
160185

161186
def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:

0 commit comments

Comments
 (0)