Skip to content

Commit ed08e89

Browse files
dougalmKfacJaxDev
authored andcommitted
Stackless yashful
PiperOrigin-RevId: 681582933
1 parent a4531e9 commit ed08e89

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

kfac_jax/_src/utils/parallel.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,7 @@ def in_pmap(axis_name: str | None) -> bool:
3535
if axis_name is None:
3636
return False
3737

38-
try:
39-
# The only way to know if we are under `jax.pmap` is to check if the
40-
# function call below raises a `NameError` or not.
41-
core.axis_frame(axis_name)
42-
43-
return True
44-
45-
except NameError:
46-
return False
38+
return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()
4739

4840

4941
def wrap_if_pmap(

0 commit comments

Comments
 (0)