Skip to content

Commit e0bf7f8

Browse files
authored
Remove references to jax.experimental.layout.Layout. (#21400)
It is being renamed to jax.experimental.layout.Format. The comments can be kept generic anyways to support other layout instances in the future.
1 parent e99164e commit e0bf7f8

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import jax
44
import numpy as np
5-
from jax.experimental import layout as jax_layout
65

76
from keras.src.backend.common import global_state
87
from keras.src.random import seed_generator
@@ -40,8 +39,7 @@ def distribute_variable(value, layout):
4039
Args:
4140
value: the initial value of the variable.
4241
layout: `TensorLayout` for the created variable, or a
43-
JAX-supported layout instance
44-
(e.g. `jax.experimental.layout.Layout`, `jax.sharding.Sharding`).
42+
JAX-supported layout instance (e.g. `jax.sharding.Sharding`).
4543
4644
Returns:
4745
jax.Array which is the distributed variable.
@@ -58,8 +56,7 @@ def distribute_tensor(tensor, layout):
5856
Args:
5957
tensor: `jax.Array` that need to be distributed.
6058
layout: `TensorLayout` for the created variable, or a
61-
JAX-supported layout instance
62-
(e.g. `jax.experimental.layout.Layout`, `jax.sharding.Sharding`).
59+
JAX-supported layout instance (e.g. `jax.sharding.Sharding`).
6360
6461
Returns:
6562
Distributed value.
@@ -81,7 +78,8 @@ def distribute_tensor(tensor, layout):
8178
layout, jax.sharding.Sharding
8279
) and tensor.sharding.is_equivalent_to(layout, ndim=len(tensor.shape)):
8380
return tensor
84-
elif isinstance(layout, jax_layout.Layout):
81+
# JAX explicit layout support.
82+
elif hasattr(layout, "layout"):
8583
current_layout = getattr(tensor, "layout", None)
8684
if current_layout == layout:
8785
return tensor

keras/src/backend/jax/distribution_lib_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@
3333
reason="Backend specific test",
3434
)
3535
class JaxDistributionLibTest(testing.TestCase):
36+
def _create_jax_layout(self, sharding):
37+
# Use jax_layout.Format or jax_layout.Layout if available.
38+
if hasattr(jax_layout, "Format"):
39+
return jax_layout.Format(sharding=sharding)
40+
elif hasattr(jax_layout, "Layout"):
41+
return jax_layout.Layout(sharding=sharding)
42+
43+
return sharding
44+
3645
def test_list_devices(self):
3746
self.assertEqual(len(distribution_lib.list_devices()), 8)
3847
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
@@ -132,7 +141,7 @@ def test_distribute_tensor_with_jax_layout(self):
132141
)
133142

134143
inputs = jax.numpy.array(np.random.normal(size=(16, 8)))
135-
target_layout = jax_layout.Layout(
144+
target_layout = self._create_jax_layout(
136145
sharding=jax.sharding.NamedSharding(
137146
jax_mesh, jax.sharding.PartitionSpec("batch", None)
138147
)
@@ -163,7 +172,7 @@ def test_distribute_variable_with_jax_layout(self):
163172
)
164173

165174
variable = jax.numpy.array(np.random.normal(size=(16, 8)))
166-
target_layout = jax_layout.Layout(
175+
target_layout = self._create_jax_layout(
167176
sharding=jax.sharding.NamedSharding(
168177
jax_mesh, jax.sharding.PartitionSpec("model", None)
169178
)
@@ -184,7 +193,7 @@ def test_distribute_input_data_with_jax_layout(self):
184193
)
185194

186195
input_data = jax.numpy.array(np.random.normal(size=(16, 8)))
187-
target_layout = jax_layout.Layout(
196+
target_layout = self._create_jax_layout(
188197
sharding=jax.sharding.NamedSharding(
189198
jax_mesh, jax.sharding.PartitionSpec("batch", None)
190199
)

0 commit comments

Comments
 (0)