From 916ca642e33fcd9269641cbb6b18ccf99c6c4f05 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh Date: Fri, 22 Nov 2024 14:15:32 -0800 Subject: [PATCH] Make `UnitNormalization` layer stateless. There is no need to resolve negative axes in `build`, as `tf.linalg.l2_normalize` can handle them. Kept the build method to validate the axes in the context of the `input_shape`. Also added call to `super.build(...)` per best practice on Keras 2. Note that in Keras 3, `UnitNormalization` is already stateless. PiperOrigin-RevId: 699285760 --- tf_keras/layers/normalization/unit_normalization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tf_keras/layers/normalization/unit_normalization.py b/tf_keras/layers/normalization/unit_normalization.py index f3255bcb1..26cb546e6 100644 --- a/tf_keras/layers/normalization/unit_normalization.py +++ b/tf_keras/layers/normalization/unit_normalization.py @@ -60,7 +60,8 @@ def __init__(self, axis=-1, **kwargs): self.supports_masking = True def build(self, input_shape): - self.axis = tf_utils.validate_axis(self.axis, input_shape) + tf_utils.validate_axis(self.axis, input_shape) + super().build(input_shape) def call(self, inputs): inputs = tf.cast(inputs, self.compute_dtype)