Skip to content

Commit 6a411ef

Browse files
fix import error
1 parent 4e35416 commit 6a411ef

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

keras/src/backend/jax/core.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import flax
12
import jax
23
import jax.experimental.sparse as jax_sparse
34
import jax.numpy as jnp
45
import ml_dtypes
56
import numpy as np
6-
from flax import nnx
77

88
from keras.src import tree
99
from keras.src.backend.common import KerasVariable
@@ -57,7 +57,7 @@ def __jax_array__(self):
5757
return self.value
5858

5959

60-
class Variable(JaxVariable, nnx.Variable):
60+
class Variable(JaxVariable, flax.nnx.Variable):
6161
def __init__(
6262
self,
6363
initializer,
@@ -123,7 +123,7 @@ def _complete_nnx_init(self):
123123
current_nnx_mutable = self.trainable # A sensible default link
124124

125125
# initialize the nnx.Variable
126-
nnx.Variable.__init__(
126+
flax.nnx.Variable.__init__(
127127
self,
128128
value=self._value,
129129
mutable=current_nnx_mutable,
@@ -173,8 +173,8 @@ def value(self, new_value):
173173
# Overriding NNX methods that modify `raw_value` or `_var_metadata` directly
174174
# to ensure Keras's `_value` and other Keras states are in sync.
175175

176-
def copy_from(self, other: nnx.Variable): # type: ignore
177-
if not isinstance(other, nnx.Variable): # Basic check from nnx
176+
def copy_from(self, other: flax.nnx.Variable): # type: ignore
177+
if not isinstance(other, flax.nnx.Variable): # Basic check from nnx
178178
raise TypeError(
179179
f"Expected nnx.Variable, got {type(other).__name__}"
180180
)
@@ -184,12 +184,12 @@ def copy_from(self, other: nnx.Variable): # type: ignore
184184
# Let nnx.Variable handle its part (updates self.raw_value and
185185
# self._var_metadata)
186186
# Need to call nnx.Variable.copy_from specifically.
187-
nnx.Variable.copy_from(self, other)
187+
flax.nnx.Variable.copy_from(self, other)
188188

189189
# Now, self.raw_value is updated. Sync Keras's self._value.
190190
# Extract the JAX array if raw_value is a nnx.mutable_array
191191
keras_value_to_assign = self.raw_value
192-
if nnx.utils.is_mutable_array(keras_value_to_assign):
192+
if flax.nnx.utils.is_mutable_array(keras_value_to_assign):
193193
keras_value_to_assign = keras_value_to_assign.__array__()
194194

195195
self.assign(keras_value_to_assign)
@@ -205,11 +205,11 @@ def copy_from(self, other: nnx.Variable): # type: ignore
205205
def update_from_state(self, variable_state: nnx.graph.VariableState):
206206
# Let nnx.Variable handle its part (updates self.raw_value and
207207
# self._var_metadata)
208-
nnx.Variable.update_from_state(self, variable_state)
208+
flax.nnx.Variable.update_from_state(self, variable_state)
209209

210210
# Sync Keras's self._value
211211
keras_value_to_assign = self.raw_value
212-
if nnx.utils.is_mutable_array(keras_value_to_assign):
212+
if flax.nnx.utils.is_mutable_array(keras_value_to_assign):
213213
keras_value_to_assign = keras_value_to_assign.__array__()
214214

215215
self.assign(keras_value_to_assign)
@@ -242,7 +242,7 @@ def __getstate__(self):
242242
"_nnx_metadata_arg": self._nnx_metadata_arg,
243243
"_nnx_init_pending": self._nnx_init_pending,
244244
}
245-
nnx_state = nnx.Variable.__getstate__(self)
245+
nnx_state = flax.nnx.Variable.__getstate__(self)
246246
return {"keras_state": keras_state, "nnx_state": nnx_state}
247247

248248
def __setstate__(self, state):
@@ -254,7 +254,7 @@ def __setstate__(self, state):
254254
object.__setattr__(self, k, v)
255255

256256
# Restore NNX attributes using its __setstate__
257-
nnx.Variable.__setstate__(self, nnx_state)
257+
flax.nnx.Variable.__setstate__(self, nnx_state)
258258

259259
if (
260260
self._initializer is not None and self._value is None
@@ -276,7 +276,7 @@ def __setstate__(self, state):
276276
if current_nnx_mutable is None:
277277
current_nnx_mutable = self.trainable
278278

279-
if current_nnx_mutable and nnx.utils.is_mutable_array(
279+
if current_nnx_mutable and flax.nnx.utils.is_mutable_array(
280280
self.raw_value
281281
):
282282
self.raw_value[...] = self._value

0 commit comments

Comments
 (0)