Skip to content

Add inplace quantizer examples #2345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,6 +2826,97 @@ def check_nn_module(node):
if node.name == "mul":
check_nn_module(node)

def test_quantize_in_place_ops(self):
class TestQuantizer(Quantizer):
example_inputs = None

def set_example_inputs(self, example_inputs):
self.example_inputs = example_inputs

def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
# Make a copy of the graph to ensure that we are using the
# return value of this function.
ep = torch.export.export(model, self.example_inputs)
ep = ep.run_decompositions({})
return ep.module()

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.add.Tensor
):
input_act0 = node.args[0]
assert isinstance(input_act0, torch.fx.Node)
input_act1 = node.args[1]
assert isinstance(input_act1, torch.fx.Node)
print("input_act1 is a node")
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act0: act_qspec,
input_act1: act_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.randn(1, 2, 3, 3))

def forward(self, x):
self.buf.add_(x)
return self.buf

def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
return (
len(
[
n
for n in graph_module.graph.nodes
if n.op == "call_function"
and n.name.endswith("_")
and n.name != "copy_"
]
)
> 0
)

m = M().eval()
quantizer = TestQuantizer()
example_inputs = (torch.randn(1, 2, 3, 3),)
quantizer.set_example_inputs(example_inputs)
m = export_for_training(m, example_inputs, strict=True).module()
# Check that the model has in-place ops
self.assertTrue(has_inplace_ops(m))
m = prepare_pt2e(m, quantizer)
# Check that the model no longer has in-place ops because the graph is funtionalized during annotate_to_tranform
self.assertFalse(has_inplace_ops(m))
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=True)
for node in m.graph.nodes:
if node.name == "quantize_per_tensor_default":
# Ensure the quant node is not fused with the mutable buffer
self.assertTrue(node.op == "call_function")

# Verify the quantized model works
result = m(*example_inputs)
self.assertIsNotNone(result)


@skipIfNoQNNPACK
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
Expand Down
25 changes: 25 additions & 0 deletions torchao/quantization/pt2e/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ def __init__(
self.deferred_value = object()
self.skip_folding_node_fn = skip_folding_node_fn

# Identify mutable buffers by finding copy_ operations
self.mutable_buffers = self._find_mutable_buffers()

def _find_mutable_buffers(self) -> set[torch.fx.Node]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change needed? is there a test that exercises this code path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the added test will fail if we have this code path. The quantize_per_tensor_default will be folded together with the mutable buffer, which is not what we want

"""Find mutable buffers by identifying copy_ operations.
The first argument of copy_ op is the mutable buffer."""
mutable_buffers = set()
for node in self.module.graph.nodes:
if (
node.op == "call_function"
and hasattr(node.target, "_schema")
and "copy_" in str(node.target)
):
# The first argument of copy_ is the mutable buffer
if len(node.args) > 0 and isinstance(node.args[0], torch.fx.Node):
mutable_buffers.add(node.args[0])
return mutable_buffers

def _support_dynamic_shape(self) -> bool:
# ConstantFolder not support dynamic shape now
return False
Expand Down Expand Up @@ -156,6 +174,13 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True

# Check if any input to this node is a mutable buffer
# If so, prevent constant folding to avoid issues with quantize_per_tensor_default
for arg in node.args:
if isinstance(arg, torch.fx.Node) and arg in self.mutable_buffers:
return True

return False

def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:
Expand Down
Loading