Skip to content

Commit ad06941

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Fixes for Sequential model with multiple inputs.
- While `Sequential` works with multiple inputs in most scenarios, `build()` did not allow building with multiple inputs. This is now fixed. - Removed the `build_input_shape` from the new serialization format. This is a legacy concept, which has been replaced with `build_config.input_shape` in the new format. Having both could cause models to be built twice. - `build_from_config` now always call `build` with `TensorShape`s, not tuples. Not all layers handle tuples correctly. PiperOrigin-RevId: 719002230
1 parent a23abb2 commit ad06941

File tree

3 files changed

+57
-13
lines changed

3 files changed

+57
-13
lines changed

tf_keras/engine/base_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2321,7 +2321,7 @@ def build_from_config(self, config):
23212321
"""
23222322
input_shape = config["input_shape"]
23232323
if input_shape is not None:
2324-
self.build(input_shape)
2324+
self.build(tf_utils.convert_shapes(input_shape, to_tuples=False))
23252325

23262326
############################################################################
23272327
# Methods & attributes below are all private and only used by the framework.

tf_keras/engine/sequential.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,16 @@ def _build_graph_network_for_inferred_shape(
285285
):
286286
# Determine whether the input shape is novel, i.e. whether the model
287287
# should be rebuilt.
288-
input_shape = tuple(input_shape)
288+
input_shape = tf_utils.convert_shapes(input_shape)
289289
if self._inferred_input_shape is None:
290290
new_shape = input_shape
291291
else:
292-
new_shape = relax_input_shape(
293-
self._inferred_input_shape, input_shape
292+
new_shape = tf.nest.map_structure(
293+
_relax_input_shape,
294+
tf_utils.convert_shapes(
295+
self._inferred_input_shape, to_tuples=False
296+
),
297+
tf_utils.convert_shapes(input_shape, to_tuples=False),
294298
)
295299
if (
296300
new_shape is not None
@@ -299,10 +303,13 @@ def _build_graph_network_for_inferred_shape(
299303
# A novel shape has been received: we need to rebuild the model.
300304
# In case we are inside a graph function, we step out of it.
301305
with tf.init_scope():
302-
inputs = input_layer.Input(
303-
batch_shape=new_shape,
304-
dtype=input_dtype,
305-
name=self.layers[0].name + "_input",
306+
inputs = tf.nest.map_structure(
307+
lambda s: input_layer.Input(
308+
batch_shape=tf_utils.convert_shapes(s),
309+
dtype=input_dtype,
310+
name=self.layers[0].name + "_input",
311+
),
312+
tf_utils.convert_shapes(new_shape, to_tuples=False),
306313
)
307314
layer_input = inputs
308315
created_nodes = set()
@@ -370,7 +377,7 @@ def build(self, input_shape=None):
370377
raise ValueError("You must provide an `input_shape` argument.")
371378
self._build_graph_network_for_inferred_shape(input_shape)
372379
if not self.built:
373-
input_shape = tuple(input_shape)
380+
input_shape = tf_utils.convert_shapes(input_shape)
374381
self._build_input_shape = input_shape
375382
super().build(input_shape)
376383
self.built = True
@@ -435,7 +442,8 @@ def compute_mask(self, inputs, mask):
435442
def get_config(self):
436443
layer_configs = []
437444
serialize_obj_fn = serialization_lib.serialize_keras_object
438-
if getattr(self, "use_legacy_config", None):
445+
use_legacy_config = getattr(self, "use_legacy_config", False)
446+
if use_legacy_config:
439447
serialize_obj_fn = legacy_serialization.serialize_keras_object
440448
for layer in super().layers:
441449
# `super().layers` include the InputLayer if available (it is
@@ -446,7 +454,11 @@ def get_config(self):
446454
config = training.Model.get_config(self)
447455
config["name"] = self.name
448456
config["layers"] = copy.deepcopy(layer_configs)
449-
if not self._is_graph_network and self._build_input_shape is not None:
457+
if (
458+
use_legacy_config
459+
and not self._is_graph_network
460+
and self._build_input_shape
461+
):
450462
config["build_input_shape"] = self._build_input_shape
451463
return config
452464

@@ -458,6 +470,7 @@ def from_config(cls, config, custom_objects=None):
458470
layer_configs = config["layers"]
459471
else:
460472
name = None
473+
build_input_shape = None
461474
layer_configs = config
462475
model = cls(name=name)
463476
for layer_config in layer_configs:
@@ -519,11 +532,15 @@ def _get_shape_tuple(t):
519532
return None
520533

521534

522-
def relax_input_shape(shape_1, shape_2):
535+
def _relax_input_shape(shape_1, shape_2):
523536
if shape_1 is None or shape_2 is None:
524537
return None
525-
if len(shape_1) != len(shape_2):
538+
if shape_1.rank is None or shape_2.rank is None:
539+
return None
540+
if shape_1.rank != shape_2.rank:
526541
return None
542+
shape_1 = shape_1.as_list()
543+
shape_2 = shape_2.as_list()
527544
return tuple(None if d1 != d2 else d1 for d1, d2 in zip(shape_1, shape_2))
528545

529546

tf_keras/engine/sequential_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from absl.testing import parameterized
2020

2121
import tf_keras as keras
22+
from tf_keras.saving import object_registration
2223
from tf_keras.testing_infra import test_combinations
2324
from tf_keras.testing_infra import test_utils
2425

@@ -574,6 +575,22 @@ def test_multi_inputs_outputs(self):
574575
model(image_inputs)
575576
model.fit(x=image_inputs, y=image_inputs, steps_per_epoch=1)
576577

578+
@test_combinations.run_all_keras_modes(always_skip_v1=True)
579+
def test_multi_inputs_build(self):
580+
model = keras.Sequential([ImageMultiplyLayer()])
581+
model.build({"images": (None, 512, 512, 3), "weights": (None, 3)})
582+
583+
image_inputs = tf.ones((2, 512, 512, 3))
584+
weight_inputs = tf.ones((2, 3))
585+
output = model({"images": image_inputs, "weights": weight_inputs})
586+
587+
config = model.to_json()
588+
new_model = keras.models.model_from_json(config)
589+
new_output = new_model(
590+
{"images": image_inputs, "weights": weight_inputs}
591+
)
592+
self.assertAllClose(output, new_output)
593+
577594

578595
class TestSequentialEagerIntegration(test_combinations.TestCase):
579596
@test_combinations.run_all_keras_modes
@@ -642,10 +659,20 @@ def test_build_empty_network(self):
642659
self.assertTrue(model.built)
643660

644661

662+
@object_registration.register_keras_serializable()
645663
class ImageAugmentLayer(keras.layers.Layer):
646664
def call(self, inputs):
647665
return inputs
648666

649667

668+
@object_registration.register_keras_serializable()
669+
class ImageMultiplyLayer(keras.layers.Layer):
670+
def call(self, inputs):
671+
images = inputs["images"]
672+
weights = inputs["weights"]
673+
images = tf.reshape(images, (-1, 1, 1, 3))
674+
return images * weights
675+
676+
650677
if __name__ == "__main__":
651678
tf.test.main()

0 commit comments

Comments
 (0)