-
Notifications
You must be signed in to change notification settings - Fork 40
Description
Like #30, I am trying to batch over different starting values when the obj function isn't globally convex.
I am coming from jax, where you setup the objective function and then use vmap to signal which axes to batch over.
Here is what I have for a simple example (function is globally convex, but for demo purposes go ahead anyway):
from torchmin import minimize as pyt_minimize
def objfun(x):
return .1 * x + 3 * x ** 2We will setup starting values for x at three different points:
init_batched = torch.tensor([[1.], [3.], [-1.5]])where the first axes is the batching dimension (shape[0]=# of different starting values).
Setting up vmap for the function and evaluating it, shows things work as expected:
batched_obj_fun = torch.vmap(objfun, in_dims=0)
batched_obj_fun(init_batched)yields
: tensor([[ 3.1000],
: [27.3000],
: [ 6.6000]], device='cuda:0')
But using this vmap'd function with pytorch-minimize isn't working:
res = pyt_minimize(lambda parms: objfun(parms),
init_batched, method='bfgs', tol=1e-5, disp=True)
throws this error:
<lots of trace>
RuntimeError: ScalarFunction was supplied a function that does not return scalar outputs.
So I tried to vmap a wrapped pyt_minimize call:
def minimize_fun(coords):
res = pyt_minimize(lambda parms: objfun(parms),
coords, method='bfgs', tol=1e-5, disp=True)
return res
batched_minimize = torch.vmap(minimize_fun, in_dims=0)and call it to do the batch minimization:
batched_minimize(init_batched)with this error:
RuntimeError: You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function being transformed by a functorch transform. This is unsupported, please attempt to use the functorch transforms (e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() outside of a function being transformed instead.
From #30 this is possible but maybe the "jax" way is the wrong way to proceed....