Skip to content

Commit 6af28c9

Browse files
authored
Arm backend: Added decomposition for MaxPool2d with dilation > 0. (#11724)
Arm backend: Added decomposition for MaxPool2D operator with dilation > 0 Signed-off-by: Elena Zhelezina <elena.zhelezina@arm.com>
1 parent be8ffd1 commit 6af28c9

File tree

4 files changed

+269
-2
lines changed

4 files changed

+269
-2
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
3030
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
3131
from .decompose_linear_pass import DecomposeLinearPass # noqa
32+
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
3233
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
3334
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
3435
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DecomposeLeakyReLUPass,
3333
DecomposeLinearPass,
3434
DecomposeLinearVectorNormPass,
35+
DecomposeMaxPool2DPass,
3536
DecomposeMeanDimPass,
3637
DecomposeNotEqualPass,
3738
DecomposeSelectPass,
@@ -123,6 +124,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
123124
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
124125
self.add_pass(DecomposeSumPass())
125126
self.add_pass(Conv1dUnsqueezePass())
127+
self.add_pass(DecomposeMaxPool2DPass())
126128
self.add_pass(DecomposeSelectPass())
127129
self.add_pass(ConvertSqueezesToViewPass())
128130

@@ -179,6 +181,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
179181
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
180182
self.add_pass(DecomposeSumPass())
181183
self.add_pass(Conv1dUnsqueezePass())
184+
self.add_pass(DecomposeMaxPool2DPass())
182185
self.add_pass(DecomposeSelectPass())
183186
self.add_pass(ConvertSqueezesToViewPass())
184187

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import operator
9+
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
# We'll decompose only the EXIR edge max_pool2d ops when dilation > 1
14+
EDGE_MAXPOOL2D = (
15+
exir_ops.edge.aten.max_pool2d.default,
16+
exir_ops.edge.aten.max_pool2d_with_indices.default,
17+
)
18+
19+
20+
class DecomposeMaxPool2DPass(ArmPass):
21+
"""
22+
Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space.
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
# Only intercept EXIR edge max_pool2d ops
27+
if op not in EDGE_MAXPOOL2D:
28+
return super().call_operator(op, args, kwargs, meta)
29+
30+
# detect whether indices variant
31+
is_with_indices = op is exir_ops.edge.aten.max_pool2d_with_indices.default
32+
33+
# Normalize missing trailing args to their defaults
34+
x = args[0]
35+
kernel_size = args[1]
36+
stride = args[2]
37+
padding = args[3] if len(args) >= 4 else 0
38+
dilation = args[4] if len(args) >= 5 else 1
39+
40+
# Normalize attributes
41+
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
42+
d_h, d_w = (dilation, dilation) if isinstance(dilation, int) else dilation
43+
k_h, k_w = (
44+
(kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
45+
)
46+
s_h, s_w = (stride, stride) if isinstance(stride, int) else stride
47+
48+
# If no dilation: call EXIR edge op with only supported args (x, kernel, stride[, padding])
49+
if d_h == 1 and d_w == 1:
50+
minimal_args = [x, kernel_size, stride]
51+
# only include padding if non-zero
52+
if (pad_h, pad_w) != (0, 0):
53+
minimal_args.append((pad_h, pad_w))
54+
return super().call_operator(op, tuple(minimal_args), {}, meta)
55+
56+
# Compute padded and packed dimensions for dilation > 1
57+
N, C, H, W = x.data.size()
58+
ph, pw = pad_h, pad_w
59+
ph2, pw2 = pad_h, pad_w
60+
H_pad = H + ph + ph2
61+
W_pad = W + pw + pw2
62+
H_pack = (H_pad + d_h - 1) // d_h
63+
W_pack = (W_pad + d_w - 1) // d_w
64+
extra_h = 0 if H_pack < k_h else (s_h - ((H_pack - k_h) % s_h)) % s_h
65+
extra_w = 0 if W_pack < k_w else (s_w - ((W_pack - k_w) % s_w)) % s_w
66+
ph2 += extra_h * d_h
67+
pw2 += extra_w * d_w
68+
69+
# 1) Pad via EXIR edge pad (preserves dtype)
70+
pad_edge = exir_ops.edge.aten.constant_pad_nd.default
71+
pads = [pw, pw2, ph, ph2, 0, 0, 0, 0]
72+
x_pad = super().call_operator(
73+
pad_edge,
74+
(x, pads, 0),
75+
{},
76+
meta,
77+
)
78+
79+
# 2) Space-to-batch: reshape and permute
80+
x2 = super().call_operator(
81+
exir_ops.edge.aten.view_copy.default,
82+
(x_pad, [N, C, H_pack, d_h, W_pack, d_w]),
83+
{},
84+
meta,
85+
)
86+
x2 = super().call_operator(
87+
exir_ops.edge.aten.permute_copy.default,
88+
(x2, [3, 5, 0, 1, 2, 4]),
89+
{},
90+
meta,
91+
)
92+
x2 = super().call_operator(
93+
exir_ops.edge.aten.view_copy.default,
94+
(x2, [N * d_h * d_w, C, H_pack, W_pack]),
95+
{},
96+
meta,
97+
)
98+
99+
# 3) Core pooling on packed tensor
100+
pool_edge_op = (
101+
exir_ops.edge.aten.max_pool2d_with_indices.default
102+
if is_with_indices
103+
else exir_ops.edge.aten.max_pool2d.default
104+
)
105+
pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0))
106+
pool_out = super().call_operator(
107+
pool_edge_op,
108+
pool_args,
109+
{},
110+
meta,
111+
)
112+
113+
# Unpack pooled result
114+
if is_with_indices:
115+
pooled_proxy = super().call_operator(
116+
operator.getitem,
117+
(pool_out, 0),
118+
{},
119+
meta,
120+
)
121+
indices_proxy = super().call_operator(
122+
operator.getitem,
123+
(pool_out, 1),
124+
{},
125+
meta,
126+
)
127+
pooled_fake, _ = pool_out.data
128+
else:
129+
pooled_proxy = pool_out
130+
pooled_fake = pool_out.data
131+
indices_proxy = None
132+
133+
_, C_out, H_out, W_out = pooled_fake.shape
134+
135+
# 4) Batch-to-space: reshape and permute back
136+
out = super().call_operator(
137+
exir_ops.edge.aten.view_copy.default,
138+
(pooled_proxy, [d_h, d_w, N, C_out, H_out, W_out]),
139+
{},
140+
meta,
141+
)
142+
out = super().call_operator(
143+
exir_ops.edge.aten.permute_copy.default,
144+
(out, [2, 3, 4, 0, 5, 1]),
145+
{},
146+
meta,
147+
)
148+
# now flatten back into (N, C, H_out*d_h, W_out*d_w)
149+
out = super().call_operator(
150+
exir_ops.edge.aten.view_copy.default,
151+
(out, [N, C_out, H_out * d_h, W_out * d_w]),
152+
{},
153+
meta,
154+
)
155+
156+
# 5) Final crop
157+
S_top = ph // d_h + (1 if ph % d_h else 0)
158+
S_left = pw // d_w + (1 if pw % d_w else 0)
159+
S_top = max(0, min(S_top, H_out * d_h - H))
160+
S_left = max(0, min(S_left, W_out * d_w - W))
161+
out = super().call_operator(
162+
exir_ops.edge.aten.slice_copy.Tensor,
163+
(out, 2, S_top, S_top + H),
164+
{},
165+
meta,
166+
)
167+
out = super().call_operator(
168+
exir_ops.edge.aten.slice_copy.Tensor,
169+
(out, 3, S_left, S_left + W),
170+
{},
171+
meta,
172+
)
173+
174+
if is_with_indices:
175+
# Reconstruct indices
176+
idx = super().call_operator(
177+
exir_ops.edge.aten.view_copy.default,
178+
(indices_proxy, [d_h, d_w, N, C_out, H_out, W_out]),
179+
{},
180+
meta,
181+
)
182+
idx = super().call_operator(
183+
exir_ops.edge.aten.permute_copy.default,
184+
(idx, [2, 3, 4, 0, 5, 1]),
185+
{},
186+
meta,
187+
)
188+
idx = super().call_operator(
189+
exir_ops.edge.aten.view_copy.default,
190+
(idx, [N, C_out, H_out * d_h, W_out * d_w]),
191+
{},
192+
meta,
193+
)
194+
idx = super().call_operator(
195+
exir_ops.edge.aten.slice_copy.Tensor,
196+
(idx, 2, S_top, S_top + H),
197+
{},
198+
meta,
199+
)
200+
idx = super().call_operator(
201+
exir_ops.edge.aten.slice_copy.Tensor,
202+
(idx, 3, S_left, S_left + W),
203+
{},
204+
meta,
205+
)
206+
return out, idx
207+
208+
return out

backends/arm/test/ops/test_max_pool.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
TosaPipelineMI,
2020
)
2121

22-
2322
test_data_suite = {
2423
# (test_name, test_data, [kernel_size, stride, padding])
2524
"zeros": lambda: (torch.zeros(1, 1, 4, 8), [2, 2, 1]),
@@ -34,6 +33,20 @@
3433
"randn": lambda: (torch.randn(5, 16, 50, 32), [4, 2, 0]),
3534
}
3635

36+
test_data_suite_dilation = [
37+
# Simple dilation=2 on 8x8 input, kernel=3, stride=1, no padding
38+
("dilation2", torch.rand(1, 1, 8, 8), [3, 1, 0, 2]),
39+
# Input is 6x6, kernel=3, stride=1, dilation=2.
40+
# Padding=1 expands the effective input to 8x8.
41+
("pad_then_dil2", torch.rand(1, 1, 6, 6), [3, 1, 1, 2]),
42+
# Input is 16x16, kernel=2x2, stride=2x2, dilation=1 (no dilation).
43+
# Padding of 1 ensures the input size remains divisible by stride
44+
# after padding.
45+
("even_kernel_fast", torch.rand(1, 3, 16, 16), [(2, 2), (2, 2), (1, 1), 1]),
46+
# Multi-batch, multi-channel input (N=4, C=3), kernel=3x3,
47+
# stride=3x3, no padding, dilation=1.
48+
("mb_ch_dil1", torch.rand(4, 3, 12, 12), [(3, 3), (3, 3), 0, 1]),
49+
]
3750

3851
aten_op = "torch.ops.aten.max_pool2d.default"
3952
exir_op = "executorch_exir_dialects_edge__ops_aten_max_pool2d_default"
@@ -47,10 +60,14 @@ def __init__(
4760
kernel_size: int | Tuple[int, int],
4861
stride: int | Tuple[int, int],
4962
padding: int | Tuple[int, int],
63+
dilation: int | Tuple[int, int] = 1,
5064
):
5165
super().__init__()
5266
self.max_pool_2d = torch.nn.MaxPool2d(
53-
kernel_size=kernel_size, stride=stride, padding=padding
67+
kernel_size=kernel_size,
68+
stride=stride,
69+
padding=padding,
70+
dilation=dilation,
5471
)
5572

5673
def forward(self, x):
@@ -180,3 +197,41 @@ def test_max_pool2d_u55_BI_failure_set(test_data: Tuple):
180197
)
181198
pipeline.pop_stage("check_count.exir")
182199
pipeline.run()
200+
201+
202+
# Convert the list of (name, tensor, params) into the dict-of-lambdas shape
203+
dilation_test_data = {
204+
name: (lambda data=data, params=params: (data, params))
205+
for name, data, params in test_data_suite_dilation
206+
}
207+
208+
209+
@common.parametrize("test_data", dilation_test_data)
210+
def test_max_pool2d_tosa_MI_dilation(test_data):
211+
"""
212+
TOSA MI pipeline with dilation > 1 (and dilation=1 sanity cases).
213+
"""
214+
data, model_params = test_data()
215+
pipeline = TosaPipelineMI[input_t1](
216+
MaxPool2d(*model_params),
217+
(data,),
218+
aten_op,
219+
exir_op,
220+
)
221+
pipeline.run()
222+
223+
224+
@common.parametrize("test_data", dilation_test_data)
225+
def test_max_pool2d_tosa_BI_dilation(test_data):
226+
"""
227+
TOSA BI pipeline with dilation > 1 (and dilation=1 sanity cases).
228+
"""
229+
data, model_params = test_data()
230+
pipeline = TosaPipelineBI[input_t1](
231+
MaxPool2d(*model_params),
232+
(data,),
233+
aten_op,
234+
exir_op,
235+
symmetric_io_quantization=True,
236+
)
237+
pipeline.run()

0 commit comments

Comments
 (0)