-
Notifications
You must be signed in to change notification settings - Fork 94
Open
Description
When calling pymbar with both nbootstraps and initialize='BAR' set, i.e.
mbar = pymbar.MBAR(u_kn, N_k=N_k, solver_protocol='robust', initialize='BAR', n_bootstraps = nboots)
Then with jax enabled, we got an error:
raceback (most recent call last):
File "/home/XXXX/miniconda3/envs/cmp_gpu/lib/python3.12/site-packages/pymbar/mbar.py", line 1950, in _initialize_with_bar
f_k_init[l] = (
~~~~~~~~^^^
File "/home/XXXXXX/miniconda3/envs/cmp_gpu/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 592, in _unimplemented_setitem
raise TypeError(msg.format(type(self)))
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html
And this does not appear without jax. So some test case probably needs to be added here, as well as this being addressed at some point. It hasn't come up before, so might be a rare case.
Not sure priority or timeline for fix for now.