33
33
34
34
from qonnx .core .modelwrapper import ModelWrapper
35
35
from qonnx .transformation .infer_shapes import InferShapes
36
- from qonnx .transformation .insert import InsertIdentity
36
+ from qonnx .transformation .insert import InsertIdentity , InsertIdentityOnAllTopLevelIO
37
37
38
38
39
39
@pytest .fixture
@@ -49,9 +49,16 @@ def simple_model():
49
49
return model
50
50
51
51
52
- def save_transformed_model (model , test_name ):
53
- output_path = f"{ test_name } .onnx"
54
- model .save (output_path )
52
+ def test_insert_identity_on_all_top_level_io (simple_model ):
53
+ orig_top_inp_names = [inp .name for inp in simple_model .graph .input ]
54
+ orig_top_out_names = [out .name for out in simple_model .graph .output ]
55
+ model = simple_model .transform (InsertIdentityOnAllTopLevelIO ())
56
+ for inp in orig_top_inp_names :
57
+ assert model .find_consumer (inp ).op_type == "Identity"
58
+ for out in orig_top_out_names :
59
+ assert model .find_producer (out ).op_type == "Identity"
60
+ assert orig_top_inp_names == [inp .name for inp in model .graph .input ]
61
+ assert orig_top_out_names == [out .name for out in model .graph .output ]
55
62
56
63
57
64
def test_insert_identity_before_input (simple_model ):
@@ -63,9 +70,6 @@ def test_insert_identity_before_input(simple_model):
63
70
assert identity_node is not None
64
71
assert identity_node .op_type == "Identity"
65
72
66
- # Save the transformed model
67
- save_transformed_model (model , "test_insert_identity_before_input" )
68
-
69
73
70
74
def test_insert_identity_after_input (simple_model ):
71
75
# Apply the transformation
@@ -76,9 +80,6 @@ def test_insert_identity_after_input(simple_model):
76
80
assert identity_node is not None
77
81
assert identity_node .op_type == "Identity"
78
82
79
- # Save the transformed model
80
- save_transformed_model (model , "test_insert_identity_after_input" )
81
-
82
83
83
84
def test_insert_identity_before_intermediate (simple_model ):
84
85
# Apply the transformation
@@ -89,9 +90,6 @@ def test_insert_identity_before_intermediate(simple_model):
89
90
assert identity_node is not None
90
91
assert identity_node .op_type == "Identity"
91
92
92
- # Save the transformed model
93
- save_transformed_model (model , "test_insert_identity_before_intermediate" )
94
-
95
93
96
94
def test_insert_identity_after_intermediate (simple_model ):
97
95
# Apply the transformation
@@ -102,9 +100,6 @@ def test_insert_identity_after_intermediate(simple_model):
102
100
assert identity_node is not None
103
101
assert identity_node .op_type == "Identity"
104
102
105
- # Save the transformed model
106
- save_transformed_model (model , "test_insert_identity_after_intermediate" )
107
-
108
103
109
104
def test_insert_identity_before_output (simple_model ):
110
105
# Apply the transformation
@@ -115,9 +110,6 @@ def test_insert_identity_before_output(simple_model):
115
110
assert identity_node is not None
116
111
assert identity_node .op_type == "Identity"
117
112
118
- # Save the transformed model
119
- save_transformed_model (model , "test_insert_identity_before_output" )
120
-
121
113
122
114
def test_insert_identity_after_output (simple_model ):
123
115
# Apply the transformation
@@ -128,9 +120,6 @@ def test_insert_identity_after_output(simple_model):
128
120
assert identity_node is not None
129
121
assert identity_node .op_type == "Identity"
130
122
131
- # Save the transformed model
132
- save_transformed_model (model , "test_insert_identity_after_output" )
133
-
134
123
135
124
def test_tensor_not_found (simple_model ):
136
125
# Apply the transformation with a non-existent tensor
0 commit comments