Skip to content

Commit 2779ddb

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add inplace quantizer examples
Summary: Add a quantizer example for in place ops Rollback Plan: Differential Revision: D76312488
1 parent 83663b8 commit 2779ddb

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,6 +2826,78 @@ def check_nn_module(node):
28262826
if node.name == "mul":
28272827
check_nn_module(node)
28282828

2829+
def test_quantize_in_place_ops(self):
2830+
class TestQuantizer(Quantizer):
2831+
example_inputs = None
2832+
2833+
def set_example_inputs(self, example_inputs):
2834+
self.example_inputs = example_inputs
2835+
2836+
def transform_for_annotation(
2837+
self, model: torch.fx.GraphModule
2838+
) -> torch.fx.GraphModule:
2839+
# Make a copy of the graph to ensure that we are using the
2840+
# return value of this function.
2841+
ep = torch.export.export(model, self.example_inputs)
2842+
ep = ep.run_decompositions({})
2843+
return ep.module()
2844+
2845+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2846+
act_qspec = QuantizationSpec(
2847+
dtype=torch.uint8,
2848+
quant_min=0,
2849+
quant_max=255,
2850+
qscheme=torch.per_tensor_affine,
2851+
is_dynamic=False,
2852+
observer_or_fake_quant_ctr=observer.default_observer
2853+
)
2854+
for node in model.graph.nodes:
2855+
if (
2856+
node.op == "call_function"
2857+
and node.target == torch.ops.aten.add.Tensor
2858+
):
2859+
input_act0 = node.args[0]
2860+
assert isinstance(input_act0, torch.fx.Node)
2861+
input_act1 = node.args[1]
2862+
if isinstance(input_act1, torch.fx.Node):
2863+
node.meta["quantization_annotation"] = QuantizationAnnotation(
2864+
input_qspec_map={
2865+
input_act0: act_qspec,
2866+
input_act1: act_qspec,
2867+
},
2868+
output_qspec=act_qspec,
2869+
_annotated=True,
2870+
)
2871+
else:
2872+
# Handle case where second input is a constant
2873+
node.meta["quantization_annotation"] = QuantizationAnnotation(
2874+
input_qspec_map={
2875+
input_act0: act_qspec,
2876+
},
2877+
output_qspec=act_qspec,
2878+
_annotated=True,
2879+
)
2880+
2881+
def validate(self, model: torch.fx.GraphModule) -> None:
2882+
pass
2883+
2884+
class M(torch.nn.Module):
2885+
def forward(self, x):
2886+
return x + 3
2887+
2888+
m = M().eval()
2889+
quantizer = TestQuantizer()
2890+
example_inputs = (torch.randn(1, 2, 3, 3),)
2891+
quantizer.set_example_inputs(example_inputs)
2892+
m = export_for_training(m, example_inputs, strict=True).module()
2893+
m = prepare_pt2e(m, quantizer)
2894+
m(*example_inputs)
2895+
m = convert_pt2e(m)
2896+
2897+
# Verify the quantized model works
2898+
result = m(*example_inputs)
2899+
self.assertIsNotNone(result)
2900+
28292901

28302902
@skipIfNoQNNPACK
28312903
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")

0 commit comments

Comments
 (0)