Skip to content

Commit 99b3a68

Browse files
Arm backend: Add support for aten.round (#11813)
Adds support and unittest for round-op. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 28b8198 commit 99b3a68

File tree

5 files changed

+174
-0
lines changed

5 files changed

+174
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
3434
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
3535
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
36+
from .decompose_round_pass import DecomposeRoundPass # noqa
3637
from .decompose_select import DecomposeSelectPass # noqa
3738
from .decompose_silu_pass import DecomposeSiluPass # noqa
3839
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DecomposeMaxPool2DPass,
3737
DecomposeMeanDimPass,
3838
DecomposeNotEqualPass,
39+
DecomposeRoundPass,
3940
DecomposeSelectPass,
4041
DecomposeSiluPass,
4142
DecomposeSoftmaxPass,
@@ -141,6 +142,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
141142
return self._transform(exported_program.graph_module)
142143

143144
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
145+
self.add_pass(DecomposeRoundPass())
144146
self.add_pass(DecomposeSqrtPass())
145147
self.add_pass(ConvertIntPowToMuls())
146148
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
@@ -222,6 +224,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
222224
self.add_pass(InsertCastForOpsWithInt64InputPass())
223225
self.add_pass(DecomposeEmbeddingPass())
224226
self.add_pass(DecomposeScaledDotProductAttention())
227+
self.add_pass(DecomposeRoundPass())
225228
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
226229
self.add_pass(ScalarsToAttributePass())
227230
self.add_pass(DecomposeGroupNormPass())
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from torch._ops import OpOverload
11+
12+
13+
Op = OpOverload | EdgeOpOverload
14+
15+
16+
def _get_round_decomposition_ops(op) -> tuple[Op, Op, Op, Op, Op, Op, Op]:
17+
"""
18+
Returns the (full_op, ge_op, add_op, sub_op, floor_op, ceil_op, where_op) for the
19+
given round operation. The ops depend on whether the round op is an aten or edge op.
20+
"""
21+
if op == exir_ops.edge.aten.round.default:
22+
return (
23+
exir_ops.edge.aten.full.default,
24+
exir_ops.edge.aten.ge.Tensor,
25+
exir_ops.edge.aten.add.Scalar,
26+
exir_ops.edge.aten.sub.Scalar,
27+
exir_ops.edge.aten.floor.default,
28+
exir_ops.edge.aten.ceil.default,
29+
exir_ops.edge.aten.where.self,
30+
)
31+
elif op == torch.ops.aten.round.default:
32+
return (
33+
torch.ops.aten.full.default,
34+
torch.ops.aten.ge.Tensor,
35+
torch.ops.aten.add.Scalar,
36+
torch.ops.aten.sub.Scalar,
37+
torch.ops.aten.floor.default,
38+
torch.ops.aten.ceil.default,
39+
torch.ops.aten.where.self,
40+
)
41+
raise RuntimeError(f"Can't get round decomposition ops for op {op}")
42+
43+
44+
class DecomposeRoundPass(ArmPass):
45+
"""
46+
For inputs >= 0, round(x) is equivalent to floor(x + 0.5), and for inputs < 0,
47+
round(x) is equivalent to ceil(x - 0.5). This pass decomposes the round operation into
48+
a sequence of more primitive operations.
49+
Example:
50+
%zero = full((1,), 0.0, dtype=torch.float32)
51+
%is_non_negative = ge(x, %zero)
52+
%plus_half = add(x, 0.5)
53+
%minus_half = sub(x, 0.5)
54+
%floor = floor(%plus_half)
55+
%ceil = ceil(%minus_half)
56+
%result = where(%is_non_negative, %floor, %ceil)
57+
"""
58+
59+
def call_operator(self, op, args, kwargs, meta, updated=False):
60+
if op not in (exir_ops.edge.aten.round.default, torch.ops.aten.round.default):
61+
return super().call_operator(op, args, kwargs, meta, updated)
62+
x = args[0]
63+
full, ge, add, sub, floor, ceil, where = _get_round_decomposition_ops(op)
64+
zero = super().call_operator(
65+
full,
66+
args=((1,), 0.0),
67+
kwargs={"dtype": torch.float32},
68+
meta=meta,
69+
updated=True,
70+
)
71+
is_non_negative = super().call_operator(
72+
ge, (x, zero), kwargs, meta, updated=True
73+
)
74+
plus_half = super().call_operator(add, (x, 0.5), kwargs, meta, updated=True)
75+
minus_half = super().call_operator(sub, (x, 0.5), kwargs, meta, updated=True)
76+
floor = super().call_operator(floor, (plus_half,), kwargs, meta, updated=True)
77+
ceil = super().call_operator(ceil, (minus_half,), kwargs, meta, updated=True)
78+
return super().call_operator(
79+
where,
80+
(is_non_negative, floor, ceil),
81+
kwargs,
82+
meta,
83+
updated=True,
84+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def is_node_supported(
211211
exir_ops.edge.aten.leaky_relu.default,
212212
exir_ops.edge.aten.sqrt.default,
213213
exir_ops.edge.aten.rsqrt.default,
214+
exir_ops.edge.aten.round.default,
214215
exir_ops.edge.aten._softmax.default,
215216
exir_ops.edge.aten.select_copy.int,
216217
exir_ops.edge.aten._log_softmax.default,
@@ -281,6 +282,7 @@ def is_node_supported(
281282
exir_ops.edge.aten.ne.Scalar: None,
282283
exir_ops.edge.aten.div.Scalar: None,
283284
exir_ops.edge.aten.leaky_relu.default: None,
285+
exir_ops.edge.aten.round.default: None,
284286
}
285287

286288
if node.target in needs_decomp_dict:

backends/arm/test/ops/test_round.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import pytest
10+
import torch
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
input_t1 = Tuple[torch.Tensor] # Input x
20+
21+
aten_op = "torch.ops.aten.round.default"
22+
exir_op = "executorch_exir_dialects_edge__ops_aten_round_default"
23+
24+
test_data_suite = {
25+
# (test_name, test_data)
26+
"zeros": lambda: torch.zeros(1, 10, 10, 10),
27+
"ones": lambda: torch.ones(10, 10, 10),
28+
"rand": lambda: torch.rand(10, 10) - 0.5,
29+
"randn_pos": lambda: torch.randn(10) + 10,
30+
"randn_neg": lambda: torch.randn(10) - 10,
31+
"ramp": lambda: torch.arange(-16, 16, 0.2),
32+
}
33+
34+
35+
class Round(torch.nn.Module):
36+
def forward(self, x: torch.Tensor):
37+
return x.round()
38+
39+
40+
@common.parametrize("test_data", test_data_suite)
41+
def test_round_tosa_MI(test_data: torch.Tensor):
42+
pipeline = TosaPipelineMI[input_t1](
43+
Round(),
44+
(test_data(),),
45+
aten_op,
46+
exir_op,
47+
)
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", test_data_suite)
52+
def test_round_tosa_BI(test_data: torch.Tensor):
53+
pipeline = TosaPipelineBI[input_t1](
54+
Round(),
55+
(test_data(),),
56+
[],
57+
exir_op,
58+
)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", test_data_suite)
63+
@common.XfailIfNoCorstone300
64+
@pytest.mark.xfail(reason="where.self not supported on U55")
65+
def test_round_u55_BI(test_data: torch.Tensor):
66+
pipeline = EthosU55PipelineBI[input_t1](
67+
Round(),
68+
(test_data(),),
69+
[],
70+
exir_op,
71+
)
72+
pipeline.run()
73+
74+
75+
@common.parametrize("test_data", test_data_suite)
76+
@common.XfailIfNoCorstone320
77+
def test_round_u85_BI(test_data: torch.Tensor):
78+
pipeline = EthosU85PipelineBI[input_t1](
79+
Round(),
80+
(test_data(),),
81+
[],
82+
exir_op,
83+
)
84+
pipeline.run()

0 commit comments

Comments
 (0)