Skip to content

Commit 2deaacd

Browse files
committed
Add callback
1 parent 32c3418 commit 2deaacd

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

src/AdaptiveRegularization.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,23 @@ Additional `kwargs` are used for stopping criterion, see `Stopping.jl`.
5858
# Output
5959
The returned value is a `GenericExecutionStats`, see `SolverCore.jl`.
6060
61+
# Callback
62+
The callback is called at each iteration.
63+
The expected signature of the callback is `callback(nlp, solver, stats)`, and its output is ignored.
64+
Changing any of the input arguments will affect the subsequent iterations.
65+
In particular, setting `stats.status = :user` will stop the algorithm.
66+
All relevant information should be available in `nlp` and `solver`.
67+
Notably, you can access, and modify, the following:
68+
- `solver.stp`: stopping object used for the algorithm;
69+
- `solver.workspace`: additional allocations;
70+
- `stats`: structure holding the output of the algorithm (`GenericExecutionStats`), which contains, among other things:
71+
- `stats.dual_feas`: norm of current gradient;
72+
- `stats.iter`: current iteration counter;
73+
- `stats.objective`: current objective function value;
74+
- `stats.status`: current status of the algorithm. Should be `:unknown` unless the algorithm has attained a stopping criterion. Changing this to anything will stop the algorithm, but you should use `:user` to properly indicate the intention.
75+
- `stats.elapsed_time`: elapsed time in seconds.
76+
77+
6178
This implementation uses `Stopping.jl`. Therefore, it is also possible to used
6279
6380
TRARC(stp; kwargs...)

src/main.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ function SolverCore.solve!(
103103
stats::GenericExecutionStats{T, S};
104104
robust::Bool = true,
105105
verbose::Integer = false,
106+
callback = (args...) -> nothing,
106107
kwargs...,
107108
) where {Pb, M, SRC, MStp, LoS, Score, S, T}
108109
PData = solver.meta
@@ -139,6 +140,7 @@ function SolverCore.solve!(
139140
set_solution!(stats, nlp_at_x.x)
140141
set_objective!(stats, nlp_at_x.fx)
141142
set_dual_residual!(stats, nlp_at_x.current_score)
143+
set_iter!(stats, 0)
142144
set_time!(stats, nlp_at_x.current_time - nlp_stop.meta.start_time)
143145

144146
verbose > 0 && @info log_header(
@@ -147,7 +149,9 @@ function SolverCore.solve!(
147149
)
148150
verbose > 0 && @info log_row(Any[iter, ft, norm_∇f, 0.0, "First iteration", α])
149151

150-
while !OK
152+
callback(nlp, solver, stats)
153+
154+
while !OK && (stats.status != :user)
151155
preprocess!(nlp_stop, PData, workspace, ∇f, norm_∇f, α)
152156

153157
if ~PData.OK
@@ -232,7 +236,7 @@ function SolverCore.solve!(
232236
set_dual_residual!(stats, nlp_at_x.current_score)
233237
set_iter!(stats, nlp_stop.meta.nb_of_stop)
234238
set_time!(stats, nlp_at_x.current_time - nlp_stop.meta.start_time)
235-
# TODO: callback(nlp, solver, stats)
239+
callback(nlp, solver, stats)
236240
end # while !OK
237241

238242
stats

test/callback.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@testset "Test callback" begin
2+
f(x) = (x[1] - 1)^2 + 4 * (x[2] - x[1]^2)^2
3+
nlp = ADNLPModel(f, [-1.2; 1.0])
4+
function cb(nlp, solver, stats)
5+
if stats.iter == 8
6+
stats.status = :user
7+
end
8+
end
9+
stats = TRARC(nlp, callback = cb)
10+
@test stats.iter == 8
11+
end
12+
13+
@testset "Test callback for NLS" begin
14+
F(x) = [x[1] - 1; 2 * (x[2] - x[1]^2)]
15+
nls = ADNLSModel(F, [-1.2; 1.0], 2)
16+
function cb(nlp, solver, stats)
17+
if stats.iter == 8
18+
stats.status = :user
19+
end
20+
end
21+
22+
stats = TRARC(nls, callback = cb)
23+
@test stats.iter == 8
24+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using ADNLPModels, NLPModels, OptimizationProblems.ADNLPProblems, SolverCore, So
77
# Stopping
88
using Stopping
99

10+
include("callback.jl")
11+
1012
@testset "Testing NLP solvers" begin
1113
@testset "$name" for name in ALL_solvers
1214
solver = eval(name)

0 commit comments

Comments
 (0)