Skip to content

Commit a1ac51b

Browse files
fix import issue
1 parent 6a411ef commit a1ac51b

File tree

6 files changed

+12
-19
lines changed

6 files changed

+12
-19
lines changed

examples/demo_jax_distributed.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import jax.numpy as jnp
1313
import tensorflow as tf # just for tf.data
1414
import keras # Keras multi-backend
15-
from flax import nnx
1615
import numpy as np
1716
from tqdm import tqdm
1817

@@ -264,7 +263,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y):
264263

265264

266265
# Training step: Keras provides a pure functional optimizer.stateless_apply
267-
@nnx.jit
266+
@jax.jit
268267
def train_step(train_state, x, y):
269268
(loss_value, non_trainable_variables), grads = compute_gradients(
270269
train_state.trainable_variables,
@@ -302,7 +301,7 @@ def train_step(train_state, x, y):
302301
sharded_data = jax.device_put(data.numpy(), data_sharding)
303302

304303

305-
@nnx.jit
304+
@jax.jit
306305
def predict(data):
307306
predictions, updated_non_trainable_variables = model.stateless_call(
308307
device_train_state.trainable_variables,

guides/distributed_training_with_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
import numpy as np
4949
import tensorflow as tf
5050
import keras
51-
from flax import nnx
51+
import flax
5252
from jax.experimental import mesh_utils
5353
from jax.sharding import Mesh
5454
from jax.sharding import NamedSharding
@@ -186,7 +186,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y):
186186

187187

188188
# Training step, Keras provides a pure functional optimizer.stateless_apply
189-
@nnx.jit
189+
@flax.jax.jit
190190
def train_step(train_state, x, y):
191191
(
192192
trainable_variables,

keras/src/backend/jax/layer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from flax import nnx
2-
3-
41
class JaxLayer(nnx.Module):
52
def __init_subclass__(cls):
63
super().__init_subclass__()

keras/src/backend/jax/trainer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import jax
77
import numpy as np
8-
from flax import nnx
98

109
from keras.src import backend
1110
from keras.src import callbacks as callbacks_module
@@ -234,7 +233,7 @@ def concatenate(outputs):
234233
return output
235234

236235
if not self.run_eagerly and self.jit_compile:
237-
concatenate = nnx.jit(concatenate)
236+
concatenate = jax.jit(concatenate)
238237

239238
def iterator_step(state, iterator):
240239
data = next(iterator)
@@ -278,7 +277,7 @@ def make_train_function(self, force=False):
278277
# so that jax will reuse the memory buffer for outputs.
279278
# This will reduce the memory usage of the training function by
280279
# half.
281-
train_step = nnx.jit(self.train_step, donate_argnums=0)
280+
train_step = jax.jit(self.train_step, donate_argnums=0)
282281
else:
283282
train_step = self.train_step
284283

@@ -294,7 +293,7 @@ def make_test_function(self, force=False):
294293
# so that jax will reuse the memory buffer for outputs.
295294
# This will reduce the memory usage of the training function by
296295
# half.
297-
test_step = nnx.jit(self.test_step, donate_argnums=0)
296+
test_step = jax.jit(self.test_step, donate_argnums=0)
298297
else:
299298
test_step = self.test_step
300299

@@ -311,7 +310,7 @@ def predict_step(state, data):
311310
return outputs, (state[0], non_trainable_variables)
312311

313312
if not self.run_eagerly and self.jit_compile:
314-
predict_step = nnx.jit(predict_step, donate_argnums=0)
313+
predict_step = jax.jit(predict_step, donate_argnums=0)
315314

316315
_step_function = self._make_function(
317316
predict_step, concatenate_outputs=True
@@ -905,7 +904,7 @@ def _enforce_jax_state_sharding(
905904
906905
Since the output of the train/eval step will be used as inputs to next
907906
step, we need to ensure that they have the same sharding spec, so that
908-
nnx.jit won't have to recompile the train/eval function.
907+
jax.jit won't have to recompile the train/eval function.
909908
910909
Note that this function will also rely on the recorded sharding spec
911910
for each of states.

keras/src/random/random_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import pytest
33
from absl.testing import parameterized
4-
from flax import nnx
54

65
import keras
76
from keras.src import backend
@@ -385,7 +384,7 @@ def test_dropout_jax_jit_stateless(self):
385384

386385
x = ops.ones(3)
387386

388-
@nnx.jit
387+
@jax.jit
389388
def train_step(x):
390389
with keras.src.backend.StatelessScope():
391390
x = keras.layers.Dropout(rate=0.1)(x, training=True)
@@ -414,7 +413,7 @@ def test_jax_rngkey_seed(self):
414413
reason="This test requires `jax` as the backend.",
415414
)
416415
def test_jax_unseed_disallowed_during_tracing(self):
417-
@nnx.jit
416+
@jax.jit
418417
def jit_fn():
419418
return random.randint((2, 2), 0, 10, seed=None)
420419

keras/src/random/seed_generator_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
import pytest
3-
from flax import nnx
43

54
from keras.src import backend
65
from keras.src import ops
@@ -79,7 +78,7 @@ def test_seed_generator_unexpected_kwargs(self):
7978
backend.backend() != "jax", reason="This test requires the JAX backend"
8079
)
8180
def test_jax_tracing_with_global_seed_generator(self):
82-
@nnx.jit
81+
@jax.jit
8382
def traced_function():
8483
return seed_generator.global_seed_generator().next()
8584

0 commit comments

Comments
 (0)