Skip to content

Commit 358b58a

Browse files
committed
[Transform] introduce InsertIdentityOnAllTopLevelIO
1 parent d5a212e commit 358b58a

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

src/qonnx/transformation/insert.py

+16
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@
3232
from qonnx.transformation.general import SortGraph
3333

3434

35+
class InsertIdentityOnAllTopLevelIO(Transformation):
36+
"""
37+
Transformation that inserts an Identity node on all top-level inputs and outputs
38+
of the ONNX graph. This can be useful before calling transformations that do not
39+
gracefully handle edge cases where transformed tensors are top-level inputs or outputs.
40+
"""
41+
42+
def apply(self, model):
43+
graph = model.graph
44+
for inp in graph.input:
45+
model = model.transform(InsertIdentity(inp.name, "consumer"))
46+
for out in graph.output:
47+
model = model.transform(InsertIdentity(out.name, "producer"))
48+
return model, False
49+
50+
3551
class InsertIdentity(Transformation):
3652
"""
3753
Transformation that inserts an Identity node in the ONNX graph. For edge cases

0 commit comments

Comments
 (0)