Skip to content

Keras <> NNX integration #21252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: master
Choose a base branch
from

Conversation

divyashreepathihalli
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli commented May 5, 2025

The PR integrates NNX into JAX backend!

The following snippet shows how you would enable the nnx backend

import os
os.environ["KERAS_BACKEND"]="jax"
os.environ["KERAS_NNX_ENABLED"]="true"
import keras

Demo colab here : https://colab.sandbox.google.com/drive/1mK-4qbce2HGRIkcb4v5n4niWGDezL_6n#scrollTo=m-ZH9Mpnphfz
Added a github workflow action for nnx backend. Note this will fail - because this needs a new release of flax to work.

@divyashreepathihalli divyashreepathihalli marked this pull request as draft May 5, 2025 23:05
@codecov-commenter
Copy link

codecov-commenter commented May 5, 2025

Codecov Report

Attention: Patch coverage is 15.26316% with 161 lines in your changes missing coverage. Please review.

Project coverage is 77.57%. Comparing base (81821e0) to head (c05166e).
Report is 18 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/core.py 4.46% 105 Missing and 2 partials ⚠️
keras/src/backend/config.py 30.55% 20 Missing and 5 partials ⚠️
keras/src/layers/layer.py 30.76% 8 Missing and 1 partial ⚠️
keras/src/utils/jax_utils.py 22.22% 7 Missing ⚠️
keras/src/backend/jax/trainer.py 0.00% 5 Missing ⚠️
keras/src/backend/jax/layer.py 0.00% 4 Missing ⚠️
keras/src/backend/jax/__init__.py 50.00% 1 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/config/__init__.py 0.00% 1 Missing ⚠️
keras/src/backend/__init__.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (81821e0) and HEAD (c05166e). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (81821e0) HEAD (c05166e)
keras 5 4
keras-jax 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21252      +/-   ##
==========================================
- Coverage   82.65%   77.57%   -5.09%     
==========================================
  Files         565      565              
  Lines       54802    55062     +260     
  Branches     8508     8552      +44     
==========================================
- Hits        45297    42714    -2583     
- Misses       7413    10337    +2924     
+ Partials     2092     2011      -81     
Flag Coverage Δ
keras 77.41% <15.26%> (-5.05%) ⬇️
keras-jax ?
keras-numpy 58.59% <15.26%> (-0.16%) ⬇️
keras-openvino 33.48% <14.21%> (+0.33%) ⬆️
keras-tensorflow 63.81% <15.26%> (-0.19%) ⬇️
keras-torch 63.45% <15.26%> (-0.20%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

import jax.numpy as jnp

x = ops.ones(3)

@jax.jit
@nnx.jit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the integration prevent the use of jax.jit with Keras layers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! it would only work with nnx.jit for now ( They might be working on adding support for jax.jit)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added nnx as a opt in with this flag - os.environ["KERAS_NNX_ENABLED"]

@divyashreepathihalli divyashreepathihalli marked this pull request as ready for review June 4, 2025 05:56
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a testing plan

  • How do we make sure things work without flax installed on CI?
  • How do we make sure things work when NNX is enabled?
  • When NNX is not enabled?

We need to make sure we have an automated testing path for these different options we are writing logic for or they will silently break at some point.

@divyashreepathihalli
Copy link
Collaborator Author

divyashreepathihalli commented Jun 6, 2025

Added a github workflow action for nnx backend. Note this will FAIL - because this needs a new release of flax to work.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

- name: Install Flax for NNX backend
if: matrix.backend == 'nnx'
run: |
pip install flax --progress-bar off --upgrade
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, it's already part of the requirements file.

@@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
python-version: ['3.10']
backend: [tensorflow, jax, torch, numpy, openvino]
backend: [tensorflow, jax, torch, numpy, openvino, nnx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to do something else in addition (or instead of this?), right now, this does not turn on NNX.

elif "mutable" not in nnx_metadata:
nnx_metadata["mutable"] = actual_nnx_mutable

# Initialize nnx.Variable first.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a placeholder value? Does it allocate HBM? The issue is that we may be using twice the memory we need by initializing the variable twice.

else:
_placeholder_value = jnp.array(0.0, dtype=jnp.float32)

# Call nnx.Variable.__init__ directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the sharding on the placeholder_value, do we need it?


return self._maybe_autocast(current_value)

def __hash__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the discussion we had, we can't do that. It's not compatible with the semantics of __eq__ that we have.

I think my preferred fix would be to change NNX to not use variables as keys, but instead use id(variable) in dictionaries.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants