33
33
reason = "Backend specific test" ,
34
34
)
35
35
class JaxDistributionLibTest (testing .TestCase ):
36
+ def _create_jax_layout (self , sharding ):
37
+ # Use jax_layout.Format or jax_layout.Layout if available.
38
+ if hasattr (jax_layout , "Format" ):
39
+ return jax_layout .Format (sharding = sharding )
40
+ elif hasattr (jax_layout , "Layout" ):
41
+ return jax_layout .Layout (sharding = sharding )
42
+
43
+ return sharding
44
+
36
45
def test_list_devices (self ):
37
46
self .assertEqual (len (distribution_lib .list_devices ()), 8 )
38
47
self .assertEqual (len (distribution_lib .list_devices ("cpu" )), 8 )
@@ -132,7 +141,7 @@ def test_distribute_tensor_with_jax_layout(self):
132
141
)
133
142
134
143
inputs = jax .numpy .array (np .random .normal (size = (16 , 8 )))
135
- target_layout = jax_layout . Layout (
144
+ target_layout = self . _create_jax_layout (
136
145
sharding = jax .sharding .NamedSharding (
137
146
jax_mesh , jax .sharding .PartitionSpec ("batch" , None )
138
147
)
@@ -163,7 +172,7 @@ def test_distribute_variable_with_jax_layout(self):
163
172
)
164
173
165
174
variable = jax .numpy .array (np .random .normal (size = (16 , 8 )))
166
- target_layout = jax_layout . Layout (
175
+ target_layout = self . _create_jax_layout (
167
176
sharding = jax .sharding .NamedSharding (
168
177
jax_mesh , jax .sharding .PartitionSpec ("model" , None )
169
178
)
@@ -184,7 +193,7 @@ def test_distribute_input_data_with_jax_layout(self):
184
193
)
185
194
186
195
input_data = jax .numpy .array (np .random .normal (size = (16 , 8 )))
187
- target_layout = jax_layout . Layout (
196
+ target_layout = self . _create_jax_layout (
188
197
sharding = jax .sharding .NamedSharding (
189
198
jax_mesh , jax .sharding .PartitionSpec ("batch" , None )
190
199
)
0 commit comments