Skip to content

Commit 37457ea

Browse files
zichuan-weicopybara-github
authored andcommitted
blockwise: enable blockwise quantization in policy
PiperOrigin-RevId: 746515536
1 parent 17af50a commit 37457ea

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

ai_edge_quantizer/algorithm_manager_api_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl.testing import parameterized
1919
from tensorflow.python.platform import googletest
2020
from ai_edge_quantizer import algorithm_manager_api
21+
from ai_edge_quantizer import default_policy
2122
from ai_edge_quantizer import qtyping
2223

2324
_TFLOpName = qtyping.TFLOperationName
@@ -205,6 +206,12 @@ def test_register_config_check_policy_succeeds(self):
205206
self._alg_manager._config_check_policy_registry[test_algorithm_name]
206207
)
207208

209+
def test_default_policy_not_empty(self):
210+
"""Tests that the default policy is not empty & no empty policy is generated."""
211+
self.assertNotEmpty(default_policy.DEFAULT_CONFIG_CHECK_POLICY)
212+
for policy in default_policy.DEFAULT_CONFIG_CHECK_POLICY.values():
213+
self.assertNotEmpty(policy)
214+
208215

209216
if __name__ == "__main__":
210217
googletest.main()

ai_edge_quantizer/default_policy.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@
5757
"explicit_dequantize": false,
5858
"compute_precision": "INTEGER"
5959
},
60+
"dynamic_wi4_afp32_blockwise": {
61+
"weight_tensor_config": {
62+
"num_bits": 4,
63+
"symmetric": [true],
64+
"granularity": ["BLOCKWISE"],
65+
"dtype": "INT",
66+
"block_size": [32, 64, 96, 128, 256]
67+
},
68+
"explicit_dequantize": false,
69+
"compute_precision": "INTEGER"
70+
},
6071
"static_wi8_ai16": {
6172
"activation_tensor_config": {
6273
"num_bits": 16,
@@ -216,6 +227,7 @@
216227
"FULLY_CONNECTED"
217228
],
218229
"dynamic_wi4_afp32": ["FULLY_CONNECTED", "EMBEDDING_LOOKUP", "CONV_2D"],
230+
"dynamic_wi4_afp32_blockwise": ["EMBEDDING_LOOKUP", "FULLY_CONNECTED"],
219231
"weightonly_wi8_afp32": [
220232
"BATCH_MATMUL",
221233
"CONV_2D",
@@ -259,6 +271,7 @@ def _unroll_json_config(
259271

260272
# Then unroll weight configs and turn them into quantization configs.
261273
quant_configs = []
274+
weight_configs = []
262275
for symmetric in json_config["weight_tensor_config"]["symmetric"]:
263276
for granularity in json_config["weight_tensor_config"]["granularity"]:
264277
tensor_config = {
@@ -267,6 +280,16 @@ def _unroll_json_config(
267280
"granularity": granularity,
268281
"dtype": json_config["weight_tensor_config"]["dtype"],
269282
}
283+
if "block_size" in json_config["weight_tensor_config"]:
284+
for block_size in json_config["weight_tensor_config"]["block_size"]:
285+
tensor_config["block_size"] = block_size
286+
weight_configs.append(
287+
qtyping.TensorQuantizationConfig.from_dict(tensor_config)
288+
)
289+
else:
290+
weight_configs.append(
291+
qtyping.TensorQuantizationConfig.from_dict(tensor_config)
292+
)
270293

271294
if activation_configs:
272295
for activation_config in activation_configs:
@@ -281,15 +304,14 @@ def _unroll_json_config(
281304
)
282305
)
283306
else:
284-
quant_configs.append(
285-
qtyping.OpQuantizationConfig(
286-
weight_tensor_config=qtyping.TensorQuantizationConfig.from_dict(
287-
tensor_config
288-
),
289-
compute_precision=json_config["compute_precision"],
290-
explicit_dequantize=json_config["explicit_dequantize"],
291-
)
292-
)
307+
for weight_config in weight_configs:
308+
quant_configs.append(
309+
qtyping.OpQuantizationConfig(
310+
weight_tensor_config=weight_config,
311+
compute_precision=json_config["compute_precision"],
312+
explicit_dequantize=json_config["explicit_dequantize"],
313+
)
314+
)
293315

294316
return quant_configs
295317

0 commit comments

Comments
 (0)