Skip to content

Commit b22a2be

Browse files
Arm backend: Add support for per-channel quantization (#11752)
- Adds support for per-channel quantization in TosaQuantizer and TosaBackend - Enables per-channel quantization for MobilneNetV2 test cases cc @digantdesai @freddan80 @per @zingo --------- Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 6af28c9 commit b22a2be

26 files changed

+359
-233
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55

66
# pyre-unsafe
77

8-
from typing import cast
98

109
import torch
1110
from executorch.backends.arm._passes.arm_pass_utils import (
1211
create_node,
1312
get_first_fake_tensor,
14-
insert_q_dq_pair,
1513
)
16-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1714
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1815
from executorch.exir.dialects._ops import ops as exir_ops
1916
from executorch.exir.pass_base import ExportPass, PassResult
@@ -59,20 +56,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
5956

6057
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
6158
"""
62-
returns True for dq and w in the following sequences;
59+
returns True for w in the following sequence;
6360
w -> depthwise_conv2d -> ...
64-
w -> dq -> depthwise_conv2d -> ...
6561
"""
66-
if node.op == "call_function":
67-
if node.target != dq_op:
68-
return False
69-
prev_node = node.args[0]
70-
if cast(torch.fx.Node, prev_node).op != "placeholder":
71-
return False
72-
if is_consumer_node_depthwise_conv2d(node):
73-
consumer_node = list(node.users)[0]
74-
return consumer_node.args[1] == node
75-
elif node.op == "placeholder":
62+
if node.op == "placeholder":
7663
# node is an input, weight or bias node
7764
consumer_node = list(node.users)[0]
7865
if self.is_weight_node_for_depthwise_conv2d(consumer_node):
@@ -129,8 +116,6 @@ def is_channel_reshape(input_shape, output_shape):
129116

130117
@staticmethod
131118
def insert_input_transpose(node, input_node, graph_module):
132-
quantize = input_node.target == dq_op
133-
q_params = input_node.args[1:] if quantize else None
134119
with graph_module.graph.inserting_before(node):
135120
permute_node = create_node(
136121
graph_module.graph,
@@ -143,8 +128,6 @@ def insert_input_transpose(node, input_node, graph_module):
143128
else AnnotateChannelsLastDimOrder.NHWC_inverse_order
144129
),
145130
),
146-
quantize=quantize,
147-
q_params=q_params,
148131
)
149132
node.replace_input_with(input_node, permute_node)
150133

@@ -185,11 +168,6 @@ def insert_output_transpose(node, graph_module):
185168
for user in users:
186169
user.replace_input_with(node, permute_node)
187170

188-
quantize = node.args[0] == q_op
189-
if quantize:
190-
q_params = node.args[0].args[1:]
191-
insert_q_dq_pair(graph_module.graph, node, q_params)
192-
193171
@staticmethod
194172
def _insert_view_transpose(
195173
input_shape, output_shape, node, input_node, graph_module

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
import itertools
99
import operator
10-
from typing import List
10+
from typing import cast, List
1111

1212
import torch
1313
from executorch.backends.arm._passes.arm_pass_utils import create_node
1414

15-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
15+
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1718
from executorch.exir.pass_base import ExportPass, PassResult
1819
from torch.fx import GraphModule
1920
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
@@ -61,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6162
}
6263
for partition in matmul_partitions:
6364
quantized_input = all(
64-
input_node.target == dq_op for input_node in partition.input_nodes
65+
input_node.target in dq_ops for input_node in partition.input_nodes
6566
)
6667
matmul_node = [
6768
node for node in partition.nodes if node.target in matmul_targets
@@ -74,17 +75,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7475
input_node = self._match_partition_to_node(
7576
node, partition.input_nodes
7677
)
77-
input_node_qargs = QuantArgs.from_operator(
78-
input_node.target, input_node.args
79-
)
8078
# Insert new dq-node just before the mm/bmm with input_node's qparams
8179
with graph_module.graph.inserting_before(matmul_node):
8280
# Create new dq-node before matmul
8381
dq_node = create_node(
8482
graph=graph_module.graph,
85-
op_target=dq_op,
83+
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
8684
)
87-
dq_node.args = (node, *input_node_qargs)
85+
dq_node.args = (node, *input_node.args[1:])
8886
matmul_node.replace_input_with(node, dq_node)
8987

9088
for partition_input in partition.input_nodes:
@@ -95,19 +93,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
9593
graph_module.graph.erase_node(partition_input)
9694

9795
partition_output = list(partition.output_nodes[0].users)[0]
98-
quantized_output = partition_output.target == q_op
96+
quantized_output = partition_output.target in q_ops
9997
if quantized_output:
100-
output_node_qargs = QuantArgs.from_operator(
101-
partition_output.target, partition_output.args
102-
)
10398
with graph_module.graph.inserting_after(matmul_node):
10499
# Create q-node after matmul
105100
q_node = create_node(
106101
graph=graph_module.graph,
107-
op_target=q_op,
102+
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
108103
)
109104
matmul_node.replace_all_uses_with(q_node)
110-
q_node.args = (matmul_node, *output_node_qargs)
105+
q_node.args = (matmul_node, *partition_output.args[1:])
111106
# Remove partition output q-node
112107
partition_output.replace_all_uses_with(
113108
partition_output.all_input_nodes[0]

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
9393
self.add_pass(RemoveGetItemPass())
9494
self.add_pass(ConvertSplitToSlicePass())
9595
self.add_pass(ConvertMmToBmmPass())
96-
self.add_pass(DecomposeLinearPass())
9796
self.add_pass(DecomposeLinearVectorNormPass())
9897
self.add_pass(
9998
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
@@ -109,12 +108,13 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
109108
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
110109
self.add_pass(AnnotateDecomposedMatmulPass())
111110
self.add_pass(QuantizeOperatorArguments())
112-
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
111+
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
113112
self.add_pass(RetraceFoldedDtypesPass())
114113
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
115114
self.add_pass(MatchArgRanksPass(exported_program))
116115
if self.tosa_spec.is_U55_subset:
117116
self.add_pass(BroadcastArgsPass())
117+
self.add_pass(DecomposeLinearPass())
118118
self.add_pass(ComputeConstantOpsAOT(exported_program))
119119

120120
self.add_pass(RemoveClonePass())
@@ -168,7 +168,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
168168

169169
self.add_pass(AnnotateDecomposedMatmulPass())
170170
self.add_pass(QuantizeOperatorArguments())
171-
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
171+
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
172172
self.add_pass(RetraceFoldedDtypesPass())
173173
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
174174
self.add_pass(MatchArgRanksPass(exported_program))

backends/arm/_passes/cast_int64_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
3535

3636
def _to_int32(self, graph_module: torch.fx.GraphModule):
3737
for node in graph_module.graph.nodes:
38+
if len(node.users) == 0:
39+
continue
3840
fake_tensor = node.meta["val"]
3941
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
4042
continue

backends/arm/_passes/decompose_linear_pass.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,28 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import numpy as np
9+
from executorch.backends.arm._passes import ArmPass
1010
from executorch.backends.arm._passes.arm_pass_utils import (
1111
create_node,
1212
get_first_fake_tensor,
1313
)
14-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1514
from executorch.exir.dialects._ops import ops as exir_ops
16-
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.pass_base import PassResult
1716

1817

19-
class DecomposeLinearPass(ExportPass):
18+
class DecomposeLinearPass(ArmPass):
2019
"""
2120
This pass decomposes linear into a Conv2D with the required view operations.
2221
linear(x, weights, bias) becomes:
2322
x_reshaped = view(x)
2423
weights_reshaped = view(weights)
2524
conv2d = conv2d(x_reshaped, weights_reshaped, bias)
2625
output = view(conv2d)
27-
It also inserts q/dq pairs if the linear node was quantized.
2826
"""
2927

3028
def call(self, graph_module):
@@ -47,35 +45,22 @@ def call(self, graph_module):
4745
weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1]
4846

4947
with graph_module.graph.inserting_before(node):
50-
quantize = input.op == "call_function" and input.target == dq_op
51-
q_params = input.args[1:] if quantize else None
5248
# Reshape input to 4D with shape (N, Ci, 1, 1)
5349
input_reshaped = create_node(
5450
graph=graph_module.graph,
5551
op_target=exir_ops.edge.aten.view_copy.default,
5652
args=(input, input_reshaped_shape),
5753
kwargs={},
58-
quantize=quantize,
59-
q_params=q_params,
6054
)
6155

62-
quantize = weights.op == "call_function" and weights.target == dq_op
63-
q_params = weights.args[1:] if quantize else None
6456
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
6557
weights_reshaped = create_node(
6658
graph=graph_module.graph,
6759
op_target=exir_ops.edge.aten.view_copy.default,
6860
args=(weights, weights_reshaped_shape),
6961
kwargs={},
70-
quantize=quantize,
71-
q_params=q_params,
7262
)
7363

74-
consumer_node = list(node.users)[0]
75-
quantize = (
76-
consumer_node.op == "call_function" and consumer_node.target == q_op
77-
)
78-
q_params = consumer_node.args[1:] if quantize else None
7964
conv = create_node(
8065
graph=graph_module.graph,
8166
op_target=exir_ops.edge.aten.convolution.default,
@@ -91,8 +76,7 @@ def call(self, graph_module):
9176
1, # groups
9277
),
9378
kwargs={},
94-
quantize=quantize,
95-
q_params=q_params,
79+
from_node=node,
9680
)
9781

9882
with graph_module.graph.inserting_after(conv):

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010

1111
from typing import cast, Dict, Set, Tuple
1212

13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
13+
from executorch.backends.arm._passes import ArmPass
14+
from executorch.backends.arm._passes.arm_pass_utils import (
15+
get_param_tensor,
16+
is_param_node,
17+
)
18+
19+
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
1420

1521
from executorch.exir.dialects._ops import ops as exir_ops
1622
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -24,9 +30,6 @@
2430
)
2531
from torch.fx import GraphModule, Node
2632

27-
q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
28-
dq_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
29-
3033

3134
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
3235
"""
@@ -66,7 +69,7 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
6669
return output_qparams
6770

6871

69-
class FoldAndAnnotateQParamsPass(ExportPass):
72+
class FoldAndAnnotateQParamsPass(ArmPass):
7073
"""
7174
A pass that walks the graph and removes any DQ and Q nodes before and after the target
7275
node.
@@ -96,9 +99,6 @@ class FoldAndAnnotateQParamsPass(ExportPass):
9699
97100
"""
98101

99-
def __init__(self) -> None:
100-
super().__init__()
101-
102102
def fold_and_annotate_arg(
103103
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
104104
) -> None:
@@ -109,8 +109,25 @@ def fold_and_annotate_arg(
109109
return
110110

111111
arg_quant_params = None
112-
if arg.target == dq_op:
113-
arg_quant_params = QuantArgs.from_operator(arg.target, arg.args)
112+
if arg.target in dq_ops:
113+
args = arg.args
114+
scales = args[1]
115+
if (
116+
isinstance(args[1], Node)
117+
and self.exported_program is not None
118+
and is_param_node(self.exported_program, args[1])
119+
):
120+
scales = get_param_tensor(self.exported_program, args[1])
121+
zps = args[2]
122+
if (
123+
isinstance(args[2], Node)
124+
and self.exported_program is not None
125+
and is_param_node(self.exported_program, args[2])
126+
):
127+
zps = get_param_tensor(self.exported_program, args[2])
128+
arg_quant_params = QuantArgs.from_operator(
129+
arg.target, (args[0], scales, zps, *args[3:])
130+
)
114131
# add arg to nodes_to_remove to fold the dq-node
115132
nodes_to_remove.add(arg)
116133
if input_qparams is not None and input_qparams != arg_quant_params:
@@ -120,10 +137,13 @@ def fold_and_annotate_arg(
120137
if input_qparams is not None:
121138
node.meta["input_qparams"][i] = input_qparams
122139
for n in nodes_to_remove:
123-
if n.target != dq_op:
124-
raise RuntimeError(f"Expected {dq_op} dq_op, got {n.target}")
140+
if n.target not in dq_ops:
141+
raise RuntimeError(
142+
f"Expected one of {dq_ops} dq_op, got {n.target}"
143+
)
125144

126-
n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
145+
if len(n.args) > 0:
146+
n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
127147
graph_module.graph.erase_node(n)
128148

129149
def call(self, graph_module: GraphModule) -> PassResult:
@@ -134,7 +154,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
134154
if n.op != "call_function":
135155
continue
136156
# Don't fold chains of quant-ops into each other.
137-
if n.target in (q_op, dq_op):
157+
if n.target in (*q_ops, *dq_ops):
138158
continue
139159

140160
# Make sure we haven't already set qparams meta information on the node
@@ -164,7 +184,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
164184
# Copy the users, since we are modifying it.
165185
users_copy = copy.copy(n.users)
166186
for i, user in enumerate(users_copy):
167-
if user.target != q_op:
187+
if user.target not in q_ops:
168188
continue
169189

170190
# quantization node found here, store the quantization parameters in meta value
@@ -201,7 +221,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
201221

202222
# Make sure we have a quantized operator
203223
user = list(n.users)[0]
204-
if user.target != q_op:
224+
if user.target not in q_ops:
205225
continue
206226

207227
qargs = QuantArgs.from_operator(user.target, user.args)

0 commit comments

Comments
 (0)