We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a4531e9 commit ed08e89Copy full SHA for ed08e89
kfac_jax/_src/utils/parallel.py
@@ -35,15 +35,7 @@ def in_pmap(axis_name: str | None) -> bool:
35
if axis_name is None:
36
return False
37
38
- try:
39
- # The only way to know if we are under `jax.pmap` is to check if the
40
- # function call below raises a `NameError` or not.
41
- core.axis_frame(axis_name)
42
-
43
- return True
44
45
- except NameError:
46
- return False
+ return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()
47
48
49
def wrap_if_pmap(
0 commit comments