From 803835f8d2de38f16ca063a4e7837d593f027a05 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Tue, 20 May 2025 19:43:35 -0700 Subject: [PATCH] Adds support for Call-Context Arguments Create an argument propagation flow for call-context arguments. Currently, keras uses `training` argument to infer whether layer should be called in training/inference mode. This change introduces a general flow of propagating arguments from a parent call to a child call (using call_context), so that we can add new control flow arguments in the future using a generic framework. This change does the following things: * Adds a `call_context_args` dictionary in the call_context object to store call-context arguments being propagated. * Changes current layer implementation to use the general propagation flow instead of using hardcoded `training`. * Adds utilities to query and set this context arguments in the `Layer` class. PiperOrigin-RevId: 761325027 --- tf_keras/engine/base_layer.py | 319 ++++++++++++------ tf_keras/engine/base_layer_test.py | 165 +++++++++ tf_keras/engine/base_layer_utils.py | 22 +- tf_keras/engine/base_layer_v1.py | 11 +- tf_keras/layers/core/tf_op_layer.py | 4 + tf_keras/layers/rnn/base_rnn_test.py | 4 +- tf_keras/layers/rnn/bidirectional_test.py | 4 +- tf_keras/layers/rnn/cell_wrappers.py | 14 +- .../saving/legacy/saved_model/save_impl.py | 40 ++- tf_keras/saving/legacy/saving_utils.py | 16 +- tf_keras/utils/layer_utils.py | 48 ++- 11 files changed, 525 insertions(+), 122 deletions(-) diff --git a/tf_keras/engine/base_layer.py b/tf_keras/engine/base_layer.py index 279e193f0..ffed039de 100644 --- a/tf_keras/engine/base_layer.py +++ b/tf_keras/engine/base_layer.py @@ -308,6 +308,7 @@ def __init__( self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs ): self._instrument_layer_creation() + self._called = False # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' @@ -326,6 +327,10 @@ def __init__( # Validate optional keyword arguments. generic_utils.validate_kwargs(kwargs, allowed_kwargs) + # Track the built-in call-context arguments. These are arguments that + # are tracked and propagated across the call-stack by default. + self._call_context_args = {"training"} + # Mutable properties # Indicates whether the layer's weights are updated during training # and whether the layer's updates are run during training. @@ -411,6 +416,9 @@ def __init__( self._init_call_fn_args() + # Track the built-in call-context arguments. + self._call_spec._update_call_context_arguments(self._call_context_args) + # Whether the `call` method can be used to build a TF graph without # issues. This attribute has no effect if the model is created using # the Functional API. Instead, `model.dynamic` is determined based on @@ -1042,6 +1050,7 @@ def __call__(self, *args, **kwargs): # - input_spec compatibility is only checked against `inputs` # - mixed precision casting (autocast) is only applied to `inputs`, # not to any other argument. + self._called = True inputs, args, kwargs = self._call_spec.split_out_first_arg(args, kwargs) input_list = tf.nest.flatten(inputs) @@ -1080,17 +1089,21 @@ def __call__(self, *args, **kwargs): if self._expects_mask_arg and mask_is_implicit: kwargs["mask"] = input_masks - # Training mode for `Layer.call` is set via (in order of priority): - # (1) The `training` argument passed to this `Layer.call`, if it is not - # None - # (2) The training mode of an outer `Layer.call`. - # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if - # set) - # (4) Any non-None default value for `training` specified in the call + # Call-context arguments for `Layer.call` is set via (in order of + # priority): + # (1) The argument passed to this `Layer.call`, if it is not None + # (2) The argument value of an outer `Layer.call`. + # (3) (only for "training") The default mode set by + # `tf.keras.backend.set_learning_phase` (if set) + # (4) Any non-None default value for the argument specified in the call # signature - # (5) False (treating the layer as if it's in inference) - args, kwargs, training_mode = self._set_training_mode( - args, kwargs, call_context + # (5) False + ( + args, + kwargs, + propagated, + ) = self._get_propagated_call_context_arguments( + args, kwargs, call_context, self._call_context_args ) # Losses are cleared for all sublayers on the outermost `Layer.call`. @@ -1104,7 +1117,7 @@ def __call__(self, *args, **kwargs): layer=self, inputs=inputs, build_graph=not eager, - training=training_mode, + call_context_args=propagated, ): input_spec.assert_input_compatibility( self.input_spec, inputs, self.name @@ -1152,6 +1165,55 @@ def __call__(self, *args, **kwargs): return outputs + def _register_call_context_args(self, *argument_names): + """Registers call-context args for this layer. + If this layer declares a `call()` method that accepts + one or more of the given args, those args will be + automatically injected into the call signature of this + layer. This layer will also propagate the args to any + nested sublayers that are called from within this layer. + If this layer doesn't declare a `call()` method that + accepts one or more of the given args, these args will + simply be propagated to any nested sublayers without + being injected into the call signature of this layer. + This is useful for propagating custom arguments + from top-level layers/models to sublayers. + Example: + ``` + class Inner(layers.Layer): + def __init__(self): + super().__init__() + # Register `foo_mode` as a call-context arg + self._register_call_context_args("foo_mode") + def call(self, x, foo_mode=False): + # If foo_mode=True add 1, otherwise add 0 + add_val = ops.where(foo_mode, 1.0, 0.0) + return x + add_val + class Outer(layers.Layer): + def __init__(self): + super().__init__() + self.inner = Inner() + def call(self, x): + # We don't explicitly pass foo_mode here—Base Layer.__call__ + # should inject it into `self.inner` + return self.inner(x) + sample_input = np.array([[1.0], [2.0]]) + # Sequential model + seq = models.Sequential([Outer()]) + # Tell the Sequential model to propagate foo_mode down + # the call-stack + seq.register_call_context_args("foo_mode") + # foo_mode=True -> input + 1 + out_true = seq(sample_input, foo_mode=True) + """ + if self._called: + raise RuntimeError( + "Cannot add call-context args after the layer has been called." + ) + self._call_context_args |= set(argument_names) + self._call_spec._update_call_context_arguments(argument_names) + self._call_spec._update_call_context_argument_defaults(argument_names) + def _get_unnested_name_scope(self): if _is_name_scope_on_model_declaration_enabled: with _name_scope_unnester( @@ -2535,47 +2597,57 @@ def _convert_non_tensor(x): kwargs["mask"] = input_masks mask_arg_passed_by_framework = True - # If `training` argument is None or not explicitly passed, - # propagate `training` value from this layer's calling layer. - training_value = None - training_arg_passed_by_framework = False - # Priority 1: `training` was explicitly passed a non-None value. - if self._call_spec.arg_was_passed("training", args, kwargs): - training_value = self._call_spec.get_arg_value( - "training", args, kwargs - ) - if not self._expects_training_arg: - kwargs.pop("training") - - if training_value is None: - # Priority 2: `training` was passed to a parent layer. - if call_context.training is not None: - training_value = call_context.training - # Priority 3: `learning_phase()` has been set. - elif backend.global_learning_phase_is_set(): - training_value = backend.learning_phase() - # Force the training_value to be bool type which matches to the - # contract for layer/model call args. - if tf.is_tensor(training_value): - training_value = tf.cast(training_value, tf.bool) + propagated = dict() + args_passed_by_framework = dict() + for context_arg in self._call_context_args: + # If `training` argument is None or not explicitly passed, + # propagate `training` value from this layer's calling layer. + value = None + args_passed_by_framework[context_arg] = False + # Priority 1: `training` was explicitly passed a non-None value. + if self._call_spec.arg_was_passed(context_arg, args, kwargs): + value = self._call_spec.get_arg_value(context_arg, args, kwargs) + if not self._expects_context_arg(context_arg): + kwargs.pop(context_arg) + + if value is None: + # Priority 2: `training` was passed to a parent layer. + if call_context.get_call_context_arg(context_arg) is not None: + value = call_context.get_call_context_arg(context_arg) + # Priority 3: `learning_phase()` has been set. + elif ( + context_arg == "training" + and backend.global_learning_phase_is_set() + ): + value = backend.learning_phase() + # Force the training_value to be bool type which matches to + # the contract for layer/model call args. + if tf.is_tensor(value): + value = tf.cast(value, tf.bool) + else: + value = bool(value) + # Priority 4: trace layer with the default training argument + # specified in the `call` signature (or in inference mode if the + # `call` signature specifies no non-None default). else: - training_value = bool(training_value) - # Priority 4: trace layer with the default training argument - # specified in the `call` signature (or in inference mode if the - # `call` signature specifies no non-None default). - else: - training_value = self._call_spec.default_training_arg - # In cases (2), (3), (4) the training argument is passed - # automatically by the framework, and will not be hard-coded into - # the model. - if self._expects_training_arg: - args, kwargs = self._call_spec.set_arg_value( - "training", training_value, args, kwargs - ) - training_arg_passed_by_framework = True + value = self._call_spec.get_context_arg_default(context_arg) + # In cases (2), (3), (4) the training argument is passed + # automatically by the framework, and will not be hard-coded + # into the model. + if self._expects_context_arg(context_arg): + args, kwargs = self._call_spec.set_arg_value( + context_arg, value, args, kwargs + ) + args_passed_by_framework[context_arg] = True + + if value is not None: + propagated[context_arg] = value with call_context.enter( - layer=self, inputs=inputs, build_graph=True, training=training_value + layer=self, + inputs=inputs, + build_graph=True, + call_context_args=propagated, ): # Check input assumptions set after layer building, e.g. input # shape. @@ -2601,10 +2673,13 @@ def _convert_non_tensor(x): "Tensor or a list of Tensors, not None " "(layer: " + self.name + ")." ) - if training_arg_passed_by_framework: - args, kwargs = self._call_spec.set_arg_value( - "training", None, args, kwargs, pop_kwarg_if_none=True - ) + + for context_arg, is_passed in args_passed_by_framework.items(): + if is_passed: + args, kwargs = self._call_spec.set_arg_value( + context_arg, None, args, kwargs, pop_kwarg_if_none=True + ) + if mask_arg_passed_by_framework: kwargs.pop("mask") # Node connectivity does not special-case the first argument. @@ -2613,52 +2688,100 @@ def _convert_non_tensor(x): ) return outputs - def _set_training_mode(self, args, kwargs, call_context): - training_mode = None - if self._expects_training_arg: - # (1) `training` was passed to this `Layer.call`. - if self._call_spec.arg_was_passed("training", args, kwargs): - training_mode = self._call_spec.get_arg_value( - "training", args, kwargs - ) - # If no `training` arg was passed, or `None` was explicitly passed, - # the framework will make a decision about the training mode is. - if training_mode is None: - call_ctx_training = call_context.training - # (2) `training` mode is inferred from an outer `Layer.call`. - if call_ctx_training is not None: - training_mode = call_ctx_training - # (3) User set `tf.keras.backend.set_learning_phase`. - elif backend.global_learning_phase_is_set(): - training_mode = backend.learning_phase() - # Ensure value is a `bool` or `tf.bool`. - if isinstance(training_mode, bool): - pass - elif tf.is_tensor(training_mode): - training_mode = tf.cast(training_mode, tf.bool) + def _get_propagated_call_context_arguments( + self, args, kwargs, call_context, local_call_context_arguments + ): + """Resolves the values for propagated call context arguments for the + current layer. + + Args: + args: The arguments passed to the current layer's `call` method. + kwargs: The keyword arguments passed to the current layer's `call` + method. + call_context: The `CallContext` for the current call-stack. + local_call_context_arguments: The call-context arguments registered + with to the current layer's `Layer.call` method. + + Returns: + A tuple of the following: + 1. Updated args + 2. Updated kwargs + 3. A dictionary of the resolved call-context arguments that should + be propagated to the next layer in the call-stack. + """ + propagated_context = dict() + relevant_arguments = call_context.call_context_args.keys() | set( + local_call_context_arguments + ) + + for argument in relevant_arguments: + authoritative_value = None + was_explicitly_passed = self._call_spec.arg_was_passed( + argument, args, kwargs + ) + if self._expects_context_arg(argument): + # (1) `arg_name` was passed to this `Layer.call`. + if was_explicitly_passed: + authoritative_value = self._call_spec.get_arg_value( + argument, args, kwargs + ) + # If no `arg_name` arg was passed, or `None` was explicitly + # passed, the framework will make a decision about the training + # mode is. + if authoritative_value is None: + value_from_context = call_context.get_call_context_arg( + argument + ) + # (2) `arg_name` mode is inferred from an outer + # `Layer.call`. + if value_from_context is not None: + authoritative_value = value_from_context + # (3) User set `tf.keras.backend.set_learning_phase`. + elif ( + argument == "training" + and backend.global_learning_phase_is_set() + ): + authoritative_value = backend.learning_phase() + # Ensure value is a `bool` or `tf.bool`. + if isinstance(authoritative_value, bool): + pass + elif tf.is_tensor(authoritative_value): + authoritative_value = tf.cast( + authoritative_value, tf.bool + ) + else: + authoritative_value = bool(authoritative_value) + # (4) We default to using `call`'s default value for + # `arg_name`, or treating the layer as if it is in inference + # if no non-None default is specified in the `call` + # signature. else: - training_mode = bool(training_mode) - # (4) We default to using `call`'s default value for `training`, - # or treating the layer as if it is in inference if no non-None - # default is specified in the `call` signature. - else: - training_mode = self._call_spec.default_training_arg + authoritative_value = ( + self._call_spec.get_context_arg_default(argument) + ) - # For case (2), (3), (4) `training` arg is passed by framework. - args, kwargs = self._call_spec.set_arg_value( - "training", training_mode, args, kwargs - ) - else: - if "training" in kwargs: - # `training` was passed to this `Layer` but is not needed for - # `Layer.call`. It will set the default mode for inner - # `Layer.call`s. - training_mode = kwargs.pop("training") + # For case (2), (3), (4) `arg_name` arg is passed by + # framework. + args, kwargs = self._call_spec.set_arg_value( + argument, authoritative_value, args, kwargs + ) else: - # Grab the current `training` mode from any outer `Layer.call`. - training_mode = call_context.training + if argument in kwargs: + # `arg_name` was passed to this `Layer` but is not needed + # for `Layer.call`. It will set the default mode for inner + # `Layer.call`s. + authoritative_value = kwargs.pop(argument) + else: + # Grab the current `arg_name` mode from any outer + # `Layer.call`. + authoritative_value = call_context.get_call_context_arg( + argument + ) - return args, kwargs, training_mode + if authoritative_value is not None: + propagated_context[argument] = authoritative_value + + return args, kwargs, propagated_context def _autographed_call(self): # Wrapping `call` function in autograph to allow for dynamic control @@ -3351,7 +3474,8 @@ def _is_layer(self): def _init_call_fn_args(self, expects_training_arg=None): self._call_spec = layer_utils.CallFunctionSpec( - tf_inspect.getfullargspec(self.call) + tf_inspect.getfullargspec(self.call), + getattr(self, "_call_context_args", set()), ) if expects_training_arg is not None: self._call_spec.expects_training_arg = expects_training_arg @@ -3361,6 +3485,9 @@ def _expects_training_arg(self): """Whether the call function uses 'training' as a parameter.""" return self._call_spec.expects_training_arg + def _expects_context_arg(self, argument_name): + return argument_name in self._call_spec.expected_context_args + @property def _expects_mask_arg(self): return self._call_spec.expects_mask_arg diff --git a/tf_keras/engine/base_layer_test.py b/tf_keras/engine/base_layer_test.py index 19d2700b5..44ff934b0 100644 --- a/tf_keras/engine/base_layer_test.py +++ b/tf_keras/engine/base_layer_test.py @@ -1106,6 +1106,171 @@ def __init__(self, var1, var2, var3=None, **kwargs): with self.assertRaises(NotImplementedError): config = layer.get_config() + def test_call_context_args_with_custom_layers_propagates_args(self): + class Inner(layers.Layer): + def __init__(self): + super().__init__() + self._register_call_context_args("foo_mode") + + def call(self, x, foo_mode=None): + return x + (1 if foo_mode else 0) + + class Outer(layers.Layer): + def __init__(self): + super().__init__() + self._register_call_context_args("foo_mode") + self.inner = Inner() + + def call(self, x): + # Outer doesn’t even need to re‑inject explicitly: + # our base class will propagate foo_mode automatically + return self.inner(x) + + layer = Outer() + self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1) + self.assertEqual(int(layer(np.array(0))), 0) + + def test_register_call_context_arguments_success(self): + """Validate that registering call-context args works as expected.""" + + class MyLayer(layers.Layer): + def call(self, x): + return x + + layer = MyLayer() + + layer._register_call_context_args("foo_mode") + + self.assertCountEqual( + layer._call_context_args, ("foo_mode", "training") + ) + + def test_register_call_context_arguments_after_call_raises_error(self): + """Validate that registering call-context args after the layer has + been called raises an error.""" + + class MyLayer(layers.Layer): + def call(self, x): + return x + + layer = MyLayer() + layer(np.array(0)) + with self.assertRaisesRegex( + RuntimeError, + "Cannot add call-context args after the layer has been called.", + ): + layer._register_call_context_args("foo_mode") + + def test_nested_context_args_follow_priority_order(self): + """Validate that call-context args are propagated correctly + through multiple layers, and that the most specific value is used + when multiple values are passed down the call-stack. + """ + + class Inner(base_layer.Layer): + def __init__(self): + super().__init__(name="inner_layer") + self._register_call_context_args("foo_mode") + + def call(self, inputs, foo_mode=None): + return inputs + (1 if foo_mode else 0) + + class Middle(base_layer.Layer): + def __init__(self): + super().__init__(name="middle_layer") + self._inner_layer = Inner() + + def call(self, inputs): + return self._inner_layer(inputs) + + class Outer(base_layer.Layer): + def __init__(self): + super().__init__(name="outer_layer") + self._middle = Middle() + + def call(self, inputs): + return self._middle(inputs) + + layer = Outer() + layer._register_call_context_args("foo_mode") + + # The value of foo_mode is set to True in the call to Outer, + # so it should automatically propagate to Inner through Middle. + self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1) + self.assertEqual(int(layer(np.array(0))), 0) + + def test_context_arg_propagation_without_declaration_does_not_resolve(self): + """Validate that layer does not resolve a propagated arg if it is not + declared as a call-context arg in the layer itself.""" + + class Inner(layers.Layer): + def call(self, x, foo_mode=None): + return x + (1 if foo_mode else 0) + + class Wrapper(layers.Layer): + def __init__(self): + super().__init__() + self.inner = Inner() + + def call(self, x): + return self.inner(x) + + layer = Wrapper() + layer._register_call_context_args("foo_mode") + + # The value of foo_mode is set to True in the call to Wrapper, + # However, it is not declared as a call-context arg in Inner, + # so it should not resolve to True inside Inner (and instead + # default to False). + self.assertEqual(int(layer(np.array(0), foo_mode=True)), 0) + + def test_call_context_args_with_models_as_layers_propagates_args(self): + """Validate that call-context args are propagated correctly + through functional and sequential models when used as layers. + """ + + class InnerLayer(base_layer.Layer): + def __init__(self): + super().__init__(name="inner_layer") + self._register_call_context_args("foo") + + def call(self, inputs, foo=None): + if foo: + return inputs + 1.0 + return inputs + + class OuterLayer(base_layer.Layer): + def __init__(self): + super().__init__(name="outer_layer") + self._inner_layer = InnerLayer() + + def call(self, inputs): + return self._inner_layer(inputs) + + sample_input = tf.constant([[1.0, 2.0], [3.0, 4.0]], dtype="float32") + + # Sequential model + seq = sequential.Sequential([OuterLayer()]) + seq._register_call_context_args("foo") + + out_true = seq(sample_input, foo=True) + self.assertAllEqual(out_true, sample_input + 1.0) + + out_false = seq(sample_input, foo=False) + self.assertAllEqual(out_false, sample_input) + + # Functional model + inp = input_layer.Input((2,)) + outer = OuterLayer()(inp) + model = training_lib.Model(inputs=[inp], outputs=[outer]) + model._register_call_context_args("foo") + + out_true = model(sample_input, foo=True) + self.assertAllEqual(out_true, sample_input + 1.0) + + out_false = model(sample_input, foo=False) + self.assertAllEqual(out_false, sample_input) + @test_utils.run_v2_only class SymbolicSupportTest(test_combinations.TestCase): diff --git a/tf_keras/engine/base_layer_utils.py b/tf_keras/engine/base_layer_utils.py index f6f45ca73..ce8a17bac 100644 --- a/tf_keras/engine/base_layer_utils.py +++ b/tf_keras/engine/base_layer_utils.py @@ -480,7 +480,8 @@ class CallContext: layer: The `Layer` whose `call` is currently active. inputs: The inputs to the currently active `Layer`. build_graph: Whether currently inside a Graph or FuncGraph. - training: Whether currently executing in training or inference mode. + call_context_args: The call-context arguments being propagated through the + the call-stack. saving: Whether currently saving to SavedModel. frozen: Whether currently executing inside a `Layer` with `trainable` set to `False`. @@ -495,6 +496,7 @@ def __init__(self): "layer": None, "inputs": None, "build_graph": False, + "call_context_args": dict(), "training": None, "saving": None, } @@ -502,14 +504,17 @@ def __init__(self): # refactor. self._in_keras_graph = False - def enter(self, layer, inputs, build_graph, training, saving=None): + def enter( + self, layer, inputs, build_graph, call_context_args=dict(), saving=None + ): """Push a Layer and its inputs and state onto the current call context. Args: layer: The `Layer` whose `call` is currently active. inputs: The inputs to the currently active `Layer`. build_graph: Whether currently inside a Graph or FuncGraph. - training: Whether currently executing in training or inference mode. + call_context_args: The call-context arguments being propagated through + the call-stack. saving: Whether currently saving to SavedModel. Returns: @@ -519,7 +524,7 @@ def enter(self, layer, inputs, build_graph, training, saving=None): "layer": layer, "inputs": inputs, "build_graph": build_graph, - "training": training, + "call_context_args": call_context_args, "saving": saving, } return CallContextManager(self, state) @@ -538,7 +543,14 @@ def build_graph(self): @property def training(self): - return self._state["training"] + return self.call_context_args.get("training", None) + + @property + def call_context_args(self): + return self._state["call_context_args"] + + def get_call_context_arg(self, arg_name): + return self.call_context_args.get(arg_name, None) @property def saving(self): diff --git a/tf_keras/engine/base_layer_v1.py b/tf_keras/engine/base_layer_v1.py index 55bc3cfc6..0d27f6595 100644 --- a/tf_keras/engine/base_layer_v1.py +++ b/tf_keras/engine/base_layer_v1.py @@ -132,6 +132,7 @@ def __init__( self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs ): self._instrument_layer_creation() + self._called = False # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' @@ -165,6 +166,8 @@ def __init__( self._input_spec = None self.supports_masking = False + self._call_context_args = {"training"} + self._init_set_name(name) self._activity_regularizer = regularizers.get( kwargs.pop("activity_regularizer", None) @@ -705,6 +708,7 @@ def __call__(self, *args, **kwargs): RuntimeError: if `super().__init__()` was not called in the constructor. """ + self._called = True self._assert_built_as_v1() if not hasattr(self, "_thread_local"): @@ -803,7 +807,12 @@ def _convert_non_tensor(x): if build_graph and base_layer_utils.needs_keras_history(inputs): base_layer_utils.create_keras_history(inputs) - with call_context.enter(self, inputs, build_graph, training_value): + with call_context.enter( + self, + inputs, + build_graph, + call_context_args={"training": training_value}, + ): # Check input assumptions set after layer building, e.g. input # shape. if build_graph: diff --git a/tf_keras/layers/core/tf_op_layer.py b/tf_keras/layers/core/tf_op_layer.py index e9f68abed..5ced8be77 100644 --- a/tf_keras/layers/core/tf_op_layer.py +++ b/tf_keras/layers/core/tf_op_layer.py @@ -259,6 +259,10 @@ def _call_wrapper(*args, **kwargs): self._call_spec.expects_training_arg = False self._call_spec.expects_mask_arg = False + # Clear the call-context arguments for the layer's call method. + # Otherwise, Keras ends up injecting context arguments into the op-call + # when the call method accepts kwargs. + self._call_spec._expected_context_args.clear() def _call_wrapper(self, *args, **kwargs): created_variables = [] diff --git a/tf_keras/layers/rnn/base_rnn_test.py b/tf_keras/layers/rnn/base_rnn_test.py index 2e0fcf59b..4d9ecb321 100644 --- a/tf_keras/layers/rnn/base_rnn_test.py +++ b/tf_keras/layers/rnn/base_rnn_test.py @@ -639,7 +639,9 @@ def test_stacked_rnn_attributes(self): cells[0].kernel, tf.ones_like(cells[0].kernel) ) # TODO(b/128682878): Remove when RNNCells are __call__'d. - with base_layer_utils.call_context().enter(layer, x, True, None): + with base_layer_utils.call_context().enter( + layer, x, {"training": True}, None + ): cells[0].add_update(update_1) cells[0].add_update(update_2) self.assertEqual(len(layer.updates), 2) diff --git a/tf_keras/layers/rnn/bidirectional_test.py b/tf_keras/layers/rnn/bidirectional_test.py index 8eab3bda4..5dcf2365c 100644 --- a/tf_keras/layers/rnn/bidirectional_test.py +++ b/tf_keras/layers/rnn/bidirectional_test.py @@ -472,7 +472,9 @@ def test_Bidirectional_updates(self): _ = layer(x) assert not layer.updates # TODO(b/128684069): Remove when Wrapper sublayers are __call__'d. - with base_layer_utils.call_context().enter(layer, x, True, None): + with base_layer_utils.call_context().enter( + layer, x, {"training": True}, None + ): layer.forward_layer.add_update(x_reachable_update) layer.forward_layer.add_update(1) layer.backward_layer.add_update(x_reachable_update) diff --git a/tf_keras/layers/rnn/cell_wrappers.py b/tf_keras/layers/rnn/cell_wrappers.py index de02f0704..957059805 100644 --- a/tf_keras/layers/rnn/cell_wrappers.py +++ b/tf_keras/layers/rnn/cell_wrappers.py @@ -52,9 +52,21 @@ def __init__(self, cell, *args, **kwargs): super().__init__(*args, **kwargs) self.cell = cell cell_call_spec = tf_inspect.getfullargspec(cell.call) + accepts_kwargs = cell_call_spec.varkw is not None + self._call_spec.expects_training_arg = ( "training" in cell_call_spec.args - ) or (cell_call_spec.varkw is not None) + ) or accepts_kwargs + + # Filter _expects_context_arg. An argument is kept if: + # 1. It's an explicit argument in cell_call_spec.args OR + # 2. The cell accepts arbitrary keyword arguments (**kwargs), + # meaning it could potentially handle the context argument. + self._call_spec._expected_context_args = { + arg + for arg in self._call_spec._expected_context_args + if (arg in cell_call_spec.args) or accepts_kwargs + } def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): """Calls the wrapped cell and performs the wrapping logic. diff --git a/tf_keras/saving/legacy/saved_model/save_impl.py b/tf_keras/saving/legacy/saved_model/save_impl.py index 18ece3179..69a5d5097 100644 --- a/tf_keras/saving/legacy/saved_model/save_impl.py +++ b/tf_keras/saving/legacy/saved_model/save_impl.py @@ -219,7 +219,11 @@ def wrap_layer_functions(layer, serialization_cache): with tracing_scope(): call_collection.trace_with_input_signature() with base_layer_utils.call_context().enter( - layer, inputs=None, build_graph=True, training=None, saving=True + layer, + inputs=None, + build_graph=True, + call_context_args={}, + saving=True, ): for fn in fns.values(): if fn is not None and not isinstance(fn, LayerCall): @@ -515,19 +519,28 @@ def trace_with_training(value, fn=fn): else: add_trace_to_queue(fn, args, kwargs) - def training_arg_was_passed(self, args, kwargs): + def arg_was_passed(self, arg_name, args, kwargs): + """Returns True if the argument was passed to the call function.""" return self._call_spec.arg_was_passed( - "training", args, kwargs, inputs_in_args=True + arg_name, args, kwargs, inputs_in_args=True ) - def get_training_arg_value(self, args, kwargs): + def training_arg_was_passed(self, args, kwargs): + """Returns True if the training arg was passed to the call function.""" + return self.arg_was_passed("training", args, kwargs) + + def get_arg_value(self, arg_name, args, kwargs): + """Returns the value of the given argument or None if not found.""" try: return self._call_spec.get_arg_value( - "training", args, kwargs, inputs_in_args=True + arg_name, args, kwargs, inputs_in_args=True ) - except KeyError: # Training is not in args or kwargs. + except KeyError: # Arg not found in args or kwargs. return None + def get_training_arg_value(self, args, kwargs): + return self.get_arg_value("training", args, kwargs) + def get_input_arg_value(self, args, kwargs): return self._call_spec.get_arg_value( self._input_arg_name, args, kwargs, inputs_in_args=True @@ -613,20 +626,23 @@ def layer_call_wrapper(call_collection, method, name): def wrapper(*args, **kwargs): """Calls method within call context.""" layer = call_collection.layer - training = None + propagated = {"training": None} inputs = _filtered_inputs([args, kwargs]) - if (args or kwargs) and call_collection.training_arg_was_passed( - args, kwargs - ): - training = call_collection.get_training_arg_value(args, kwargs) + for context_arg in layer._call_context_args: + if (args or kwargs) and call_collection.arg_was_passed( + context_arg, args, kwargs + ): + propagated[context_arg] = call_collection.get_arg_value( + context_arg, args, kwargs + ) original_losses = _reset_layer_losses(layer) with base_layer_utils.call_context().enter( layer, inputs=inputs, build_graph=False, - training=training, + call_context_args=propagated, saving=True, ): with autocast_variable.enable_auto_cast_variables( diff --git a/tf_keras/saving/legacy/saving_utils.py b/tf_keras/saving/legacy/saving_utils.py index 473dc482c..a5629ee36 100644 --- a/tf_keras/saving/legacy/saving_utils.py +++ b/tf_keras/saving/legacy/saving_utils.py @@ -138,12 +138,24 @@ def trace_model_call(model, input_signature=None): @tf.function def _wrapped_model(*args, **kwargs): """A concrete tf.function that wraps the model's call function.""" + call_context = base_layer_utils.call_context() + + args, kwargs, propagated = model._get_propagated_call_context_arguments( + args, kwargs, call_context, model._call_context_args + ) + (args, kwargs,) = model._call_spec.set_arg_value( "training", False, args, kwargs, inputs_in_args=True ) - with base_layer_utils.call_context().enter( - model, inputs=None, build_graph=False, training=False, saving=True + propagated["training"] = False + + with call_context.enter( + model, + inputs=None, + build_graph=False, + call_context_args=propagated, + saving=True, ): outputs = model(*args, **kwargs) diff --git a/tf_keras/utils/layer_utils.py b/tf_keras/utils/layer_utils.py index ae60d874f..6a85420f9 100644 --- a/tf_keras/utils/layer_utils.py +++ b/tf_keras/utils/layer_utils.py @@ -775,11 +775,13 @@ class CallFunctionSpec: """Caches the spec and provides utilities for handling call function args.""" - def __init__(self, full_argspec): + def __init__(self, full_argspec, call_context_args=set()): """Initialies a `CallFunctionSpec`. Args: full_argspec: the FullArgSpec of a call function of a layer. + call_context_args: The set of call-context arguments registered + with to the current layer. """ self._full_argspec = full_argspec @@ -797,6 +799,18 @@ def __init__(self, full_argspec): "mask" in self._arg_names or call_accepts_kwargs ) + # Track the set of call-context arguments that the current layer's + # `call` method accepts. + self._expected_context_args = set() + self._update_call_context_arguments(call_context_args) + + self._context_arg_defaults = dict() + self._update_call_context_argument_defaults(call_context_args) + + def _update_call_context_argument_defaults(self, context_args): + """Updates the set of call-context argument defaults for the current + layer's `call` method. + """ call_fn_defaults = self._full_argspec.defaults or [] defaults = dict() # The call arg defaults are an n-tuple of the last n elements of the @@ -806,7 +820,21 @@ def __init__(self, full_argspec): # The default training arg will be any (non-None) default specified in # the method signature, or None if no value is specified. defaults.update(self._full_argspec.kwonlydefaults or {}) - self._default_training_arg = defaults.get("training") + + for arg in context_args: + self._context_arg_defaults[arg] = defaults.get(arg) + + def _update_call_context_arguments(self, context_args): + """Updates the set of call-context arguments that the current layer's + `call` method accepts. + """ + call_accepts_kwargs = self._full_argspec.varkw is not None + args_to_add = { + arg + for arg in context_args + if call_accepts_kwargs or arg in self._arg_names + } + self._expected_context_args.update(args_to_add) @property def full_argspec(self): @@ -843,6 +871,16 @@ def expects_training_arg(self): def expects_training_arg(self, value): self._expects_training_arg = value + @property + def expected_context_args(self): + """The set of call-context arguments that the current layer's + `call` method accepts.""" + return self._expected_context_args + + @expected_context_args.setter + def expected_context_args(self, value): + self._expected_context_args = value + @property def expects_mask_arg(self): """Whether the call function uses `mask` as a parameter.""" @@ -855,7 +893,11 @@ def expects_mask_arg(self, value): @property def default_training_arg(self): """The default value given to the "training" argument.""" - return self._default_training_arg + return self.get_context_arg_default("training") + + def get_context_arg_default(self, arg_name): + """The default value given to the call context arguments.""" + return self._context_arg_defaults.get(arg_name, None) def arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): """Returns true if argument is present in `args` or `kwargs`.