@@ -435,8 +435,7 @@ def _is_valid_quantization_params(
435
435
"""Checks if the quantization parameters are valid.
436
436
437
437
A valid quantization params requires:
438
- 1. scale and zero point either have the same shape or the zero point is a
439
- scalar.
438
+ 1. scale and zero point have the same shape (TFL Runtime requirement).
440
439
2. scale and zero point have the same rank as the tensor content (avoid
441
440
ambiguous broadcasting).
442
441
@@ -447,20 +446,17 @@ def _is_valid_quantization_params(
447
446
Returns:
448
447
True if the quantization parameters are valid.
449
448
"""
450
- if (
451
- quantization_params .scale .shape != quantization_params .zero_point .shape
452
- and quantization_params .zero_point .size != 1
453
- ):
449
+ if quantization_params .scale .shape != quantization_params .zero_point .shape :
454
450
raise ValueError (
455
- "scale and zero_point must have the same shape or zero_point must have "
456
- f" only one element. Got { quantization_params .scale .shape } and"
451
+ "scale and zero_point must have the same shape. Got "
452
+ f" { quantization_params .scale .shape } and"
457
453
f" { quantization_params .zero_point .shape } "
458
454
)
459
455
460
456
tensor_rank = tensor_data .ndim
461
457
scale_rank = quantization_params .scale .ndim
462
458
zero_point_rank = quantization_params .zero_point .ndim
463
- if tensor_rank != scale_rank or (tensor_rank != zero_point_rank ):
459
+ if ( tensor_rank != scale_rank ) or (tensor_rank != zero_point_rank ):
464
460
raise ValueError (
465
461
f"Ranks of scales ({ scale_rank } ) and zps"
466
462
f" ({ zero_point_rank } ) must be the same as the tensor rank"
0 commit comments