Skip to content

Modified compute_output_shape function to handle broadcasting behavior in layers.Rescaling #21351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions keras/src/layers/preprocessing/rescaling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
Expand Down Expand Up @@ -27,8 +29,16 @@ class Rescaling(TFDataLayer):
(independently of which backend you're using).

Args:
scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs.
scale: Float, int, list, tuple or np.ndarray.
The scale to apply to the inputs.
If scalar, the same scale will be applied to
all features or channels of input. If a list, tuple or
1D array, the scaling is applied per channel.
offset: Float, int, list/tuple or numpy ndarray.
The offset to apply to the inputs.
If scalar, the same scale will be applied to
all features or channels of input. If a list, tuple or
1D array, the scaling is applied per channel.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
"""

Expand All @@ -53,6 +63,51 @@ def call(self, inputs):
return self.backend.cast(inputs, dtype) * scale + offset

def compute_output_shape(self, input_shape):
input_shape = tuple(input_shape)

if input_shape[-1] is None:
return input_shape

input_channels = input_shape[-1]

scale_len = None
offset_len = None

if isinstance(self.scale, (list, tuple)):
scale_len = len(self.scale)
elif isinstance(self.scale, np.ndarray) and self.scale.ndim == 1:
scale_len = self.scale.size
elif isinstance(self.scale, (int, float)):
scale_len = 1

if isinstance(self.offset, (list, tuple)):
offset_len = len(self.offset)
elif isinstance(self.offset, np.ndarray) and self.offset.ndim == 1:
offset_len = self.offset.size
elif isinstance(self.offset, (int, float)):
offset_len = 1

if scale_len == 1 and offset_len == 1:
return input_shape

broadcast_len = None
if scale_len is not None and scale_len != input_channels:
broadcast_len = scale_len
if offset_len is not None and offset_len != input_channels:
if broadcast_len is not None and offset_len != broadcast_len:
raise ValueError(
"Inconsistent `scale` and `offset` lengths "
f"for broadcasting."
f" Received: `scale` = {self.scale},"
f"`offset` = {self.offset}. "
f"Ensure both `scale` and `offset` are either scalar "
f"or list, tuples, arrays of the same length."
)
broadcast_len = offset_len

if broadcast_len:
return input_shape[:-1] + (broadcast_len,)

return input_shape

def get_config(self):
Expand Down
17 changes: 17 additions & 0 deletions keras/src/layers/preprocessing/rescaling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,20 @@ def test_numpy_args(self):
expected_num_losses=0,
supports_masking=True,
)

@pytest.mark.requires_trainable_backend
def test_rescaling_broadcast_output_shape(self):
self.run_layer_test(
layers.Rescaling,
init_kwargs={
"scale": [1.0, 1.0],
"offset": [0.0, 0.0],
},
input_shape=(2, 1),
expected_output_shape=(2, 2),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
Loading