Skip to content

Commit 1aac6b6

Browse files
committed
[Test] add tet for InsertIdentityOnAllTopLevelIO, remove onnx dumps
1 parent 358b58a commit 1aac6b6

File tree

1 file changed

+11
-22
lines changed

1 file changed

+11
-22
lines changed

tests/transformation/test_insert_identity.py

+11-22
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from qonnx.core.modelwrapper import ModelWrapper
3535
from qonnx.transformation.infer_shapes import InferShapes
36-
from qonnx.transformation.insert import InsertIdentity
36+
from qonnx.transformation.insert import InsertIdentity, InsertIdentityOnAllTopLevelIO
3737

3838

3939
@pytest.fixture
@@ -49,9 +49,16 @@ def simple_model():
4949
return model
5050

5151

52-
def save_transformed_model(model, test_name):
53-
output_path = f"{test_name}.onnx"
54-
model.save(output_path)
52+
def test_insert_identity_on_all_top_level_io(simple_model):
53+
orig_top_inp_names = [inp.name for inp in simple_model.graph.input]
54+
orig_top_out_names = [out.name for out in simple_model.graph.output]
55+
model = simple_model.transform(InsertIdentityOnAllTopLevelIO())
56+
for inp in orig_top_inp_names:
57+
assert model.find_consumer(inp).op_type == "Identity"
58+
for out in orig_top_out_names:
59+
assert model.find_producer(out).op_type == "Identity"
60+
assert orig_top_inp_names == [inp.name for inp in model.graph.input]
61+
assert orig_top_out_names == [out.name for out in model.graph.output]
5562

5663

5764
def test_insert_identity_before_input(simple_model):
@@ -63,9 +70,6 @@ def test_insert_identity_before_input(simple_model):
6370
assert identity_node is not None
6471
assert identity_node.op_type == "Identity"
6572

66-
# Save the transformed model
67-
save_transformed_model(model, "test_insert_identity_before_input")
68-
6973

7074
def test_insert_identity_after_input(simple_model):
7175
# Apply the transformation
@@ -76,9 +80,6 @@ def test_insert_identity_after_input(simple_model):
7680
assert identity_node is not None
7781
assert identity_node.op_type == "Identity"
7882

79-
# Save the transformed model
80-
save_transformed_model(model, "test_insert_identity_after_input")
81-
8283

8384
def test_insert_identity_before_intermediate(simple_model):
8485
# Apply the transformation
@@ -89,9 +90,6 @@ def test_insert_identity_before_intermediate(simple_model):
8990
assert identity_node is not None
9091
assert identity_node.op_type == "Identity"
9192

92-
# Save the transformed model
93-
save_transformed_model(model, "test_insert_identity_before_intermediate")
94-
9593

9694
def test_insert_identity_after_intermediate(simple_model):
9795
# Apply the transformation
@@ -102,9 +100,6 @@ def test_insert_identity_after_intermediate(simple_model):
102100
assert identity_node is not None
103101
assert identity_node.op_type == "Identity"
104102

105-
# Save the transformed model
106-
save_transformed_model(model, "test_insert_identity_after_intermediate")
107-
108103

109104
def test_insert_identity_before_output(simple_model):
110105
# Apply the transformation
@@ -115,9 +110,6 @@ def test_insert_identity_before_output(simple_model):
115110
assert identity_node is not None
116111
assert identity_node.op_type == "Identity"
117112

118-
# Save the transformed model
119-
save_transformed_model(model, "test_insert_identity_before_output")
120-
121113

122114
def test_insert_identity_after_output(simple_model):
123115
# Apply the transformation
@@ -128,9 +120,6 @@ def test_insert_identity_after_output(simple_model):
128120
assert identity_node is not None
129121
assert identity_node.op_type == "Identity"
130122

131-
# Save the transformed model
132-
save_transformed_model(model, "test_insert_identity_after_output")
133-
134123

135124
def test_tensor_not_found(simple_model):
136125
# Apply the transformation with a non-existent tensor

0 commit comments

Comments
 (0)