-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Labels
questionUser queriesUser queries
Description
First a big thanks to all maintainers for this great library. I am trying to optimize a cost function on potentially larger scale input data for cases where classical mini-batch optimizers converge very slowly. For mini-batch optimizers (e.g. optax), this could be easily done using a dataloader (e.g. grain). However, it seems a bit tricky to do in optimistix. Hoping for some insights and/or ideas how to achieve this with optimistix.
Below is an MWE and a list of my considered variants
- gpu sharding (needs more physical gpus)
- move to cpu (plenty of mem but slow)
- alternating solves for different data batches (needs some logic to ensure convergence)
- alternating solves for different nn layers (needs some logic to ensure convergence)
- dataloader to calculate cost function iteratively (seems ideal but difficult to integrate)
import jax
import jax.numpy as jnp
import optimistix as optx
import jax.random as jr
from jax.example_libraries import stax
jax.config.update("jax_enable_x64", True)
key = jr.PRNGKey(0)
dim = 10
num_samples = 100_000 # some large number
X = jr.normal(key, (num_samples, dim))
y = jr.normal(key, (num_samples, 1))
activation = stax.Sigmoid
layer_sizes = [80,]*3
init_fun, predict_fun = stax.serial(
stax.serial( *(sum([[stax.Dense(size), activation] for size in layer_sizes],[],)) + [stax.Dense(1)] ) )
_, params = init_fun(key, (X.shape[1],))
def loss(params, args):
X, y = args # <- ideally an iterator + fori loop over data batches here
return jnp.mean( (predict_fun(params, X).flatten() - y.flatten())**2 )
bfgs_tol = 1e-12
solver = optx.BFGS(rtol=bfgs_tol, atol=bfgs_tol)
sol = optx.minimise(
loss,
solver,
params,
max_steps=100_000,
throw=False,
args=(X,y) # <- ideally a dataloader here
)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries