Skip to content

[Feature Request] Support for functorch transforms #223

Open
@marvinfriede

Description

@marvinfriede

Required prerequisites

Motivation

I am interested in Jacobians and Hessians from implicitly differentiated root finding problems. This is something that regularly comes up in scientific computing. With jax, this is already possible out of the box using function transforms (e.g., jacrev). Is this something you plan to support in torchopt, too?

Solution

I already tried, but apparently, there are some pieces of code that prevent this:

  • missing setup_context for vmap rule in ImplicitMetaGradient (very easy to adapt)
  • .item() in _vdot_real_kernel
  • make_rmatvec in normal_cg
  • conditionals in _cg_solve
  • tree operations in _cg_solve

Alternatives

The jaxopt version
import jax
import jax.numpy as jnp
from jaxopt.implicit_diff import custom_root
from jaxopt import Bisection

jax.config.update("jax_platform_name", "cpu")


def F(x, factor):
  return factor * x ** 3 - x - 2


def bisection_root_solver(init_x, factor):
  bisec = Bisection(optimality_fun=F, lower=1, upper=2)
  return bisec.run(factor=factor).params


@custom_root(F)
def custom_root_solver(init_x, factor):
    """Root solver using gradient descent."""
    maxiter = 100
    lr = 1e-1

    x = init_x
    for _ in range(maxiter):
        grad = F(x, factor)
        x = x - lr * grad

    return x


x_init = jnp.array(3.0)
fac = jnp.array(2.0)

print(custom_root_solver(x_init, fac))
print(bisection_root_solver(x_init, fac))

print(jax.grad(custom_root_solver, argnums=1)(x_init, fac))
print(jax.grad(bisection_root_solver, argnums=1)(x_init, fac))

custom_jac_fcn = jax.jacrev(custom_root_solver, argnums=1)
print(jax.jacrev(custom_jac_fcn, argnums=1)(x_init, fac))
bisection_jac_fcn = jax.jacrev(bisection_root_solver, argnums=1)
print(jax.jacrev(bisection_jac_fcn, argnums=1)(x_init, fac))

Additional context

No response

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions