Check last point when using best so far minimiser#34
Open
ColCarroll wants to merge 1 commit intopatrick-kidger:mainfrom
Open
Check last point when using best so far minimiser#34ColCarroll wants to merge 1 commit intopatrick-kidger:mainfrom
ColCarroll wants to merge 1 commit intopatrick-kidger:mainfrom
Conversation
Comment on lines
+119
to
+120
| best_f, best_aux = fn(state.best_y, args) | ||
| best_loss = self._to_loss(state.best_y, best_f) |
Owner
There was a problem hiding this comment.
Agreed, I definitely don't understand these lines. What do you think is going on?
| solver = optx.BestSoFarMinimiser(optx.BFGS(rtol=1e-5, atol=1e-5)) | ||
| sol = optx.minimise(fn, solver, jnp.array(0.0)) | ||
| assert sol.value == 3.0 | ||
| # assert fn(sol.value, None) <= fn(sol.state.state.y_eval, None) |
Owner
|
Thank you for the fix! Always happy to squash bugs :) |
|
@patrick-kidger Hello 👋 Can this be merged This issue causing a bug for solvers that converge too quickly import jax.numpy as jnp
import optimistix as optx
def fn(y , arg):
return jnp.sum((y - 2.0)**2) + jnp.sum(y **2) * 0.0
DIM = 10
key = jax.random.PRNGKey(42)
y0 = jax.random.normal(key, (DIM,))
solver = optx.BestSoFarMinimiser(optx.BFGS(atol=1e-8 , rtol=1e-8))
sol = optx.minimise(fn, solver, y0, max_steps=1000)
print(f"Solution found: {sol.value}")
print(f"Best loss: {sol.state.best_loss}")
print(f"Loss using last params {fn(sol.value , None)}")This PR solves it |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #33
There is some unnecessary looking lines
where I would expect to just use
state.best_loss, but the test doesn't pass without it!