Skip to content

Possible JAX-incompatible code path? #557

@mrshirts

Description

@mrshirts

When calling pymbar with both nbootstraps and initialize='BAR' set, i.e.

mbar = pymbar.MBAR(u_kn, N_k=N_k, solver_protocol='robust', initialize='BAR', n_bootstraps = nboots)

Then with jax enabled, we got an error:

raceback (most recent call last):
  File "/home/XXXX/miniconda3/envs/cmp_gpu/lib/python3.12/site-packages/pymbar/mbar.py", line 1950, in _initialize_with_bar
    f_k_init[l] = (
    ~~~~~~~~^^^
  File "/home/XXXXXX/miniconda3/envs/cmp_gpu/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 592, in _unimplemented_setitem
    raise TypeError(msg.format(type(self)))
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

And this does not appear without jax. So some test case probably needs to be added here, as well as this being addressed at some point. It hasn't come up before, so might be a rare case.

Not sure priority or timeline for fix for now.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions