diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index b9e785dfc949..16bdb759a9c0 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: python-version: ['3.10'] - backend: [tensorflow, jax, torch, numpy, openvino] + backend: [tensorflow, jax, torch, numpy, openvino, nnx] name: Run tests runs-on: ubuntu-latest env: diff --git a/.github/workflows/config/nnx/keras.json b/.github/workflows/config/nnx/keras.json new file mode 100644 index 000000000000..d6bb3e7fd4d5 --- /dev/null +++ b/.github/workflows/config/nnx/keras.json @@ -0,0 +1,7 @@ +{ + "floatx": "float32", + "epsilon": 1e-07, + "backend": "jax", + "image_data_format": "channels_last", + "nnx_enabled": true +} diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index e7af37f23c83..45d933a1e12d 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -11,7 +11,9 @@ "torch", "--extra-index-url https://download.pytorch.org/whl/cpu ", ), - "jax": ("jax[cpu]", ""), + # please update the jax version here if jax version is updated in + # requirements file + "jax": ("jax[cpu]==0.5.0 flax>=0.10.0", ""), "openvino": ("openvino", ""), } diff --git a/keras/api/_tf_keras/keras/config/__init__.py b/keras/api/_tf_keras/keras/config/__init__.py index 106fd46a3291..8cf3a1c30abd 100644 --- a/keras/api/_tf_keras/keras/config/__init__.py +++ b/keras/api/_tf_keras/keras/config/__init__.py @@ -17,6 +17,7 @@ from keras.src.backend.config import ( is_flash_attention_enabled as is_flash_attention_enabled, ) +from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled from keras.src.backend.config import max_epochs as max_epochs from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch from keras.src.backend.config import set_epsilon as set_epsilon diff --git a/keras/api/config/__init__.py b/keras/api/config/__init__.py index 106fd46a3291..8cf3a1c30abd 100644 --- a/keras/api/config/__init__.py +++ b/keras/api/config/__init__.py @@ -17,6 +17,7 @@ from keras.src.backend.config import ( is_flash_attention_enabled as is_flash_attention_enabled, ) +from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled from keras.src.backend.config import max_epochs as max_epochs from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch from keras.src.backend.config import set_epsilon as set_epsilon diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..a200b17c914e 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -39,7 +39,7 @@ from keras.src.backend.tensorflow.core import Variable as BackendVariable elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 - from keras.src.backend.jax.core import Variable as BackendVariable + from keras.src.backend.jax import Variable as BackendVariable elif backend() == "torch": from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 68f8e1014639..b33607bc0ff7 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -15,6 +15,9 @@ # Default backend: TensorFlow. _BACKEND = "tensorflow" +# Whether NNX is enabled. +_NNX_ENABLED = False + # Cap run duration for debugging. _MAX_EPOCHS = None _MAX_STEPS_PER_EPOCH = None @@ -230,6 +233,33 @@ def is_flash_attention_enabled(): return global_state.get_global_attribute("flash_attention", default=None) +@keras_export("keras.config.is_nnx_enabled") +def is_nnx_enabled(): + """Checks whether NNX specific features are enabled for the JAX backend. + + Returns: + bool: `True` if NNX backend features are enabled, `False` otherwise. + Defaults to `False`. + """ + return _NNX_ENABLED + + +def set_nnx_enabled(value): + global _NNX_ENABLED + from keras.src.backend.common import global_state + + _NNX_ENABLED = bool(value) + if _NNX_ENABLED: + try: + from flax import nnx # noqa F401 + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) + global_state.set_global_attribute("nnx_enabled", bool(value)) + + def standardize_data_format(data_format): if data_format is None: return image_data_format() @@ -261,6 +291,7 @@ def keras_home(): # Attempt to read Keras config file. _config_path = os.path.expanduser(os.path.join(_KERAS_DIR, "keras.json")) + if os.path.exists(_config_path): try: with open(_config_path) as f: @@ -274,36 +305,18 @@ def keras_home(): _backend = _config.get("backend", _BACKEND) _image_data_format = _config.get("image_data_format", image_data_format()) assert _image_data_format in {"channels_last", "channels_first"} + _nnx_enabled_config = _config.get("nnx_enabled", _NNX_ENABLED) + if not isinstance(_nnx_enabled_config, bool): + _NNX_ENABLED = str(_nnx_enabled_config).lower() == "true" + else: + _NNX_ENABLED = _nnx_enabled_config + # Apply basic configs that don't cause circular import set_floatx(_floatx) set_epsilon(_epsilon) set_image_data_format(_image_data_format) _BACKEND = _backend -# Save config file, if possible. -if not os.path.exists(_KERAS_DIR): - try: - os.makedirs(_KERAS_DIR) - except OSError: - # Except permission denied and potential race conditions - # in multi-threaded environments. - pass - -if not os.path.exists(_config_path): - _config = { - "floatx": floatx(), - "epsilon": epsilon(), - "backend": _BACKEND, - "image_data_format": image_data_format(), - } - try: - with open(_config_path, "w") as f: - f.write(json.dumps(_config, indent=4)) - except IOError: - # Except permission denied. - pass - -# Set backend based on KERAS_BACKEND flag, if applicable. if "KERAS_BACKEND" in os.environ: _backend = os.environ["KERAS_BACKEND"] if _backend: @@ -313,6 +326,7 @@ def keras_home(): if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ: _MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"]) + if _BACKEND != "tensorflow": # If we are not running on the tensorflow backend, we should stop tensorflow # from using all available GPU memory. See @@ -403,3 +417,35 @@ def max_steps_per_epoch(): `None`, no limit is applied. """ return _MAX_STEPS_PER_EPOCH + + +if not os.path.exists(_KERAS_DIR): + try: + os.makedirs(_KERAS_DIR) + except OSError: + # Except permission denied and potential race conditions + pass + +if not os.path.exists(_config_path): + _config_to_save = { + "floatx": floatx(), + "epsilon": epsilon(), + "backend": _BACKEND, # Use the final _BACKEND value + "image_data_format": image_data_format(), + "nnx_enabled": _NNX_ENABLED, + } + try: + with open(_config_path, "w") as f: + f.write(json.dumps(_config_to_save, indent=4)) + except IOError: + # Except permission denied. + pass + +if "KERAS_NNX_ENABLED" in os.environ: + env_val = os.environ["KERAS_NNX_ENABLED"].lower() + if env_val == "true": + _NNX_ENABLED = True + elif env_val == "false": + _NNX_ENABLED = False + +set_nnx_enabled(_NNX_ENABLED) diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 12d25effa6fc..335eed660b46 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,3 +1,4 @@ +from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image @@ -10,7 +11,11 @@ from keras.src.backend.jax.core import IS_THREAD_SAFE from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS -from keras.src.backend.jax.core import Variable + +if is_nnx_enabled(): + from keras.src.backend.jax.core import NnxVariable as Variable +else: + from keras.src.backend.jax.core import JaxVariable as Variable from keras.src.backend.jax.core import cast from keras.src.backend.jax.core import compute_output_spec from keras.src.backend.jax.core import cond diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 747c5881106b..12934aa60b53 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -5,12 +5,15 @@ import numpy as np from keras.src import tree +from keras.src.backend import config from keras.src.backend.common import KerasVariable from keras.src.backend.common import global_state from keras.src.backend.common import standardize_dtype from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.name_scope import name_scope as base_name_scope from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.stateless_scope import get_stateless_scope +from keras.src.backend.common.stateless_scope import in_stateless_scope from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.jax import distribution_lib @@ -19,7 +22,7 @@ IS_THREAD_SAFE = True -class Variable(KerasVariable): +class JaxVariable(KerasVariable): def __init__(self, *args, layout=None, **kwargs): # Intercept layout parameter so that it is available # during initialization. @@ -55,6 +58,192 @@ def __jax_array__(self): return self.value +_JAX_VARIABLE_TYPE = JaxVariable +if config.is_nnx_enabled(): + from flax import nnx + + class NnxVariable(JaxVariable, nnx.Variable): + def __init__( + self, + initializer, + shape=None, + dtype=None, + trainable=True, + autocast=True, + aggregation="none", + synchronization="auto", + name=None, + layout=None, + mutable=None, + **nnx_metadata, + ): + # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' + # param takes precedence. + nnx_metadata["mutable"] = trainable if mutable is None else mutable + + # Initialize nnx.Variable first. + # Determine the dtype for the placeholder. + _placeholder_value = jax.ShapeDtypeStruct( + shape or (), dtype=standardize_dtype(dtype) + ) + + # Call nnx.Variable.__init__ directly. + nnx.Variable.__init__( + self, value=_placeholder_value, **nnx_metadata + ) + + # Store JAX-specific layout using object.__setattr__ BEFORE + # KerasVariable init. + # This is because KerasVariable.__init__ will call + # self._initialize, which uses self._layout. + object.__setattr__(self, "_layout", layout) + + # Initialize JaxVariable (which will call KerasVariable.__init__). + JaxVariable.__init__( + self, + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=trainable, + autocast=autocast, + aggregation=aggregation, + synchronization=synchronization, + name=name, + ) + + @property + def _value(self): + if hasattr(self, "raw_value"): + return self.raw_value + return None + + @_value.setter + def _value(self, new_keras_value): + self._direct_assign(new_keras_value) + + def __getstate__(self): + # Get the state from KerasVariable (attributes in __dict__) + # KerasVariable does not have a custom __getstate__, so we mimic + # default behavior. + try: + keras_state = KerasVariable.__getstate__(self) + except AttributeError: + keras_state = object.__getstate__(self) + + # Get the state from nnx.Variable + nnx_specific_state = nnx.Variable.__getstate__(self) + + # Merge them. Keras state is primary. NNX specific state adds + # to it. + if "raw_value" in nnx_specific_state: + keras_state["_value"] = nnx_specific_state["raw_value"] + + # Add NNX attributes that are not in Keras's __dict__ + if "_trace_state" in nnx_specific_state: + keras_state["_trace_state"] = nnx_specific_state["_trace_state"] + if "_var_metadata" in nnx_specific_state: + keras_state["_var_metadata"] = nnx_specific_state[ + "_var_metadata" + ] + + # Remove elements that might be problematic or redundant if + # nnx.Variable's __getstate__ + keras_state.pop("raw_value", None) + + return keras_state + + def __setstate__(self, state): + # Separate nnx specific keys that we added if they are not part + # of Keras __dict__ this __getstate__ puts them into the main + # state dictionary. + nnx_raw_value = state["_value"] # This was raw_value + nnx_trace_state = state.pop("_trace_state", None) + nnx_var_metadata = state.pop("_var_metadata", None) + + # Populate the instance's __dict__ with the Keras attributes. + self.__dict__.update(state) + + # restore the nnx.Variable specific slotted attributes. + object.__setattr__(self, "raw_value", nnx_raw_value) + + if nnx_trace_state is not None: + object.__setattr__(self, "_trace_state", nnx_trace_state) + else: + pass + + if nnx_var_metadata is not None: + object.__setattr__(self, "_var_metadata", nnx_var_metadata) + else: + pass + + # Ensure Keras's self._value is also consistent with the + # restored raw_value + self._value = nnx_raw_value + + if hasattr(self, "_shape") and self._shape is not None: + self._ndim = len(self._shape) + else: + # Fallback if shape isn't immediately available. + self._ndim = len(self.raw_value.shape) + + def _direct_assign(self, value): + # Apply JAX-specific distribution if layout is present + if self._layout is not None: + value = distribution_lib.distribute_variable( + value, self._layout + ) + + # Ensure that nnx.Variable part is initialized + if not hasattr(self, "_var_metadata"): + # todo: should add a warning + pass + + # Apply on_set_value hook if it exists + if ( + hasattr(self, "_var_metadata") + and "on_set_value" in self._var_metadata + ): + value = self._var_metadata["on_set_value"](self, value) + + # Directly set raw_value. nnx.Variable handles mutable array + # updates + object.__setattr__(self, "raw_value", value) + + @property + def value(self): + if in_stateless_scope(): + scope = get_stateless_scope() + stateless_value = scope.get_current_value(self) + if stateless_value is not None: + return self._maybe_autocast(stateless_value) + if not hasattr(self, "raw_value"): + if self._initializer is not None: + self._initialize( + self._initializer(self.shape, dtype=self.dtype) + ) + else: + raise AttributeError( + "Variable is not properly initialized (raw_value " + "missing) and has no initializer." + ) + current_value = self.raw_value + if ( + hasattr(self, "_var_metadata") + and "on_get_value" in self._var_metadata + ): + current_value = self._var_metadata["on_get_value"]( + self, current_value + ) + return self._maybe_autocast(current_value) + + # Todo: NNX has agreed to fix it on their end. I will remove it once + # that is done + def __hash__(self): + return id(self) + + _JAX_VARIABLE_TYPE = NnxVariable + + def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): if ragged: raise ValueError("`ragged=True` is not supported with jax backend") @@ -68,7 +257,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): # an existing distributed jax array will raise error. return x - if isinstance(x, Variable): + if isinstance(x, _JAX_VARIABLE_TYPE): if dtype is not None and x.dtype != dtype: return x.value.astype(dtype) return x.value @@ -352,7 +541,7 @@ def fori_loop(lower, upper, body_fun, init_val): def stop_gradient(variable): - if isinstance(variable, Variable): + if isinstance(variable, _JAX_VARIABLE_TYPE): variable = variable.value return jax.lax.stop_gradient(variable) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py new file mode 100644 index 000000000000..0578c97f4964 --- /dev/null +++ b/keras/src/backend/jax/core_test.py @@ -0,0 +1,68 @@ +import os + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.backend.config import is_nnx_enabled + +if is_nnx_enabled(): + from flax import nnx + + from keras.src.backend.jax.core import NnxVariable + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="JAX backend specific test for core Variable integration with NNX.", +) +@pytest.mark.skipif( + not is_nnx_enabled(), + reason="Test requires NNX backend to be enabled by default for setup.", +) +class JaxCoreVariableTest(testing.TestCase): + def setup(self): + super().setup() + + class NNXModel(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + # Use NnxVariable directly as KerasJaxVariable + # might be JaxVariable if NNX is disabled globally. + self.custom_variable = NnxVariable(jnp.ones((1, 3))) + + def __call__(self, x): + return self.linear(x) + self.custom_variable + + self.nnx_model = NNXModel(rngs=nnx.Rngs(0)) + self.keras_nnx_model = keras.Sequential( + [keras.layers.Dense(units=1, input_shape=(10,))] + ) + self.single_dummy_input = np.random.rand(1, 10) + + def test_variable_in_nnx_module(self): + self.assertTrue(hasattr(self.nnx_model.custom_variable, "_trace_state")) + self.assertIsNotNone(self.nnx_model.custom_variable._trace_state) + self.assertAllEqual(self.nnx_model.custom_variable.value, [[1, 1, 1]]) + self.assertTrue( + isinstance(self.nnx_model.custom_variable, nnx.Variable) + ) + + def test_model_saving(self): + path = os.path.join(self.get_temp_dir(), "model.keras") + original_outputs = self.keras_nnx_model(self.single_dummy_input) + self.keras_nnx_model.save(path, save_format="keras_v3") + restored_model = keras.models.load_model(path) + restored_outputs = restored_model(self.single_dummy_input) + self.assertAllEqual(original_outputs, restored_outputs) + + def test_keras_variable_nnx_split_merge_sync(self): + variable1 = keras.Variable(jnp.array(1.0)) + graphdef, state = nnx.split(variable1) + state = jax.tree.map(lambda x: x + 1, state) + variable2 = nnx.merge(graphdef, state) + self.assertEqual(variable2._value, variable2.value) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index fbcc4fe5b5c6..7784bae431ed 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,2 +1,11 @@ +from keras.src.backend.config import is_nnx_enabled + +if is_nnx_enabled(): + from flax import nnx + + class NnxLayer(nnx.Module): + pass + + class JaxLayer: pass diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 2577de297d78..327b7968f953 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -12,6 +12,7 @@ from keras.src import tree from keras.src.backend import config from keras.src.backend import distribution_lib as jax_distribution_lib +from keras.src.backend.config import is_nnx_enabled from keras.src.distribution import distribution_lib from keras.src.trainers import trainer as base_trainer from keras.src.trainers.data_adapters import array_slicing @@ -19,6 +20,13 @@ from keras.src.trainers.epoch_iterator import EpochIterator from keras.src.utils import traceback_utils +if is_nnx_enabled(): + from flax import nnx + + jit = nnx.jit +else: + jit = jax.jit + class JAXTrainer(base_trainer.Trainer): def __init__(self): @@ -233,7 +241,7 @@ def concatenate(outputs): return output if not self.run_eagerly and self.jit_compile: - concatenate = jax.jit(concatenate) + concatenate = jit(concatenate) def iterator_step(state, iterator): data = next(iterator) @@ -277,7 +285,7 @@ def make_train_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - train_step = jax.jit(self.train_step, donate_argnums=0) + train_step = jit(self.train_step, donate_argnums=0) else: train_step = self.train_step @@ -293,7 +301,8 @@ def make_test_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - test_step = jax.jit(self.test_step, donate_argnums=0) + test_step = jit(self.test_step, donate_argnums=0) + else: test_step = self.test_step @@ -310,7 +319,7 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - predict_step = jax.jit(predict_step, donate_argnums=0) + predict_step = jit(predict_step, donate_argnums=0) _step_function = self._make_function( predict_step, concatenate_outputs=True @@ -904,7 +913,7 @@ def _enforce_jax_state_sharding( Since the output of the train/eval step will be used as inputs to next step, we need to ensure that they have the same sharding spec, so that - jax.jit won't have to recompile the train/eval function. + nnx.jit/jax.jit won't have to recompile the train/eval function. Note that this function will also rely on the recorded sharding spec for each of states. diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index eaff1a8376a2..280a99506acf 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -38,6 +38,7 @@ from keras.src.backend.common.name_scope import current_path from keras.src.backend.common.remat import get_current_remat_mode from keras.src.backend.common.symbolic_scope import in_symbolic_scope +from keras.src.backend.config import is_nnx_enabled from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec @@ -53,7 +54,10 @@ if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer elif backend.backend() == "jax": - from keras.src.backend.jax.layer import JaxLayer as BackendLayer + if is_nnx_enabled(): + from keras.src.backend.jax.layer import NnxLayer as BackendLayer + else: + from keras.src.backend.jax.layer import JaxLayer as BackendLayer elif backend.backend() == "torch": from keras.src.backend.torch.layer import TorchLayer as BackendLayer elif backend.backend() == "numpy": @@ -220,7 +224,6 @@ def call(self, inputs): def __new__(cls, *args, **kwargs): obj = super().__new__(cls, *args, **kwargs) - # Wrap the user-provided `build` method in the `build_wrapper` # to add name scope support and serialization support. original_build_method = obj.build @@ -1533,7 +1536,19 @@ def __setattr__(self, name, value): if not hasattr(self, "_tracker"): self._initialize_tracker() value = self._tracker.track(value) - return super().__setattr__(name, value) + + # NNX-specific bypass for `_called` and `built` attributes + if ( + backend.backend() == "jax" + and is_nnx_enabled() + and (name == "_called" or name == "built") + ): + object.__setattr__(self, name, value) + return + + super().__setattr__( + name, value + ) # Default path, including for NnxLayer -> nnx.Module def __delattr__(self, name): obj = getattr(self, name) @@ -1646,8 +1661,18 @@ def get_config(self): return {**base_config, **config} def _open_name_scope(self): + from keras.src.utils import jax_utils # avoid circular imports + if self._parent_path is None: - self._parent_path = current_path() + # Avoid mutating _parent_path during a JAX trace if it's part of + # nnx.Object state and the object was created at a different trace + # level. We check if we are in NNX mode and if we are in a JAX + # trace. + if not (is_nnx_enabled() and jax_utils.is_in_jax_tracing_scope()): + try: + self._parent_path = current_path() + except Exception: + pass return backend.name_scope(self.name, caller=self) def rematerialized_call(self, layer_call, *args, **kwargs): diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index 9529a8e689f1..3b934761c866 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -6,6 +6,7 @@ from keras.src import tree from keras.src.api_export import keras_export from keras.src.backend.common.keras_tensor import any_symbolic_tensors +from keras.src.backend.config import is_nnx_enabled from keras.src.ops.node import Node from keras.src.utils import python_utils from keras.src.utils import traceback_utils @@ -118,7 +119,10 @@ def __new__(cls, *args, **kwargs): to manually implement `get_config()`. """ instance = super(Operation, cls).__new__(cls) + if backend.backend() == "jax" and is_nnx_enabled(): + from flax import nnx + vars(instance)["_object__state"] = nnx.object.ObjectState() # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index a996e9945cc8..a4dcbeab0d40 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -774,6 +774,8 @@ def _get_current_learning_rate(self): self._learning_rate, learning_rate_schedule.LearningRateSchedule ): return self._learning_rate(self._iterations) + elif isinstance(self._learning_rate, backend.Variable): + return self._learning_rate elif callable(self._learning_rate): return self._learning_rate() return self._learning_rate diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 7fd5763924b5..765263e82696 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,6 +9,5 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 -flax - +flax>=0.10.1 -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index 8d150a4e989e..730f1fb2601c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,6 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. jax[cpu]==0.5.0 -flax - +flax>=0.10.1 # Common deps. -r requirements-common.txt