Open
Description
Required prerequisites
- I have searched the Issue Tracker and Discussions that this hasn't already been reported. (+1 or comment there if it has.)
- Consider asking first in a Discussion.
Motivation
I am interested in Jacobians and Hessians from implicitly differentiated root finding problems. This is something that regularly comes up in scientific computing. With jax, this is already possible out of the box using function transforms (e.g., jacrev). Is this something you plan to support in torchopt, too?
Solution
I already tried, but apparently, there are some pieces of code that prevent this:
- missing
setup_context
for vmap rule inImplicitMetaGradient
(very easy to adapt) .item()
in_vdot_real_kernel
make_rmatvec
innormal_cg
- conditionals in
_cg_solve
- tree operations in
_cg_solve
Alternatives
The jaxopt version
import jax
import jax.numpy as jnp
from jaxopt.implicit_diff import custom_root
from jaxopt import Bisection
jax.config.update("jax_platform_name", "cpu")
def F(x, factor):
return factor * x ** 3 - x - 2
def bisection_root_solver(init_x, factor):
bisec = Bisection(optimality_fun=F, lower=1, upper=2)
return bisec.run(factor=factor).params
@custom_root(F)
def custom_root_solver(init_x, factor):
"""Root solver using gradient descent."""
maxiter = 100
lr = 1e-1
x = init_x
for _ in range(maxiter):
grad = F(x, factor)
x = x - lr * grad
return x
x_init = jnp.array(3.0)
fac = jnp.array(2.0)
print(custom_root_solver(x_init, fac))
print(bisection_root_solver(x_init, fac))
print(jax.grad(custom_root_solver, argnums=1)(x_init, fac))
print(jax.grad(bisection_root_solver, argnums=1)(x_init, fac))
custom_jac_fcn = jax.jacrev(custom_root_solver, argnums=1)
print(jax.jacrev(custom_jac_fcn, argnums=1)(x_init, fac))
bisection_jac_fcn = jax.jacrev(bisection_root_solver, argnums=1)
print(jax.jacrev(bisection_jac_fcn, argnums=1)(x_init, fac))
Additional context
No response