Skip to content

Commit d5a212e

Browse files
committed
[Test] add unit tests for InsertIdentity
1 parent b7eb655 commit d5a212e

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) 2025 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of AMD nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import pytest
30+
31+
from onnx import TensorProto
32+
from onnx import helper as oh
33+
34+
from qonnx.core.modelwrapper import ModelWrapper
35+
from qonnx.transformation.infer_shapes import InferShapes
36+
from qonnx.transformation.insert import InsertIdentity
37+
38+
39+
@pytest.fixture
40+
def simple_model():
41+
# Create a simple ONNX model for testing
42+
input_tensor = oh.make_tensor_value_info("input", TensorProto.FLOAT, [1, 2])
43+
output_tensor = oh.make_tensor_value_info("output", TensorProto.FLOAT, [1, 2])
44+
node1 = oh.make_node("Relu", ["input"], ["intermediate"])
45+
node2 = oh.make_node("Relu", ["intermediate"], ["output"])
46+
graph = oh.make_graph([node1, node2], "test_graph", [input_tensor], [output_tensor])
47+
model = ModelWrapper(oh.make_model(graph))
48+
model = model.transform(InferShapes())
49+
return model
50+
51+
52+
def save_transformed_model(model, test_name):
53+
output_path = f"{test_name}.onnx"
54+
model.save(output_path)
55+
56+
57+
def test_insert_identity_before_input(simple_model):
58+
# Apply the transformation
59+
transformation = InsertIdentity("input", "producer")
60+
model = simple_model.transform(transformation)
61+
62+
identity_node = model.find_producer("input")
63+
assert identity_node is not None
64+
assert identity_node.op_type == "Identity"
65+
66+
# Save the transformed model
67+
save_transformed_model(model, "test_insert_identity_before_input")
68+
69+
70+
def test_insert_identity_after_input(simple_model):
71+
# Apply the transformation
72+
transformation = InsertIdentity("input", "consumer")
73+
model = simple_model.transform(transformation)
74+
75+
identity_node = model.find_consumer("input")
76+
assert identity_node is not None
77+
assert identity_node.op_type == "Identity"
78+
79+
# Save the transformed model
80+
save_transformed_model(model, "test_insert_identity_after_input")
81+
82+
83+
def test_insert_identity_before_intermediate(simple_model):
84+
# Apply the transformation
85+
transformation = InsertIdentity("intermediate", "producer")
86+
model = simple_model.transform(transformation)
87+
88+
identity_node = model.find_producer("intermediate")
89+
assert identity_node is not None
90+
assert identity_node.op_type == "Identity"
91+
92+
# Save the transformed model
93+
save_transformed_model(model, "test_insert_identity_before_intermediate")
94+
95+
96+
def test_insert_identity_after_intermediate(simple_model):
97+
# Apply the transformation
98+
transformation = InsertIdentity("intermediate", "consumer")
99+
model = simple_model.transform(transformation)
100+
101+
identity_node = model.find_consumer("intermediate")
102+
assert identity_node is not None
103+
assert identity_node.op_type == "Identity"
104+
105+
# Save the transformed model
106+
save_transformed_model(model, "test_insert_identity_after_intermediate")
107+
108+
109+
def test_insert_identity_before_output(simple_model):
110+
# Apply the transformation
111+
transformation = InsertIdentity("output", "producer")
112+
model = simple_model.transform(transformation)
113+
114+
identity_node = model.find_producer("output")
115+
assert identity_node is not None
116+
assert identity_node.op_type == "Identity"
117+
118+
# Save the transformed model
119+
save_transformed_model(model, "test_insert_identity_before_output")
120+
121+
122+
def test_insert_identity_after_output(simple_model):
123+
# Apply the transformation
124+
transformation = InsertIdentity("output", "consumer")
125+
model = simple_model.transform(transformation)
126+
127+
identity_node = model.find_consumer("output")
128+
assert identity_node is not None
129+
assert identity_node.op_type == "Identity"
130+
131+
# Save the transformed model
132+
save_transformed_model(model, "test_insert_identity_after_output")
133+
134+
135+
def test_tensor_not_found(simple_model):
136+
# Apply the transformation with a non-existent tensor
137+
transformation = InsertIdentity("non_existent_tensor", "producer")
138+
with pytest.raises(ValueError):
139+
simple_model.transform(transformation)

0 commit comments

Comments
 (0)