Skip to content

Commit 5c92598

Browse files
authored
[MRG] Test Jax version in backend (#794)
* Fix test jax_version in backend * Fix test jax_version in backend
1 parent 3ee4386 commit 5c92598

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
2121
- Fix openmp flags on macOS (PR #789)
2222
- Clean documentation (PR #787)
2323
- Fix code coverage (PR #791)
24+
- Fix test of the version of jax in `ot.backend` (PR #794)
2425

2526

2627
## 0.9.6.post1

ot/backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,12 @@
122122
from jax.extend.backend import get_backend as _jax_get_backend
123123

124124
jax_type = jax.numpy.ndarray
125-
jax_new_version = float(".".join(jax.__version__.split(".")[1:])) > 4.24
125+
jax_new_version = tuple([float(s) for s in jax.__version__.split(".")]) > (
126+
0,
127+
4,
128+
24,
129+
0,
130+
)
126131
except ImportError:
127132
jax = False
128133
jax_type = float

0 commit comments

Comments
 (0)