@@ -52,18 +52,20 @@ def setUp(self):
52
52
self ._quantizer = quantizer .Quantizer (self .float_model_path )
53
53
54
54
@parameterized .parameters (
55
- '../../recipes/default_af32w8float_recipe.json' ,
56
- '../../recipes/default_af32w4float_recipe.json' ,
57
- '../../recipes/dynamic_legacy_wi8_afp32_recipe.json' ,
58
- '../../recipes/dynamic_wi8_afp32_recipe.json' ,
55
+ ( '../../recipes/default_af32w8float_recipe.json' , 1700 ) ,
56
+ ( '../../recipes/default_af32w4float_recipe.json' , 1600 ) ,
57
+ ( '../../recipes/dynamic_legacy_wi8_afp32_recipe.json' , 1400 ) ,
58
+ ( '../../recipes/dynamic_wi8_afp32_recipe.json' , 1400 ) ,
59
59
)
60
- def test_embedding_lookup_model_int_weight_only (self , recipe_path ):
60
+ def test_embedding_lookup_model_int_weight_only (
61
+ self , recipe_path , expected_model_size
62
+ ):
61
63
recipe_path = test_utils .get_path_to_datafile (recipe_path )
62
64
self ._quantizer .load_quantization_recipe (recipe_path )
63
65
self .assertFalse (self ._quantizer .need_calibration )
64
66
quant_result = self ._quantizer .quantize ()
65
67
# Check model size.
66
- self .assertLess (len (quant_result .quantized_model ), 2000 )
68
+ self .assertLess (len (quant_result .quantized_model ), expected_model_size )
67
69
68
70
# TODO: b/364405203 - Enable after 0 signature works.
69
71
# comparison_result = self._quantizer.validate(
@@ -91,8 +93,7 @@ def test_embedding_lookup_model_fp16_weight_only(self):
91
93
),
92
94
)
93
95
quant_result = self ._quantizer .quantize ()
94
- print (len (quant_result .quantized_model ))
95
- self .assertLess (len (quant_result .quantized_model ), 2000 )
96
+ self .assertLess (len (quant_result .quantized_model ), 1600 )
96
97
97
98
# TODO: b/364405203 - Enable after 0 signature works.
98
99
# comparion_result = self._quantizer.validate(
@@ -106,18 +107,20 @@ def test_embedding_lookup_model_fp16_weight_only(self):
106
107
# )
107
108
108
109
@parameterized .parameters (
109
- '../../recipes/default_a8w8_recipe.json' ,
110
- '../../recipes/default_a16w8_recipe.json' ,
110
+ ( '../../recipes/default_a8w8_recipe.json' , 1400 ) ,
111
+ ( '../../recipes/default_a16w8_recipe.json' , 1400 ) ,
111
112
)
112
- def test_embedding_lookup_model_full_integer (self , recipe_path ):
113
+ def test_embedding_lookup_model_full_integer (
114
+ self , recipe_path , expected_model_size
115
+ ):
113
116
calibration_result = {
114
117
'Identity_1' : {'min' : - 2.0 , 'max' : 2.0 },
115
118
}
116
119
recipe_path = test_utils .get_path_to_datafile (recipe_path )
117
120
self ._quantizer .load_quantization_recipe (recipe_path )
118
121
self .assertTrue (self ._quantizer .need_calibration )
119
122
quant_result = self ._quantizer .quantize (calibration_result )
120
- self .assertLess (len (quant_result .quantized_model ), 2000 )
123
+ self .assertLess (len (quant_result .quantized_model ), expected_model_size )
121
124
# TODO: b/364405203 - Enable after 0 signature works.
122
125
# comparion_result = self._quantizer.validate(
123
126
# error_metrics='mse',
0 commit comments