Skip to content

Commit 5d8df8d

Browse files
paulineshocopybara-github
authored andcommitted
Support custom op and dynamic shape in tensor creation
PiperOrigin-RevId: 752821626
1 parent d351ccc commit 5d8df8d

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

ai_edge_quantizer/transformations/transformation_utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,39 @@ class TransformationInput:
5151
def add_op_code(
5252
op_code: schema_py_generated.OperatorCodeT,
5353
model_op_codes: list[schema_py_generated.OperatorCodeT],
54+
custom_op_name: Optional[str] = None,
5455
) -> int:
5556
"""Add an op code into a model if it's not present.
5657
5758
Args:
5859
op_code: The op code to be added.
5960
model_op_codes: The op codes of the model.
61+
custom_op_name: The custom string of the op code. If None, the op code will
62+
be added as a builtin op code.
6063
6164
Returns:
6265
The index of the op code in the model.
6366
"""
67+
if (
68+
op_code == schema_py_generated.BuiltinOperator.CUSTOM
69+
and custom_op_name is None
70+
):
71+
raise ValueError('Custom string is required for custom op code.')
72+
6473
for i, model_op_code in enumerate(model_op_codes):
74+
# If the model already has the op code, just return the index.
6575
if model_op_code.builtinCode == op_code:
66-
return i
76+
if custom_op_name is not None:
77+
if model_op_code.customCode == custom_op_name:
78+
return i
79+
else:
80+
# Built-in op
81+
return i
82+
6783
model_op_codes.append(schema_py_generated.OperatorCodeT())
6884
model_op_codes[-1].builtinCode = op_code
85+
if custom_op_name is not None:
86+
model_op_codes[-1].customCode = custom_op_name
6987
return len(model_op_codes) - 1
7088

7189

@@ -146,7 +164,14 @@ def add_new_activation_tensor(
146164
The index of the new tensor in the subgraph.
147165
"""
148166
new_tensor = schema_py_generated.TensorT()
149-
new_tensor.shape = shape
167+
# If there's a dynamic shape, we need to read from the shapeSignature field
168+
# instead of shape. Shape should contain just 1 for the dynamic dimension but
169+
# shapeSignature should contain the true shape.
170+
if -1 in shape:
171+
new_tensor.shapeSignature = shape
172+
new_tensor.shape = [1 if i == -1 else i for i in shape]
173+
else:
174+
new_tensor.shape = shape
150175
new_tensor.type = tensor_type
151176
new_tensor.name = tensor_name
152177
new_tensor.buffer = 0

ai_edge_quantizer/transformations/transformation_utils_test.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,62 @@ def setUp(self):
4141
testcase_name="add_new_op_code",
4242
op_code=schema_py_generated.BuiltinOperator.LOGISTIC,
4343
expected=1,
44+
custom_op_name=None,
4445
),
4546
dict(
4647
testcase_name="add_existing_op_code",
4748
op_code=schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
4849
expected=0,
50+
custom_op_name=None,
51+
),
52+
dict(
53+
testcase_name="add_new_custom_op_code",
54+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
55+
expected=1,
56+
custom_op_name="random_new_custom_op",
4957
),
5058
)
51-
def test_add_op_code(self, op_code, expected):
59+
def test_add_op_code(self, op_code, expected, custom_op_name):
5260
"""Tests if the op code is added to the model."""
5361
got = transformation_utils.add_op_code(
54-
op_code=op_code, model_op_codes=self.model.operatorCodes
62+
op_code=op_code,
63+
model_op_codes=self.model.operatorCodes,
64+
custom_op_name=custom_op_name,
5565
)
5666
self.assertEqual(expected, got)
67+
if custom_op_name is not None:
68+
self.assertEqual(self.model.operatorCodes[got].customCode, custom_op_name)
69+
70+
def test_add_custom_op_code_without_op_string_raises_error(self):
71+
with self.assertRaisesRegex(ValueError, "Custom string is required"):
72+
transformation_utils.add_op_code(
73+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
74+
model_op_codes=self.model.operatorCodes,
75+
custom_op_name=None,
76+
)
77+
78+
def test_add_two_custom_op_codes(self):
79+
custom_op_name = "random_new_custom_op"
80+
added_index = transformation_utils.add_op_code(
81+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
82+
model_op_codes=self.model.operatorCodes,
83+
custom_op_name=custom_op_name,
84+
)
85+
self.assertEqual(1, added_index)
86+
self.assertEqual(
87+
self.model.operatorCodes[added_index].customCode, custom_op_name
88+
)
89+
90+
custom_op_name_2 = "random_new_custom_op_2"
91+
added_index = transformation_utils.add_op_code(
92+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
93+
model_op_codes=self.model.operatorCodes,
94+
custom_op_name=custom_op_name_2,
95+
)
96+
self.assertEqual(2, added_index)
97+
self.assertEqual(
98+
self.model.operatorCodes[added_index].customCode, custom_op_name_2
99+
)
57100

58101
@parameterized.named_parameters(
59102
dict(
@@ -189,6 +232,25 @@ def test_add_new_activation_tensor_to_subgraph(
189232
self.model.subgraphs[0].tensors[-1].shape,
190233
)
191234

235+
def test_add_new_activation_tensor_with_dynamic_shape(self):
236+
"""Tests adding an activation tensor with dynamic shape."""
237+
subgraph = self.model.subgraphs[0]
238+
new_id = transformation_utils.add_new_activation_tensor(
239+
tensor_name="test_tensor",
240+
shape=[1, -1, -1, 1],
241+
tensor_type=schema_py_generated.TensorType.FLOAT32,
242+
subgraph=subgraph,
243+
)
244+
# Originally had 4 tensors, new tensor is added at index 4.
245+
self.assertEqual(new_id, 4)
246+
self.assertLen(subgraph.tensors, 5)
247+
self.assertEqual(subgraph.tensors[-1].name, "test_tensor")
248+
self.assertEqual(
249+
subgraph.tensors[-1].type, schema_py_generated.TensorType.FLOAT32
250+
)
251+
self.assertEqual(subgraph.tensors[-1].shape, [1, 1, 1, 1])
252+
self.assertEqual(subgraph.tensors[-1].shapeSignature, [1, -1, -1, 1])
253+
192254

193255
if __name__ == "__main__":
194256
googletest.main()

0 commit comments

Comments
 (0)