Skip to content

Scalable optimization for large data #201

@stli

Description

@stli

First a big thanks to all maintainers for this great library. I am trying to optimize a cost function on potentially larger scale input data for cases where classical mini-batch optimizers converge very slowly. For mini-batch optimizers (e.g. optax), this could be easily done using a dataloader (e.g. grain). However, it seems a bit tricky to do in optimistix. Hoping for some insights and/or ideas how to achieve this with optimistix.

Below is an MWE and a list of my considered variants

  1. gpu sharding (needs more physical gpus)
  2. move to cpu (plenty of mem but slow)
  3. alternating solves for different data batches (needs some logic to ensure convergence)
  4. alternating solves for different nn layers (needs some logic to ensure convergence)
  5. dataloader to calculate cost function iteratively (seems ideal but difficult to integrate)
import jax
import jax.numpy as jnp
import optimistix as optx
import jax.random as jr
from jax.example_libraries import stax

jax.config.update("jax_enable_x64", True)

key = jr.PRNGKey(0)
dim = 10
num_samples = 100_000 # some large number

X = jr.normal(key, (num_samples, dim))
y = jr.normal(key, (num_samples, 1))

activation = stax.Sigmoid
layer_sizes = [80,]*3
init_fun, predict_fun = stax.serial(
    stax.serial( *(sum([[stax.Dense(size), activation] for size in layer_sizes],[],)) + [stax.Dense(1)] ) )
_, params = init_fun(key, (X.shape[1],))

def loss(params, args):
    X, y = args # <- ideally an iterator + fori loop over data batches here
    return jnp.mean( (predict_fun(params, X).flatten() - y.flatten())**2 )

bfgs_tol = 1e-12
solver = optx.BFGS(rtol=bfgs_tol, atol=bfgs_tol)
sol = optx.minimise(
    loss,
    solver,
    params,
    max_steps=100_000,
    throw=False,
    args=(X,y) # <- ideally a dataloader here
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions