Skip to content

Keras <> NNX integration #21252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 65 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
8e1c008
_valu
divyashreepathihalli May 5, 2025
7159709
update variables
divyashreepathihalli May 7, 2025
e378cfb
add nnx.jit
divyashreepathihalli May 8, 2025
8701fc7
revert changes to JaxLayer
divyashreepathihalli May 8, 2025
b599727
Merge branch 'master' into nnx
divyashreepathihalli May 8, 2025
4f7b3b8
format fix
divyashreepathihalli May 8, 2025
0234e27
make variables subclass nnx.Variable
divyashreepathihalli May 13, 2025
b87c4f9
more tweaks
divyashreepathihalli May 13, 2025
91b9a73
update init
divyashreepathihalli May 27, 2025
aee1789
refactor jax Variable class
divyashreepathihalli May 28, 2025
141487f
code reformat
divyashreepathihalli May 28, 2025
5ccc31e
more cleanup
divyashreepathihalli May 28, 2025
4e35416
Merge branch 'keras-team:master' into nnx
divyashreepathihalli May 28, 2025
dd9c77d
update flax version
divyashreepathihalli May 28, 2025
48983c6
update flax version
divyashreepathihalli May 28, 2025
b22f9ef
fix jax error
divyashreepathihalli May 29, 2025
4dbffa6
update Variables implementation
divyashreepathihalli May 29, 2025
627e581
fix import
divyashreepathihalli May 29, 2025
f58ef60
add a test
divyashreepathihalli May 29, 2025
c2b73b7
needs updates in operation
divyashreepathihalli May 29, 2025
a662f5e
remove __new__ from JaxLayer
divyashreepathihalli May 30, 2025
396f973
update base optimizers
divyashreepathihalli May 30, 2025
30e971d
code reformat+ model saving tests
divyashreepathihalli May 30, 2025
968d804
add __hash__
divyashreepathihalli Jun 2, 2025
b99571a
update variable value updates
divyashreepathihalli Jun 2, 2025
ed0bc00
sync value properly
divyashreepathihalli Jun 2, 2025
460e0e2
update flag based routing between nnx and jax
divyashreepathihalli Jun 3, 2025
34f27e9
clean up
divyashreepathihalli Jun 3, 2025
427ff82
fix circular import error
divyashreepathihalli Jun 3, 2025
c4ee191
fix is nnx call enabled flag
divyashreepathihalli Jun 3, 2025
44414dc
attemptto fix circular import error
divyashreepathihalli Jun 3, 2025
0953d99
try again
divyashreepathihalli Jun 3, 2025
6d54a7e
fix import error
divyashreepathihalli Jun 4, 2025
64adbaf
reformat# Please enter the commit message for your changes. Lines sta…
divyashreepathihalli Jun 4, 2025
6454800
This has to fix it
divyashreepathihalli Jun 4, 2025
001f112
api gen
divyashreepathihalli Jun 4, 2025
5f26958
remove enable diisable configs -that does not work
divyashreepathihalli Jun 4, 2025
782c653
adrress some comments
divyashreepathihalli Jun 6, 2025
561f70a
update conditional imports
divyashreepathihalli Jun 6, 2025
e7caa03
fix tests
divyashreepathihalli Jun 6, 2025
8e3f460
add github workflow for nnx
divyashreepathihalli Jun 6, 2025
d70d51c
fix test
divyashreepathihalli Jun 6, 2025
38dbd4b
address comments
divyashreepathihalli Jun 6, 2025
1c60c5e
fix test
divyashreepathihalli Jun 6, 2025
74835fd
address comments
divyashreepathihalli Jun 6, 2025
6f11c0c
fix test
divyashreepathihalli Jun 6, 2025
c05166e
fix test -_-
divyashreepathihalli Jun 7, 2025
8582c7e
put the set attr in operation
divyashreepathihalli Jun 11, 2025
9471e4e
Merge branch 'master' into nnx
divyashreepathihalli Jun 11, 2025
297775a
fix jax error
divyashreepathihalli Jun 11, 2025
f01cc0d
fix trace error
divyashreepathihalli Jun 12, 2025
6810848
remove installation
divyashreepathihalli Jun 12, 2025
dc79329
import fixes
divyashreepathihalli Jun 12, 2025
f280dd0
update jax version
divyashreepathihalli Jun 12, 2025
75f9cc8
ugh the jax version issue
divyashreepathihalli Jun 12, 2025
68261d4
update jax version
divyashreepathihalli Jun 12, 2025
1e09246
update installations
divyashreepathihalli Jun 12, 2025
8a142a1
update jax utils
divyashreepathihalli Jun 12, 2025
c7b2347
another requirents file fix
divyashreepathihalli Jun 12, 2025
99d4307
fix test
divyashreepathihalli Jun 12, 2025
d544a0b
add back flax to req common
divyashreepathihalli Jun 12, 2025
3b8d90b
address review comments
divyashreepathihalli Jun 13, 2025
0e0fcd1
fix tests
divyashreepathihalli Jun 13, 2025
bd66ec8
fix tests address more comments
divyashreepathihalli Jun 13, 2025
8637c18
fix tests
divyashreepathihalli Jun 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
python-version: ['3.10']
backend: [tensorflow, jax, torch, numpy, openvino]
backend: [tensorflow, jax, torch, numpy, openvino, nnx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to do something else in addition (or instead of this?), right now, this does not turn on NNX.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, because I am seeing the nnx related error on this test workflow

name: Run tests
runs-on: ubuntu-latest
env:
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/config/nnx/keras.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"floatx": "float32",
"epsilon": 1e-07,
"backend": "jax",
"image_data_format": "channels_last",
"nnx_enabled": true
}
5 changes: 4 additions & 1 deletion guides/distributed_training_with_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from keras.src.backend.config import is_nnx_enabled
from keras.src.utils.jax_utils import jit
from flax import nnx


def get_model():
Expand Down Expand Up @@ -186,7 +189,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y):


# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
@jit
def train_step(train_state, x, y):
(
trainable_variables,
Expand Down
4 changes: 3 additions & 1 deletion integration_tests/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
"torch",
"--extra-index-url https://download.pytorch.org/whl/cpu ",
),
"jax": ("jax[cpu]", ""),
# please update the jax version here if jax version is updated in
# requirements file
"jax": ("jax[cpu]==0.5.0", ""),
"openvino": ("openvino", ""),
}

Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.backend.config import (
is_flash_attention_enabled as is_flash_attention_enabled,
)
from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled
from keras.src.backend.config import max_epochs as max_epochs
from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch
from keras.src.backend.config import set_epsilon as set_epsilon
Expand Down
1 change: 1 addition & 0 deletions keras/api/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.backend.config import (
is_flash_attention_enabled as is_flash_attention_enabled,
)
from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled
from keras.src.backend.config import max_epochs as max_epochs
from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch
from keras.src.backend.config import set_epsilon as set_epsilon
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from keras.src.backend.tensorflow.core import Variable as BackendVariable
elif backend() == "jax":
from keras.src.backend.jax import * # noqa: F403
from keras.src.backend.jax.core import Variable as BackendVariable
from keras.src.backend.jax import Variable as BackendVariable
elif backend() == "torch":
from keras.src.backend.torch import * # noqa: F403
from keras.src.backend.torch.core import Variable as BackendVariable
Expand Down
94 changes: 70 additions & 24 deletions keras/src/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# Default backend: TensorFlow.
_BACKEND = "tensorflow"

# Whether NNX is enabled.
_NNX_ENABLED = False

# Cap run duration for debugging.
_MAX_EPOCHS = None
_MAX_STEPS_PER_EPOCH = None
Expand Down Expand Up @@ -230,6 +233,33 @@ def is_flash_attention_enabled():
return global_state.get_global_attribute("flash_attention", default=None)


@keras_export("keras.config.is_nnx_enabled")
def is_nnx_enabled():
"""Checks whether NNX specific features are enabled for the JAX backend.

Returns:
bool: `True` if NNX backend features are enabled, `False` otherwise.
Defaults to `False`.
"""
return _NNX_ENABLED


def set_nnx_enabled(value):
global _NNX_ENABLED
from keras.src.backend.common import global_state

_NNX_ENABLED = bool(value)
if _NNX_ENABLED:
try:
from flax import nnx # noqa F401
except ImportError:
raise ImportError(
"To use the NNX backend, you must install `flax`."
"Try: `pip install flax`"
)
global_state.set_global_attribute("nnx_enabled", bool(value))


def standardize_data_format(data_format):
if data_format is None:
return image_data_format()
Expand Down Expand Up @@ -261,6 +291,7 @@ def keras_home():

# Attempt to read Keras config file.
_config_path = os.path.expanduser(os.path.join(_KERAS_DIR, "keras.json"))

if os.path.exists(_config_path):
try:
with open(_config_path) as f:
Expand All @@ -274,36 +305,18 @@ def keras_home():
_backend = _config.get("backend", _BACKEND)
_image_data_format = _config.get("image_data_format", image_data_format())
assert _image_data_format in {"channels_last", "channels_first"}
_nnx_enabled_config = _config.get("nnx_enabled", _NNX_ENABLED)
if not isinstance(_nnx_enabled_config, bool):
_NNX_ENABLED = str(_nnx_enabled_config).lower() == "true"
else:
_NNX_ENABLED = _nnx_enabled_config

# Apply basic configs that don't cause circular import
set_floatx(_floatx)
set_epsilon(_epsilon)
set_image_data_format(_image_data_format)
_BACKEND = _backend

# Save config file, if possible.
if not os.path.exists(_KERAS_DIR):
try:
os.makedirs(_KERAS_DIR)
except OSError:
# Except permission denied and potential race conditions
# in multi-threaded environments.
pass

if not os.path.exists(_config_path):
_config = {
"floatx": floatx(),
"epsilon": epsilon(),
"backend": _BACKEND,
"image_data_format": image_data_format(),
}
try:
with open(_config_path, "w") as f:
f.write(json.dumps(_config, indent=4))
except IOError:
# Except permission denied.
pass

# Set backend based on KERAS_BACKEND flag, if applicable.
if "KERAS_BACKEND" in os.environ:
_backend = os.environ["KERAS_BACKEND"]
if _backend:
Expand All @@ -313,6 +326,7 @@ def keras_home():
if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ:
_MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"])


if _BACKEND != "tensorflow":
# If we are not running on the tensorflow backend, we should stop tensorflow
# from using all available GPU memory. See
Expand Down Expand Up @@ -403,3 +417,35 @@ def max_steps_per_epoch():
`None`, no limit is applied.
"""
return _MAX_STEPS_PER_EPOCH


if not os.path.exists(_KERAS_DIR):
try:
os.makedirs(_KERAS_DIR)
except OSError:
# Except permission denied and potential race conditions
pass

if not os.path.exists(_config_path):
_config_to_save = {
"floatx": floatx(),
"epsilon": epsilon(),
"backend": _BACKEND, # Use the final _BACKEND value
"image_data_format": image_data_format(),
"nnx_enabled": _NNX_ENABLED,
}
try:
with open(_config_path, "w") as f:
f.write(json.dumps(_config_to_save, indent=4))
except IOError:
# Except permission denied.
pass

if "KERAS_NNX_ENABLED" in os.environ:
env_val = os.environ["KERAS_NNX_ENABLED"].lower()
if env_val == "true":
_NNX_ENABLED = True
elif env_val == "false":
_NNX_ENABLED = False

set_nnx_enabled(_NNX_ENABLED)
7 changes: 6 additions & 1 deletion keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax import core
from keras.src.backend.jax import distribution_lib
from keras.src.backend.jax import image
Expand All @@ -10,7 +11,11 @@
from keras.src.backend.jax.core import IS_THREAD_SAFE
from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS
from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.jax.core import Variable

if is_nnx_enabled():
from keras.src.backend.jax.core import NnxVariable as Variable
else:
from keras.src.backend.jax.core import JaxVariable as Variable
from keras.src.backend.jax.core import cast
from keras.src.backend.jax.core import compute_output_spec
from keras.src.backend.jax.core import cond
Expand Down
Loading
Loading