I want to use the look-ahead optimizer but it does not seem to fit the optax pattern exactly.
import optax
import jax
import jax.numpy as jnp
def fn_to_optimize(x):
x = x.fast
return jnp.sum((x) ** 2)
params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
# params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)
for step in range(100):
loss, grads = jax.value_and_grad(fn_to_optimize)(params)
updates, state = solver.update(grads, state, params)
params = optax.apply_updates(params, updates)
if step % 10 == 0:
print(f"Step {step}, Loss: {loss}, Params: {params}")
I would have expected it to be a drop in replacement like the other optimizers but that'd does not seem to be the case.
When I look at the source code (
|
def init_fn(params: base.Params) -> LookaheadState: |
), the init function is very clear.
However, the update function is confusing.
It expects
params:LookaheadParams unlike init
This seems to be right way to do it but this is not very intuitive.
import optax
import jax
import jax.numpy as jnp
def fn_to_optimize(x):
return jnp.sum((x) ** 2)
params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)
for step in range(100):
loss, grads = jax.value_and_grad(fn_to_optimize)(params.fast)
updates, state = solver.update(grads, state, params)
params = optax.apply_updates(params, updates)
if step % 1 == 0:
print(f"Step {step}, Loss: {loss}, Params: {params}")
I want to use the look-ahead optimizer but it does not seem to fit the optax pattern exactly.
I would have expected it to be a drop in replacement like the other optimizers but that'd does not seem to be the case.
When I look at the source code (
optax/optax/_src/lookahead.py
Line 99 in 5bd9095
However, the update function is confusing.
It expects
params:LookaheadParamsunlike initThis seems to be right way to do it but this is not very intuitive.