Skip to content

Commit 4524503

Browse files
paulineshocopybara-github
authored andcommitted
Update embedding_lookup model with signature
PiperOrigin-RevId: 755525455
1 parent a3c1aa3 commit 4524503

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,9 +531,9 @@ def test_embedding_lookup_weight_only_succeeds(self):
531531

532532
op_tensor_names = {}
533533
op_tensor_names["weight"] = (
534-
"jax2tf_export_func_/...y_yz-_...z/pjit__einsum_/MatMul;jax2tf_export_func_/pjit__one_hot_/Equal;jax2tf_export_func_/pjit__one_hot_/Cast_1"
534+
"jit(export_func)/jit(main)/...y,yz->...z/dot_general;jit(export_func)/jit(main)/jit(_one_hot)/eq;jit(export_func)/jit(main)/jit(_one_hot)/convert_element_type"
535535
)
536-
op_tensor_names["input"] = "inputs"
536+
op_tensor_names["input"] = "lookup"
537537
op_tensor_names["output"] = "Identity_1"
538538

539539
# TODO: b/335913710 - Rename the test function.

ai_edge_quantizer/algorithms/uniform_quantize/op_architecture_tests/embedding_lookup_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ def test_embedding_lookup_succeeds(
9393
op = subgraph0.operators[subgraph_op_id]
9494
op_tensor_names = {}
9595
op_tensor_names["weight"] = (
96-
"jax2tf_export_func_/...y_yz-_...z/pjit__einsum_/MatMul;jax2tf_export_func_/pjit__one_hot_/Equal;jax2tf_export_func_/pjit__one_hot_/Cast_1"
96+
"jit(export_func)/jit(main)/...y,yz->...z/dot_general;jit(export_func)/jit(main)/jit(_one_hot)/eq;jit(export_func)/jit(main)/jit(_one_hot)/convert_element_type"
9797
)
98-
op_tensor_names["input"] = "inputs"
98+
op_tensor_names["input"] = "lookup"
9999
op_tensor_names["output"] = "Identity_1"
100100
self._op_test_info.op_tensor_names = op_tensor_names
101101
self._op_test_info.quantized_dimension = 0

ai_edge_quantizer/tests/end_to_end_tests/embedding_lookup_test.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,20 @@ def setUp(self):
5252
self._quantizer = quantizer.Quantizer(self.float_model_path)
5353

5454
@parameterized.parameters(
55-
'../../recipes/default_af32w8float_recipe.json',
56-
'../../recipes/default_af32w4float_recipe.json',
57-
'../../recipes/dynamic_legacy_wi8_afp32_recipe.json',
58-
'../../recipes/dynamic_wi8_afp32_recipe.json',
55+
('../../recipes/default_af32w8float_recipe.json', 1700),
56+
('../../recipes/default_af32w4float_recipe.json', 1600),
57+
('../../recipes/dynamic_legacy_wi8_afp32_recipe.json', 1400),
58+
('../../recipes/dynamic_wi8_afp32_recipe.json', 1400),
5959
)
60-
def test_embedding_lookup_model_int_weight_only(self, recipe_path):
60+
def test_embedding_lookup_model_int_weight_only(
61+
self, recipe_path, expected_model_size
62+
):
6163
recipe_path = test_utils.get_path_to_datafile(recipe_path)
6264
self._quantizer.load_quantization_recipe(recipe_path)
6365
self.assertFalse(self._quantizer.need_calibration)
6466
quant_result = self._quantizer.quantize()
6567
# Check model size.
66-
self.assertLess(len(quant_result.quantized_model), 2000)
68+
self.assertLess(len(quant_result.quantized_model), expected_model_size)
6769

6870
# TODO: b/364405203 - Enable after 0 signature works.
6971
# comparison_result = self._quantizer.validate(
@@ -91,8 +93,7 @@ def test_embedding_lookup_model_fp16_weight_only(self):
9193
),
9294
)
9395
quant_result = self._quantizer.quantize()
94-
print(len(quant_result.quantized_model))
95-
self.assertLess(len(quant_result.quantized_model), 2000)
96+
self.assertLess(len(quant_result.quantized_model), 1600)
9697

9798
# TODO: b/364405203 - Enable after 0 signature works.
9899
# comparion_result = self._quantizer.validate(
@@ -106,18 +107,20 @@ def test_embedding_lookup_model_fp16_weight_only(self):
106107
# )
107108

108109
@parameterized.parameters(
109-
'../../recipes/default_a8w8_recipe.json',
110-
'../../recipes/default_a16w8_recipe.json',
110+
('../../recipes/default_a8w8_recipe.json', 1400),
111+
('../../recipes/default_a16w8_recipe.json', 1400),
111112
)
112-
def test_embedding_lookup_model_full_integer(self, recipe_path):
113+
def test_embedding_lookup_model_full_integer(
114+
self, recipe_path, expected_model_size
115+
):
113116
calibration_result = {
114117
'Identity_1': {'min': -2.0, 'max': 2.0},
115118
}
116119
recipe_path = test_utils.get_path_to_datafile(recipe_path)
117120
self._quantizer.load_quantization_recipe(recipe_path)
118121
self.assertTrue(self._quantizer.need_calibration)
119122
quant_result = self._quantizer.quantize(calibration_result)
120-
self.assertLess(len(quant_result.quantized_model), 2000)
123+
self.assertLess(len(quant_result.quantized_model), expected_model_size)
121124
# TODO: b/364405203 - Enable after 0 signature works.
122125
# comparion_result = self._quantizer.validate(
123126
# error_metrics='mse',
Binary file not shown.

0 commit comments

Comments
 (0)