1
+ import flax
1
2
import jax
2
3
import jax .experimental .sparse as jax_sparse
3
4
import jax .numpy as jnp
4
5
import ml_dtypes
5
6
import numpy as np
6
- from flax import nnx
7
7
8
8
from keras .src import tree
9
9
from keras .src .backend .common import KerasVariable
@@ -57,7 +57,7 @@ def __jax_array__(self):
57
57
return self .value
58
58
59
59
60
- class Variable (JaxVariable , nnx .Variable ):
60
+ class Variable (JaxVariable , flax . nnx .Variable ):
61
61
def __init__ (
62
62
self ,
63
63
initializer ,
@@ -123,7 +123,7 @@ def _complete_nnx_init(self):
123
123
current_nnx_mutable = self .trainable # A sensible default link
124
124
125
125
# initialize the nnx.Variable
126
- nnx .Variable .__init__ (
126
+ flax . nnx .Variable .__init__ (
127
127
self ,
128
128
value = self ._value ,
129
129
mutable = current_nnx_mutable ,
@@ -173,8 +173,8 @@ def value(self, new_value):
173
173
# Overriding NNX methods that modify `raw_value` or `_var_metadata` directly
174
174
# to ensure Keras's `_value` and other Keras states are in sync.
175
175
176
- def copy_from (self , other : nnx .Variable ): # type: ignore
177
- if not isinstance (other , nnx .Variable ): # Basic check from nnx
176
+ def copy_from (self , other : flax . nnx .Variable ): # type: ignore
177
+ if not isinstance (other , flax . nnx .Variable ): # Basic check from nnx
178
178
raise TypeError (
179
179
f"Expected nnx.Variable, got { type (other ).__name__ } "
180
180
)
@@ -184,12 +184,12 @@ def copy_from(self, other: nnx.Variable): # type: ignore
184
184
# Let nnx.Variable handle its part (updates self.raw_value and
185
185
# self._var_metadata)
186
186
# Need to call nnx.Variable.copy_from specifically.
187
- nnx .Variable .copy_from (self , other )
187
+ flax . nnx .Variable .copy_from (self , other )
188
188
189
189
# Now, self.raw_value is updated. Sync Keras's self._value.
190
190
# Extract the JAX array if raw_value is a nnx.mutable_array
191
191
keras_value_to_assign = self .raw_value
192
- if nnx .utils .is_mutable_array (keras_value_to_assign ):
192
+ if flax . nnx .utils .is_mutable_array (keras_value_to_assign ):
193
193
keras_value_to_assign = keras_value_to_assign .__array__ ()
194
194
195
195
self .assign (keras_value_to_assign )
@@ -205,11 +205,11 @@ def copy_from(self, other: nnx.Variable): # type: ignore
205
205
def update_from_state (self , variable_state : nnx .graph .VariableState ):
206
206
# Let nnx.Variable handle its part (updates self.raw_value and
207
207
# self._var_metadata)
208
- nnx .Variable .update_from_state (self , variable_state )
208
+ flax . nnx .Variable .update_from_state (self , variable_state )
209
209
210
210
# Sync Keras's self._value
211
211
keras_value_to_assign = self .raw_value
212
- if nnx .utils .is_mutable_array (keras_value_to_assign ):
212
+ if flax . nnx .utils .is_mutable_array (keras_value_to_assign ):
213
213
keras_value_to_assign = keras_value_to_assign .__array__ ()
214
214
215
215
self .assign (keras_value_to_assign )
@@ -242,7 +242,7 @@ def __getstate__(self):
242
242
"_nnx_metadata_arg" : self ._nnx_metadata_arg ,
243
243
"_nnx_init_pending" : self ._nnx_init_pending ,
244
244
}
245
- nnx_state = nnx .Variable .__getstate__ (self )
245
+ nnx_state = flax . nnx .Variable .__getstate__ (self )
246
246
return {"keras_state" : keras_state , "nnx_state" : nnx_state }
247
247
248
248
def __setstate__ (self , state ):
@@ -254,7 +254,7 @@ def __setstate__(self, state):
254
254
object .__setattr__ (self , k , v )
255
255
256
256
# Restore NNX attributes using its __setstate__
257
- nnx .Variable .__setstate__ (self , nnx_state )
257
+ flax . nnx .Variable .__setstate__ (self , nnx_state )
258
258
259
259
if (
260
260
self ._initializer is not None and self ._value is None
@@ -276,7 +276,7 @@ def __setstate__(self, state):
276
276
if current_nnx_mutable is None :
277
277
current_nnx_mutable = self .trainable
278
278
279
- if current_nnx_mutable and nnx .utils .is_mutable_array (
279
+ if current_nnx_mutable and flax . nnx .utils .is_mutable_array (
280
280
self .raw_value
281
281
):
282
282
self .raw_value [...] = self ._value
0 commit comments