Skip to content

Batch Optimization #35

@roblem

Description

@roblem

Like #30, I am trying to batch over different starting values when the obj function isn't globally convex.

I am coming from jax, where you setup the objective function and then use vmap to signal which axes to batch over.

Here is what I have for a simple example (function is globally convex, but for demo purposes go ahead anyway):

from torchmin import minimize as pyt_minimize

def objfun(x):
    return .1 * x + 3 * x ** 2

We will setup starting values for x at three different points:

init_batched = torch.tensor([[1.], [3.], [-1.5]])

where the first axes is the batching dimension (shape[0]=# of different starting values).

Setting up vmap for the function and evaluating it, shows things work as expected:

batched_obj_fun = torch.vmap(objfun, in_dims=0)
batched_obj_fun(init_batched)

yields

: tensor([[ 3.1000],
:         [27.3000],
:         [ 6.6000]], device='cuda:0')

But using this vmap'd function with pytorch-minimize isn't working:

res = pyt_minimize(lambda parms: objfun(parms),
             init_batched, method='bfgs', tol=1e-5, disp=True)

throws this error:

<lots of trace>
RuntimeError: ScalarFunction was supplied a function that does not return scalar outputs.

So I tried to vmap a wrapped pyt_minimize call:

def minimize_fun(coords):
    res = pyt_minimize(lambda parms: objfun(parms),
               coords, method='bfgs', tol=1e-5, disp=True)
    return res

batched_minimize = torch.vmap(minimize_fun, in_dims=0)

and call it to do the batch minimization:

batched_minimize(init_batched)

with this error:

RuntimeError: You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function being transformed by a functorch transform. This is unsupported, please attempt to use the functorch transforms (e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() outside of a function being transformed instead.

From #30 this is possible but maybe the "jax" way is the wrong way to proceed....

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions