Skip to content

Commit 389d83d

Browse files
committed
convert dummy_solver
1 parent a8af8a4 commit 389d83d

File tree

2 files changed

+135
-32
lines changed

2 files changed

+135
-32
lines changed

test/dummy_solver.jl

Lines changed: 129 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,157 @@
1+
# non-allocating reshape
2+
# see https://github.com/JuliaLang/julia/issues/36313
3+
reshape_array(a, dims) = invoke(Base._reshape, Tuple{AbstractArray, typeof(dims)}, a, dims)
4+
5+
mutable struct DummySolver{S} <: AbstractOptimizationSolver
6+
x::S # primal approximation
7+
gx::S # gradient of objective
8+
y::S # multipliers estimates
9+
rhs::S # right-hand size of Newton system
10+
jval::S # flattened Jacobian
11+
hval::S # flattened Hessian
12+
wval::S # flattened augmented matrix
13+
Δxy::S # search direction
14+
end
15+
16+
function DummySolver(nlp::AbstractNLPModel{T, S}) where {T, S <: AbstractVector{T}}
17+
nvar, ncon = nlp.meta.nvar, nlp.meta.ncon
18+
x = similar(nlp.meta.x0)
19+
gx = similar(nlp.meta.x0)
20+
y = similar(nlp.meta.y0)
21+
rhs = similar(nlp.meta.x0, nvar + ncon)
22+
jval = similar(nlp.meta.x0, ncon * nvar)
23+
hval = similar(nlp.meta.x0, nvar * nvar)
24+
wval = similar(nlp.meta.x0, (nvar + ncon) * (nvar + ncon))
25+
Δxy = similar(nlp.meta.x0, nvar + ncon)
26+
DummySolver{S}(x, gx, y, rhs, jval, hval, wval, Δxy)
27+
end
28+
129
function dummy_solver(
2-
nlp::AbstractNLPModel;
3-
x::AbstractVector = nlp.meta.x0,
4-
atol::Real = sqrt(eps(eltype(x))),
5-
rtol::Real = sqrt(eps(eltype(x))),
30+
nlp::AbstractNLPModel{T, S},
31+
args...;
32+
kwargs...,
33+
) where {T, S <: AbstractVector{T}}
34+
solver = DummySolver(nlp)
35+
stats = GenericExecutionStats(nlp)
36+
solve!(solver, nlp, stats, args...; kwargs...)
37+
end
38+
39+
function SolverCore.solve!(
40+
solver::DummySolver{S},
41+
nlp::AbstractNLPModel{T, S},
42+
stats::GenericExecutionStats;
43+
callback = (args...) -> nothing,
44+
x0::S = nlp.meta.x0,
45+
atol::Real = sqrt(eps(T)),
46+
rtol::Real = sqrt(eps(T)),
647
max_eval::Int = 1000,
748
max_time::Float64 = 30.0,
8-
)
49+
verbose::Bool = true,
50+
) where {T, S <: AbstractVector{T}}
951
start_time = time()
1052
elapsed_time = 0.0
1153

1254
nvar, ncon = nlp.meta.nvar, nlp.meta.ncon
13-
T = eltype(x)
55+
x = solver.x .= x0
56+
rhs = solver.rhs
57+
dual = view(rhs, 1:nvar)
58+
cx = view(rhs, (nvar + 1):(nvar + ncon))
59+
gx = solver.gx
60+
y = solver.y
61+
jval = solver.jval
62+
hval = solver.hval
63+
wval = solver.wval
64+
Δxy = solver.Δxy
65+
nnzh = Int(nvar * (nvar + 1) / 2)
66+
nnzh == nlp.meta.nnzh || error("solver assumes Hessian is dense")
67+
nvar * ncon == nlp.meta.nnzj || error("solver assumes Jacobian is dense")
1468

15-
cx = ncon > 0 ? cons(nlp, x) : zeros(T, 0)
16-
gx = grad(nlp, x)
17-
Jx = ncon > 0 ? Matrix(jac(nlp, x)) : zeros(T, 0, nvar)
18-
y = -Jx' \ gx
19-
Hxy = ncon > 0 ? hess(nlp, x, y) : hess(nlp, x)
69+
grad!(nlp, x, gx)
70+
dual .= gx
2071

21-
dual = gx + Jx' * y
72+
# assume the model returns a dense Jacobian in column-major order
73+
if ncon > 0
74+
cons!(nlp, x, cx)
75+
jac_coord!(nlp, x, jval)
76+
Jx = reshape_array(jval, (ncon, nvar))
77+
Jqr = qr(Jx')
2278

23-
iter = 0
79+
# compute least-squares multipliers
80+
# by solving Jx' y = -gx
81+
gx .*= -1
82+
ldiv!(y, Jqr, gx)
83+
84+
# update dual <- dual + Jx' * y
85+
mul!(dual, Jx', y, one(T), one(T))
86+
end
2487

88+
iter = 0
2589
ϵd = atol + rtol * norm(dual)
2690
ϵp = atol
2791

2892
fx = obj(nlp, x)
29-
@info log_header([:iter, :f, :c, :dual, :t, :x], [Int, T, T, T, Float64, Char])
30-
@info log_row(Any[iter, fx, norm(cx), norm(dual), elapsed_time, 'c'])
93+
verbose && @info log_header([:iter, :f, :c, :dual, :t, :x], [Int, T, T, T, Float64, Char])
94+
verbose && @info log_row(Any[iter, fx, norm(cx), norm(dual), elapsed_time, 'c'])
3195
solved = norm(dual) < ϵd && norm(cx) < ϵp
3296
tired = neval_obj(nlp) + neval_cons(nlp) > max_eval || elapsed_time > max_time
3397

3498
while !(solved || tired)
35-
Hxy = ncon > 0 ? hess(nlp, x, y) : hess(nlp, x)
36-
W = Symmetric([Hxy zeros(T, nvar, ncon); Jx zeros(T, ncon, ncon)], :L)
37-
Δxy = -W \ [dual; cx]
38-
Δx = Δxy[1:nvar]
39-
Δy = Δxy[(nvar + 1):end]
40-
x += Δx
41-
y += Δy
42-
43-
cx = ncon > 0 ? cons(nlp, x) : zeros(T, 0)
44-
gx = grad(nlp, x)
45-
Jx = ncon > 0 ? Matrix(jac(nlp, x)) : zeros(T, 0, nvar)
46-
dual = gx + Jx' * y
99+
# assume the model returns a dense Hessian in column-major order
100+
# NB: hess_coord!() only returns values in the lower triangle
101+
hess_coord!(nlp, x, y, view(hval, 1:nnzh))
102+
103+
# rearrange nonzeros so they correspond to a dense nvar x nvar matrix
104+
j = nvar * nvar
105+
i = nnzh
106+
k = 1
107+
while i > nvar
108+
for _ = 1:k
109+
hval[j] = hval[i]
110+
hval[i] = 0
111+
j -= 1
112+
i -= 1
113+
end
114+
j -= nvar - k
115+
k += 1
116+
end
117+
118+
# fill in augmented matrix
119+
# W = [H J']
120+
# [J 0 ]
121+
wval .= 0
122+
Wxy = reshape_array(wval, (nvar + ncon, nvar + ncon))
123+
Hxy = reshape_array(hval, (nvar, nvar))
124+
Wxy[1:nvar, 1:nvar] .= Hxy
125+
if ncon > 0
126+
Wxy[(nvar + 1):(nvar + ncon), 1:nvar] .= Jx
127+
end
128+
LBL = factorize(Symmetric(Wxy, :L))
129+
130+
ldiv!(Δxy, LBL, rhs)
131+
Δxy .*= -1
132+
@views Δx = Δxy[1:nvar]
133+
@views Δy = Δxy[(nvar + 1):(nvar + ncon)]
134+
x .+= Δx
135+
y .+= Δy
136+
137+
grad!(nlp, x, gx)
138+
dual .= gx
139+
if ncon > 0
140+
cons!(nlp, x, cx)
141+
jac_coord!(nlp, x, jval)
142+
Jx = reshape_array(jval, (ncon, nvar))
143+
Jqr = qr(Jx')
144+
gx .*= -1
145+
ldiv!(y, Jqr, gx)
146+
mul!(dual, Jx', y, one(T), one(T))
147+
end
47148
elapsed_time = time() - start_time
48149
solved = norm(dual) < ϵd && norm(cx) < ϵp
49150
tired = neval_obj(nlp) + neval_cons(nlp) > max_eval || elapsed_time > max_time
50151

51152
iter += 1
52153
fx = obj(nlp, x)
53-
@info log_row(Any[iter, fx, norm(cx), norm(dual), elapsed_time, 'd'])
154+
verbose && @info log_row(Any[iter, fx, norm(cx), norm(dual), elapsed_time, 'd'])
54155
end
55156

56157
status = if solved
@@ -61,7 +162,6 @@ function dummy_solver(
61162
:max_eval
62163
end
63164

64-
stats = GenericExecutionStats(nlp)
65165
set_status!(stats, status)
66166
set_objective!(stats, fx)
67167
set_residuals!(stats, norm(cx), norm(dual))

test/test_stats.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ function test_stats()
6666
@testset "Testing Dummy Solver with multi-precision" begin
6767
for T in (Float16, Float32, Float64, BigFloat)
6868
nlp = ADNLPModel(x -> dot(x, x), ones(T, 2))
69+
solver = DummySolver(nlp)
6970

70-
with_logger(NullLogger()) do
71-
stats = dummy_solver(nlp)
71+
stats = with_logger(NullLogger()) do
72+
solve!(solver, nlp)
7273
end
7374
@test typeof(stats.objective) == T
7475
@test typeof(stats.dual_feas) == T
@@ -79,9 +80,11 @@ function test_stats()
7980
@test eltype(stats.multipliers_U) == T
8081

8182
nlp = ADNLPModel(x -> dot(x, x), ones(T, 2), x -> [sum(x) - 1], T[0], T[0])
83+
solver = DummySolver(nlp)
84+
stats = GenericExecutionStats(nlp)
8285

8386
with_logger(NullLogger()) do
84-
stats = dummy_solver(nlp)
87+
solve!(solver, nlp, stats)
8588
end
8689
@test typeof(stats.objective) == T
8790
@test typeof(stats.dual_feas) == T

0 commit comments

Comments
 (0)