@@ -435,7 +435,8 @@ def _is_valid_quantization_params(
435
435
"""Checks if the quantization parameters are valid.
436
436
437
437
A valid quantization params requires:
438
- 1. scale and zero point have the same shape (TFL Runtime requirement).
438
+ 1. scale and zero point either have the same shape or the zero point is a
439
+ scalar.
439
440
2. scale and zero point have the same rank as the tensor content (avoid
440
441
ambiguous broadcasting).
441
442
@@ -446,17 +447,20 @@ def _is_valid_quantization_params(
446
447
Returns:
447
448
True if the quantization parameters are valid.
448
449
"""
449
- if quantization_params .scale .shape != quantization_params .zero_point .shape :
450
+ if (
451
+ quantization_params .scale .shape != quantization_params .zero_point .shape
452
+ and quantization_params .zero_point .size != 1
453
+ ):
450
454
raise ValueError (
451
- "scale and zero_point must have the same shape. Got "
452
- f" { quantization_params .scale .shape } and"
455
+ "scale and zero_point must have the same shape or zero_point must have "
456
+ f" only one element. Got { quantization_params .scale .shape } and"
453
457
f" { quantization_params .zero_point .shape } "
454
458
)
455
459
456
460
tensor_rank = tensor_data .ndim
457
461
scale_rank = quantization_params .scale .ndim
458
462
zero_point_rank = quantization_params .zero_point .ndim
459
- if ( tensor_rank != scale_rank ) or (tensor_rank != zero_point_rank ):
463
+ if tensor_rank != scale_rank or (tensor_rank != zero_point_rank ):
460
464
raise ValueError (
461
465
f"Ranks of scales ({ scale_rank } ) and zps"
462
466
f" ({ zero_point_rank } ) must be the same as the tensor rank"
0 commit comments