Skip to content

Commit 6172cb5

Browse files
authored
Merge pull request #23 from btalamini/bugfix/fix_failing_test_due_to_incompatible_integer_types
Fix a test that fails on update of jax due to mismatched integer types
2 parents 8bc8b1b + 5b45c16 commit 6172cb5

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

optimism/TensorMath.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,9 +616,13 @@ def cond_f(loopData):
616616

617617
def compute_pade_degree(diff, j, itk):
618618
j += 1
619-
p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right')
619+
# Manually force the return type of searchsorted to be 64-bit int, because it
620+
# returns 32-bit ints, ignoring the global `jax_enable_x64` flag. This looks
621+
# like a bug. I filed an issue (#11375) with Jax to correct this.
622+
# If they fix it, the conversions on p and q can be removed.
623+
p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right').astype(np.int64)
620624
p += 2
621-
q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right')
625+
q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right').astype(np.int64)
622626
q += 2
623627
m,j,converged = if_then_else((2 * (p - q) // 3 < itk) | (j == 2),
624628
(p+1,j,True), (0,j,False))

0 commit comments

Comments
 (0)