Skip to content

Commit c57e099

Browse files
marialyucopybara-github
authored andcommitted
Add support for PADV2 op for int8 and int16
PiperOrigin-RevId: 755936791
1 parent 2de738a commit c57e099

File tree

10 files changed

+280
-20
lines changed

10 files changed

+280
-20
lines changed

ai_edge_quantizer/algorithm_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class AlgorithmName(str, enum.Enum):
108108
),
109109
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
110110
_TFLOpName.PAD: common_quantize.materialize_pad,
111+
_TFLOpName.PADV2: common_quantize.materialize_padv2,
111112
}
112113
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
113114
register_quantized_op(
@@ -242,6 +243,7 @@ class AlgorithmName(str, enum.Enum):
242243
),
243244
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
244245
_TFLOpName.PAD: common_quantize.materialize_pad,
246+
_TFLOpName.PADV2: common_quantize.materialize_padv2,
245247
})
246248

247249
for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():

ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,24 @@ def materialize_pad(
693693
tensor_name_to_qsv,
694694
get_tensor_quant_params_fn,
695695
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
696-
inputs_to_ignore=[1], # Padding value does not need to be quantized.
696+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
697+
)
698+
699+
700+
def materialize_padv2(
701+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
702+
op_info: qtyping.OpInfo,
703+
graph_info: qtyping.GraphInfo,
704+
tensor_name_to_qsv: dict[str, Any],
705+
) -> list[qtyping.TensorTransformationParams]:
706+
"""Materialize tensors in tfl.padv2."""
707+
return common_utils.materialize_standard_op(
708+
op_info,
709+
graph_info,
710+
tensor_name_to_qsv,
711+
get_tensor_quant_params_fn,
712+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
713+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
697714
)
698715

699716

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2024 The AI Edge Quantizer Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import os
17+
18+
from absl.testing import parameterized
19+
import numpy as np
20+
21+
from tensorflow.python.platform import googletest
22+
from ai_edge_quantizer import qtyping
23+
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
24+
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
25+
from ai_edge_quantizer.algorithms.uniform_quantize import octav
26+
from ai_edge_quantizer.algorithms.uniform_quantize.op_architecture_tests import test_utils as op_test_utils
27+
from ai_edge_quantizer.utils import test_utils
28+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
29+
30+
31+
_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(
32+
"../../../tests/models"
33+
)
34+
35+
36+
class PadV2Test(op_test_utils.BaseQuantizeTest):
37+
38+
def setUp(self):
39+
super().setUp()
40+
np.random.seed(666)
41+
self._test_model_path = os.path.join(
42+
_TEST_DATA_PREFIX_PATH, "single_padv2.tflite"
43+
)
44+
self._op_test_info = op_test_utils.OpTestInfo(
45+
test_model=tfl_flatbuffer_utils.read_model(self._test_model_path),
46+
op_tensor_names={},
47+
input_range=(np.array([[-10]]), np.array([[10]])),
48+
output_range=(np.array([[-10]]), np.array([[10]])),
49+
)
50+
# The test model has one subgraph for now.
51+
self._graph_info = qtyping.GraphInfo(
52+
subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors,
53+
buffers=self._op_test_info.test_model.buffers,
54+
)
55+
56+
@parameterized.product(
57+
get_tensor_quant_params_func=(
58+
naive_min_max_quantize.get_tensor_quant_params,
59+
octav.get_tensor_quant_params,
60+
),
61+
activations_num_bits_and_symmetric=[
62+
(8, False),
63+
(8, True),
64+
(16, True),
65+
],
66+
)
67+
def test_materialize_padv2_succeeds(
68+
self, get_tensor_quant_params_func, activations_num_bits_and_symmetric
69+
):
70+
activation_config = test_utils.get_static_activation_quant_setting(
71+
*activations_num_bits_and_symmetric
72+
)
73+
op_quant_config = test_utils.get_static_op_quant_config(activation_config)
74+
75+
# Read from Model Explorer.
76+
subgraph0 = self._op_test_info.test_model.subgraphs[0]
77+
subgraph_op_id = 0
78+
op = subgraph0.operators[subgraph_op_id]
79+
op_info = qtyping.OpInfo(
80+
op=op,
81+
op_name=qtyping.TFLOperationName.PADV2,
82+
subgraph_op_index=subgraph_op_id,
83+
op_quant_config=op_quant_config,
84+
)
85+
86+
# Test settings.
87+
op_tensor_names = {}
88+
op_tensor_names["input"] = "serving_default_input:0"
89+
op_tensor_names["input2"] = "Const_1"
90+
op_tensor_names["input3"] = "Const"
91+
op_tensor_names["output"] = "PartitionedCall:0"
92+
self._op_test_info.op_tensor_names = op_tensor_names
93+
self._test_no_weights_op(
94+
op_info,
95+
self._graph_info,
96+
self._op_test_info,
97+
common_quantize.materialize_padv2,
98+
get_tensor_quant_params_func,
99+
same_input_output_params=True,
100+
inputs_to_ignore=[1], # Padding tensor does not need to be quantized.
101+
constant_inputs=[2], # constant_values (padding value) is quantized.
102+
)
103+
104+
105+
if __name__ == "__main__":
106+
googletest.main()

ai_edge_quantizer/algorithms/uniform_quantize/op_architecture_tests/test_utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def _test_same_direction_tensors(
124124
num_tensors,
125125
indices_to_ignore,
126126
is_inbounding_tensor,
127+
constant_inputs=None,
127128
):
128129
"""Tests all input or output tensors in provided quant params.
129130
@@ -135,7 +136,13 @@ def _test_same_direction_tensors(
135136
num_tensors: Number of tensors to test.
136137
indices_to_ignore: Indices of tensors to ignore.
137138
is_inbounding_tensor: Whether to test all inbounding tensors.
139+
constant_inputs: Constant inputs indices.
138140
"""
141+
if not is_inbounding_tensor and constant_inputs is not None:
142+
raise ValueError(
143+
"Constant inputs should only be used for inbounding tensors."
144+
)
145+
constant_inputs = constant_inputs or []
139146
tensor_base_name = "input" if is_inbounding_tensor else "output"
140147
tensor_names = [tensor_base_name] + [
141148
f"{tensor_base_name}{i+2}" for i in range(num_tensors - 1)
@@ -148,9 +155,15 @@ def _test_same_direction_tensors(
148155
and i not in indices_to_ignore
149156
):
150157
if is_inbounding_tensor:
151-
transformations = [_QuantTransformation.ADD_QUANTIZE]
158+
if i in constant_inputs:
159+
transformations = [_QuantTransformation.QUANTIZE_TENSOR]
160+
else:
161+
transformations = [_QuantTransformation.ADD_QUANTIZE]
152162
else:
153-
transformations = [_QuantTransformation.ADD_DEQUANTIZE]
163+
if i in constant_inputs:
164+
transformations = []
165+
else:
166+
transformations = [_QuantTransformation.ADD_DEQUANTIZE]
154167
else:
155168
transformations = [_QuantTransformation.NO_QUANTIZE]
156169
self._test_tensor_transformation_params(
@@ -172,6 +185,8 @@ def _test_same_input_output_params(
172185
):
173186
"""Tests input and output tensor transformation parameters are the same.
174187
188+
Assumes that each of non-ignored input tensors has exactly one consumer.
189+
175190
Args:
176191
tensor_quant_params: Tensor transformation parameters.
177192
num_inputs: Number of inputs in materialization function result.
@@ -189,11 +204,17 @@ def _test_same_input_output_params(
189204
for i in range(num_inputs):
190205
if i not in inputs_to_ignore:
191206
if expected_params is None:
192-
expected_params = tensor_quant_params[i].consumers[0].parameters # pytype: disable=attribute-error
207+
# Intputs can be constants and therefore have different quantized
208+
# data. Ignoring `quantized_data` in comparison.
209+
expected_params = dataclasses.replace(
210+
tensor_quant_params[i].consumers[0].parameters, # pytype: disable=attribute-error
211+
quantized_data=None,
212+
)
193213
else:
194-
input_tensor_quant_params = (
195-
tensor_quant_params[i].consumers[0].parameters
196-
) # pytype: disable=attribute-error
214+
input_tensor_quant_params = dataclasses.replace(
215+
tensor_quant_params[i].consumers[0].parameters, # pytype: disable=attribute-error
216+
quantized_data=None,
217+
)
197218
self.assertEqual(input_tensor_quant_params, expected_params)
198219

199220
# Test outputs.
@@ -215,6 +236,7 @@ def _test_no_weights_op(
215236
same_input_output_params=False,
216237
inputs_to_ignore=None,
217238
outputs_to_ignore=None,
239+
constant_inputs=None,
218240
):
219241
"""Test an op without weights and bias.
220242
@@ -232,6 +254,7 @@ def _test_no_weights_op(
232254
transformation parameters are the same.
233255
inputs_to_ignore: Inputs to ignore.
234256
outputs_to_ignore: Outputs to ignore.
257+
constant_inputs: Indices of constant inputs.
235258
"""
236259
num_inputs = len(op_info.op.inputs)
237260
num_outputs = len(op_info.op.outputs)
@@ -261,6 +284,7 @@ def _test_no_weights_op(
261284
num_inputs,
262285
inputs_to_ignore,
263286
is_inbounding_tensor=True,
287+
constant_inputs=constant_inputs,
264288
)
265289
# Test output tensor settings.
266290
outputs_to_ignore = outputs_to_ignore or []

ai_edge_quantizer/algorithms/utils/common_utils.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,52 @@ def _materialize_standard_op_with_same_as_output_scale(
351351
# Use output quantization params for all input tensors.
352352
if output_tensor_params.producer is None:
353353
quant_params = None
354+
_materialize_op_tensors(
355+
op_tensor_params,
356+
input_tensors,
357+
is_inbounding_tensor=True,
358+
op_info=op_info,
359+
graph_info=graph_info,
360+
tensor_name_to_qsv=tensor_name_to_qsv,
361+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
362+
quant_params=quant_params,
363+
)
354364
else:
355-
quant_params = output_tensor_params.producer.parameters
356-
_materialize_op_tensors(
357-
op_tensor_params,
358-
input_tensors,
359-
is_inbounding_tensor=True,
360-
op_info=op_info,
361-
graph_info=graph_info,
362-
tensor_name_to_qsv=tensor_name_to_qsv,
363-
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
364-
quant_params=quant_params,
365-
)
365+
output_quant_params = output_tensor_params.producer.parameters
366+
if not isinstance(output_quant_params, qtyping.UniformQuantParams):
367+
raise ValueError(
368+
"_materialize_standard_op_with_same_as_output_scale only supports"
369+
f" UniformQuantParams. For tensor {output_tensor_params.tensor_name},"
370+
f" got {type(output_quant_params)}"
371+
)
372+
# Materialize each of the input tensors separately in case there are
373+
# constants among them, requiring updating `quantized_data` first.
374+
for input_tensor in input_tensors:
375+
input_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
376+
input_tensor, graph_info.buffers
377+
)
378+
# Quantize constant inputs' data with the output quantization params.
379+
if input_tensor_data is None:
380+
quant_params = output_quant_params
381+
else:
382+
quantized_data = uniform_quantize_tensor.uniform_quantize(
383+
input_tensor_data, output_quant_params
384+
)
385+
quant_params = dataclasses.replace(
386+
output_quant_params,
387+
quantized_data=quantized_data,
388+
)
389+
_materialize_op_tensors(
390+
op_tensor_params,
391+
[input_tensor],
392+
is_inbounding_tensor=True,
393+
op_info=op_info,
394+
graph_info=graph_info,
395+
tensor_name_to_qsv=tensor_name_to_qsv,
396+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
397+
quant_params=quant_params,
398+
)
399+
366400
op_tensor_params.append(output_tensor_params)
367401

368402
return op_tensor_params

ai_edge_quantizer/default_policy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@
184184
"DYNAMIC_UPDATE_SLICE",
185185
"SELECT_V2",
186186
"STABLEHLO_COMPOSITE",
187-
"PAD"
187+
"PAD",
188+
"PADV2"
188189
],
189190
"static_wi8_ai8": [
190191
"ADD",
@@ -216,7 +217,8 @@
216217
"DYNAMIC_UPDATE_SLICE",
217218
"SELECT_V2",
218219
"STABLEHLO_COMPOSITE",
219-
"PAD"
220+
"PAD",
221+
"PADV2"
220222
],
221223
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
222224
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],

ai_edge_quantizer/qtyping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class TFLOperationName(str, enum.Enum):
6363
DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
6464
STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
6565
PAD = 'PAD'
66+
PADV2 = 'PADV2'
6667

6768

6869
class QuantizeMode(enum.Enum):

0 commit comments

Comments
 (0)