### Report Can be fixed by downgrading jax to 0.6.2, e.g. via: `pip install "jax==0.6.2"` (restarting environment might be needed) ### Version information _No response_