@@ -41,19 +41,62 @@ def setUp(self):
41
41
testcase_name = "add_new_op_code" ,
42
42
op_code = schema_py_generated .BuiltinOperator .LOGISTIC ,
43
43
expected = 1 ,
44
+ custom_op_name = None ,
44
45
),
45
46
dict (
46
47
testcase_name = "add_existing_op_code" ,
47
48
op_code = schema_py_generated .BuiltinOperator .FULLY_CONNECTED ,
48
49
expected = 0 ,
50
+ custom_op_name = None ,
51
+ ),
52
+ dict (
53
+ testcase_name = "add_new_custom_op_code" ,
54
+ op_code = schema_py_generated .BuiltinOperator .CUSTOM ,
55
+ expected = 1 ,
56
+ custom_op_name = "random_new_custom_op" ,
49
57
),
50
58
)
51
- def test_add_op_code (self , op_code , expected ):
59
+ def test_add_op_code (self , op_code , expected , custom_op_name ):
52
60
"""Tests if the op code is added to the model."""
53
61
got = transformation_utils .add_op_code (
54
- op_code = op_code , model_op_codes = self .model .operatorCodes
62
+ op_code = op_code ,
63
+ model_op_codes = self .model .operatorCodes ,
64
+ custom_op_name = custom_op_name ,
55
65
)
56
66
self .assertEqual (expected , got )
67
+ if custom_op_name is not None :
68
+ self .assertEqual (self .model .operatorCodes [got ].customCode , custom_op_name )
69
+
70
+ def test_add_custom_op_code_without_op_string_raises_error (self ):
71
+ with self .assertRaisesRegex (ValueError , "Custom string is required" ):
72
+ transformation_utils .add_op_code (
73
+ op_code = schema_py_generated .BuiltinOperator .CUSTOM ,
74
+ model_op_codes = self .model .operatorCodes ,
75
+ custom_op_name = None ,
76
+ )
77
+
78
+ def test_add_two_custom_op_codes (self ):
79
+ custom_op_name = "random_new_custom_op"
80
+ added_index = transformation_utils .add_op_code (
81
+ op_code = schema_py_generated .BuiltinOperator .CUSTOM ,
82
+ model_op_codes = self .model .operatorCodes ,
83
+ custom_op_name = custom_op_name ,
84
+ )
85
+ self .assertEqual (1 , added_index )
86
+ self .assertEqual (
87
+ self .model .operatorCodes [added_index ].customCode , custom_op_name
88
+ )
89
+
90
+ custom_op_name_2 = "random_new_custom_op_2"
91
+ added_index = transformation_utils .add_op_code (
92
+ op_code = schema_py_generated .BuiltinOperator .CUSTOM ,
93
+ model_op_codes = self .model .operatorCodes ,
94
+ custom_op_name = custom_op_name_2 ,
95
+ )
96
+ self .assertEqual (2 , added_index )
97
+ self .assertEqual (
98
+ self .model .operatorCodes [added_index ].customCode , custom_op_name_2
99
+ )
57
100
58
101
@parameterized .named_parameters (
59
102
dict (
@@ -189,6 +232,25 @@ def test_add_new_activation_tensor_to_subgraph(
189
232
self .model .subgraphs [0 ].tensors [- 1 ].shape ,
190
233
)
191
234
235
+ def test_add_new_activation_tensor_with_dynamic_shape (self ):
236
+ """Tests adding an activation tensor with dynamic shape."""
237
+ subgraph = self .model .subgraphs [0 ]
238
+ new_id = transformation_utils .add_new_activation_tensor (
239
+ tensor_name = "test_tensor" ,
240
+ shape = [1 , - 1 , - 1 , 1 ],
241
+ tensor_type = schema_py_generated .TensorType .FLOAT32 ,
242
+ subgraph = subgraph ,
243
+ )
244
+ # Originally had 4 tensors, new tensor is added at index 4.
245
+ self .assertEqual (new_id , 4 )
246
+ self .assertLen (subgraph .tensors , 5 )
247
+ self .assertEqual (subgraph .tensors [- 1 ].name , "test_tensor" )
248
+ self .assertEqual (
249
+ subgraph .tensors [- 1 ].type , schema_py_generated .TensorType .FLOAT32
250
+ )
251
+ self .assertEqual (subgraph .tensors [- 1 ].shape , [1 , 1 , 1 , 1 ])
252
+ self .assertEqual (subgraph .tensors [- 1 ].shapeSignature , [1 , - 1 , - 1 , 1 ])
253
+
192
254
193
255
if __name__ == "__main__" :
194
256
googletest .main ()
0 commit comments