Skip to content

Check last point when using best so far minimiser#34

Open
ColCarroll wants to merge 1 commit intopatrick-kidger:mainfrom
ColCarroll:bestsofar
Open

Check last point when using best so far minimiser#34
ColCarroll wants to merge 1 commit intopatrick-kidger:mainfrom
ColCarroll:bestsofar

Conversation

@ColCarroll
Copy link
Contributor

Fixes #33

There is some unnecessary looking lines

        best_f, best_aux = fn(state.best_y, args)
        best_loss = self._to_loss(state.best_y, best_f)

where I would expect to just use state.best_loss, but the test doesn't pass without it!

Comment on lines +119 to +120
best_f, best_aux = fn(state.best_y, args)
best_loss = self._to_loss(state.best_y, best_f)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I definitely don't understand these lines. What do you think is going on?

solver = optx.BestSoFarMinimiser(optx.BFGS(rtol=1e-5, atol=1e-5))
sol = optx.minimise(fn, solver, jnp.array(0.0))
assert sol.value == 3.0
# assert fn(sol.value, None) <= fn(sol.state.state.y_eval, None)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this line?

@patrick-kidger
Copy link
Owner

Thank you for the fix! Always happy to squash bugs :)

@ASKabalan
Copy link

@patrick-kidger Hello 👋

Can this be merged

This issue causing a bug for solvers that converge too quickly

import jax.numpy as jnp
import optimistix as optx

def fn(y , arg):
    return jnp.sum((y - 2.0)**2) + jnp.sum(y **2) * 0.0

DIM = 10
key = jax.random.PRNGKey(42)
y0 = jax.random.normal(key, (DIM,))

solver = optx.BestSoFarMinimiser(optx.BFGS(atol=1e-8 , rtol=1e-8))
sol = optx.minimise(fn, solver, y0, max_steps=1000)
print(f"Solution found: {sol.value}")
print(f"Best loss: {sol.state.best_loss}")
print(f"Loss using last params {fn(sol.value , None)}")

This PR solves it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BestSoFarMinimiser behavior

3 participants