With JAX 0.7.2, the following code errors with TypeError: cannot create weak reference to 'Flatten' object:
import jax.numpy as jnp
import optimistix as omx
def test(w):
return (w - jnp.array([5.0, 42.0, -2.0]))**2
solver = omx.LevenbergMarquardt(rtol=1e-4, atol=1e-4)
res = omx.least_squares(lambda w, _: test(w), solver, jnp.zeros(3))
print(res.value)
With JAX 0.7.2, the following code errors with
TypeError: cannot create weak reference to 'Flatten' object: