Skip to content

Commit 9bca2c2

Browse files
alankellycopybara-github
authored andcommitted
De-duplicate zero points for per channel quantized tensors when all the zero points are the same.
PiperOrigin-RevId: 763795520
1 parent b3b6b22 commit 9bca2c2

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,8 @@ def _is_valid_quantization_params(
435435
"""Checks if the quantization parameters are valid.
436436
437437
A valid quantization params requires:
438-
1. scale and zero point have the same shape (TFL Runtime requirement).
438+
1. scale and zero point either have the same shape or the zero point is a
439+
scalar.
439440
2. scale and zero point have the same rank as the tensor content (avoid
440441
ambiguous broadcasting).
441442
@@ -446,17 +447,20 @@ def _is_valid_quantization_params(
446447
Returns:
447448
True if the quantization parameters are valid.
448449
"""
449-
if quantization_params.scale.shape != quantization_params.zero_point.shape:
450+
if (
451+
quantization_params.scale.shape != quantization_params.zero_point.shape
452+
and quantization_params.zero_point.size != 1
453+
):
450454
raise ValueError(
451-
"scale and zero_point must have the same shape. Got"
452-
f" {quantization_params.scale.shape} and"
455+
"scale and zero_point must have the same shape or zero_point must have"
456+
f" only one element. Got {quantization_params.scale.shape} and"
453457
f" {quantization_params.zero_point.shape}"
454458
)
455459

456460
tensor_rank = tensor_data.ndim
457461
scale_rank = quantization_params.scale.ndim
458462
zero_point_rank = quantization_params.zero_point.ndim
459-
if (tensor_rank != scale_rank) or (tensor_rank != zero_point_rank):
463+
if tensor_rank != scale_rank or (tensor_rank != zero_point_rank):
460464
raise ValueError(
461465
f"Ranks of scales ({scale_rank}) and zps"
462466
f" ({zero_point_rank}) must be the same as the tensor rank"

ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def test_uniform_quantize(
160160
def test_uniform_quantize_wrong_shape(self):
161161
tensor = [-3.0, 1.3, 2.4, 16.0]
162162

163-
error_message = "scale and zero_point must have the same shape."
163+
error_message = (
164+
"Ranks of scales (3) and zps (2) must be the same as the tensor rank"
165+
)
164166
with self.assertRaisesWithPredicateMatch(
165167
ValueError, lambda err: error_message in str(err)
166168
):
@@ -233,7 +235,9 @@ def test_uniform_dequantize(
233235
def test_uniform_dequantize_wrong_shape(self):
234236
tensor = [-3.0, 1.3, 2.4, 16.0]
235237

236-
error_message = "scale and zero_point must have the same shape."
238+
error_message = (
239+
"Ranks of scales (3) and zps (2) must be the same as the tensor rank"
240+
)
237241
with self.assertRaisesWithPredicateMatch(
238242
ValueError, lambda err: error_message in str(err)
239243
):

0 commit comments

Comments
 (0)