Skip to content

Commit b6be426

Browse files
ai-edge-botcopybara-github
authored andcommitted
Rollback "De-duplicate zero points for per channel quantized tensors when all the zero points are the same."
PiperOrigin-RevId: 764435889
1 parent ed07bba commit b6be426

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,7 @@ def _is_valid_quantization_params(
435435
"""Checks if the quantization parameters are valid.
436436
437437
A valid quantization params requires:
438-
1. scale and zero point either have the same shape or the zero point is a
439-
scalar.
438+
1. scale and zero point have the same shape (TFL Runtime requirement).
440439
2. scale and zero point have the same rank as the tensor content (avoid
441440
ambiguous broadcasting).
442441
@@ -447,20 +446,17 @@ def _is_valid_quantization_params(
447446
Returns:
448447
True if the quantization parameters are valid.
449448
"""
450-
if (
451-
quantization_params.scale.shape != quantization_params.zero_point.shape
452-
and quantization_params.zero_point.size != 1
453-
):
449+
if quantization_params.scale.shape != quantization_params.zero_point.shape:
454450
raise ValueError(
455-
"scale and zero_point must have the same shape or zero_point must have"
456-
f" only one element. Got {quantization_params.scale.shape} and"
451+
"scale and zero_point must have the same shape. Got"
452+
f" {quantization_params.scale.shape} and"
457453
f" {quantization_params.zero_point.shape}"
458454
)
459455

460456
tensor_rank = tensor_data.ndim
461457
scale_rank = quantization_params.scale.ndim
462458
zero_point_rank = quantization_params.zero_point.ndim
463-
if tensor_rank != scale_rank or (tensor_rank != zero_point_rank):
459+
if (tensor_rank != scale_rank) or (tensor_rank != zero_point_rank):
464460
raise ValueError(
465461
f"Ranks of scales ({scale_rank}) and zps"
466462
f" ({zero_point_rank}) must be the same as the tensor rank"

ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,7 @@ def test_uniform_quantize(
160160
def test_uniform_quantize_wrong_shape(self):
161161
tensor = [-3.0, 1.3, 2.4, 16.0]
162162

163-
error_message = (
164-
"Ranks of scales (3) and zps (2) must be the same as the tensor rank"
165-
)
163+
error_message = "scale and zero_point must have the same shape."
166164
with self.assertRaisesWithPredicateMatch(
167165
ValueError, lambda err: error_message in str(err)
168166
):
@@ -235,9 +233,7 @@ def test_uniform_dequantize(
235233
def test_uniform_dequantize_wrong_shape(self):
236234
tensor = [-3.0, 1.3, 2.4, 16.0]
237235

238-
error_message = (
239-
"Ranks of scales (3) and zps (2) must be the same as the tensor rank"
240-
)
236+
error_message = "scale and zero_point must have the same shape."
241237
with self.assertRaisesWithPredicateMatch(
242238
ValueError, lambda err: error_message in str(err)
243239
):

0 commit comments

Comments
 (0)