Skip to content

Commit 9eb8d01

Browse files
authored
Implement ReplaceMulTensorWithMulAndFullOpsPass.
Differential Revision: D76469624 Pull Request resolved: #11577
1 parent 3a6c664 commit 9eb8d01

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,6 +2300,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23002300
return result
23012301

23022302

2303+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2304+
class ReplaceMulTensorWithMulAndFullOpsPass(ExportPass):
2305+
"""
2306+
Extracts a single value argument of mul op to a separate full op.
2307+
"""
2308+
2309+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2310+
for mul_node in graph_module.graph.find_nodes(
2311+
op="call_function", target=torch.ops.aten.mul.Tensor
2312+
):
2313+
x_arg, const_arg = mul_node.args
2314+
2315+
# Swap arguments if the order is wrong
2316+
if isinstance(const_arg, torch.fx.Node):
2317+
x_arg, const_arg = const_arg, x_arg
2318+
2319+
# Skip if the const_arg is not a scalar
2320+
if not isinstance(const_arg, (float, int)) or not isinstance(
2321+
x_arg, torch.fx.Node
2322+
):
2323+
continue
2324+
2325+
# Cast the const_arg to the dtype of the x_arg
2326+
full_arg = self.resolve_full_arg(x_arg, const_arg)
2327+
2328+
# Extract an argument to a separate full op.
2329+
with graph_module.graph.inserting_before(mul_node):
2330+
full_tensor = graph_module.graph.call_function(
2331+
exir_ops.edge.aten.full.default, args=([1], full_arg)
2332+
)
2333+
new_mul_node = graph_module.graph.call_function(
2334+
torch.ops.aten.mul.Tensor, args=(x_arg, full_tensor)
2335+
)
2336+
# Replace the old mul with a newly created mul.
2337+
mul_node.replace_all_uses_with(new_mul_node)
2338+
graph_module.graph.erase_node(mul_node)
2339+
return super().call(graph_module)
2340+
2341+
def resolve_full_arg(self, x_arg, const_arg):
2342+
if x_arg.meta["val"].dtype == torch.float32 and isinstance(const_arg, int):
2343+
const_arg = float(const_arg)
2344+
if x_arg.meta["val"].dtype == torch.int32 and isinstance(const_arg, float):
2345+
const_arg = int(const_arg)
2346+
return const_arg
2347+
2348+
23032349
# This class encapsulates all the functions that replace/switch one op in the
23042350
# graph with another.
23052351
class CadenceReplaceOpsInGraph:

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
GraphBuilder,
1616
single_op_builder,
1717
)
18-
from executorch.backends.cadence.aot.pass_utils import count_node
18+
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
1919
from executorch.backends.cadence.aot.replace_ops import (
2020
ForceChannelLastForConvPass,
2121
MakeSliceAndCatDimOutermostPass,
@@ -31,6 +31,7 @@
3131
ReplaceLinearWithFullyConnectedOpPass,
3232
ReplaceMatmulWithTransposedMatmulPass,
3333
ReplaceMMWithAddMMPass,
34+
ReplaceMulTensorWithMulAndFullOpsPass,
3435
ReplaceNopTransposeOrPermuteWithViewPass,
3536
ReplacePadWithCatPass,
3637
ReplacePermuteWithTransposePass,
@@ -1875,3 +1876,30 @@ def test_empty_slice(self):
18751876
),
18761877
1,
18771878
)
1879+
1880+
@parameterized.expand(
1881+
[
1882+
("int", int(123)),
1883+
("float", float(456.0)),
1884+
],
1885+
)
1886+
@torch.no_grad()
1887+
def test_extract_mul_argument_to_full(self, _, value) -> None:
1888+
x = torch.randn(2, 1, 64)
1889+
gm = single_op_builder(
1890+
placeholders=(x,),
1891+
op=torch.ops.aten.mul.Tensor,
1892+
args=(x, value),
1893+
kwargs={},
1894+
)
1895+
p = ReplaceMulTensorWithMulAndFullOpsPass()
1896+
graph_after_passes = p.call(gm).graph_module
1897+
self.assertTrue(
1898+
op_counts_match(
1899+
graph_after_passes,
1900+
expected_op_counts={
1901+
torch.ops.aten.mul.Tensor: 1,
1902+
exir_ops.edge.aten.full.default: 1,
1903+
},
1904+
)
1905+
)

0 commit comments

Comments
 (0)