Skip to content

Commit 7add2a0

Browse files
committed
convert dummy_solver
1 parent 6b39dc4 commit 7add2a0

File tree

2 files changed

+131
-28
lines changed

2 files changed

+131
-28
lines changed

test/dummy_solver.jl

Lines changed: 125 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,90 @@
1+
# non-allocating reshape
2+
# see https://github.com/JuliaLang/julia/issues/36313
3+
reshape_array(a, dims) = invoke(Base._reshape, Tuple{AbstractArray, typeof(dims)}, a, dims)
4+
5+
mutable struct DummySolver{S} <: AbstractOptimizationSolver
6+
x::S # primal approximation
7+
gx::S # gradient of objective
8+
y::S # multipliers estimates
9+
rhs::S # right-hand size of Newton system
10+
jval::S # flattened Jacobian
11+
hval::S # flattened Hessian
12+
wval::S # flattened augmented matrix
13+
Δxy::S # search direction
14+
end
15+
16+
function DummySolver(nlp::AbstractNLPModel{T, S}) where {T, S <: AbstractVector{T}}
17+
nvar, ncon = nlp.meta.nvar, nlp.meta.ncon
18+
x = similar(nlp.meta.x0)
19+
gx = similar(nlp.meta.x0)
20+
y = similar(nlp.meta.y0)
21+
rhs = similar(nlp.meta.x0, nvar + ncon)
22+
jval = similar(nlp.meta.x0, ncon * nvar)
23+
hval = similar(nlp.meta.x0, nvar * nvar)
24+
wval = similar(nlp.meta.x0, (nvar + ncon) * (nvar + ncon))
25+
Δxy = similar(nlp.meta.x0, nvar + ncon)
26+
DummySolver{S}(x, gx, y, rhs, jval, hval, wval, Δxy)
27+
end
28+
129
function dummy_solver(
2-
nlp::AbstractNLPModel;
3-
x::AbstractVector = nlp.meta.x0,
4-
atol::Real = sqrt(eps(eltype(x))),
5-
rtol::Real = sqrt(eps(eltype(x))),
30+
nlp::AbstractNLPModel{T, S},
31+
args...;
32+
kwargs...,
33+
) where {T, S <: AbstractVector{T}}
34+
solver = DummySolver(nlp)
35+
stats = GenericExecutionStats(nlp)
36+
solve!(solver, nlp, stats, args...; kwargs...)
37+
end
38+
39+
function solve!(
40+
solver::DummySolver{S},
41+
nlp::AbstractNLPModel{T, S},
42+
stats::GenericExecutionStats;
43+
callback = (args...) -> nothing,
44+
x0::S = nlp.meta.x0,
45+
atol::Real = sqrt(eps(T)),
46+
rtol::Real = sqrt(eps(T)),
647
max_eval::Int = 1000,
748
max_time::Float64 = 30.0,
8-
)
49+
) where {T, S <: AbstractVector{T}}
950
start_time = time()
1051
elapsed_time = 0.0
1152

1253
nvar, ncon = nlp.meta.nvar, nlp.meta.ncon
13-
T = eltype(x)
54+
x = solver.x .= x0
55+
rhs = solver.rhs
56+
dual = view(rhs, 1:nvar)
57+
cx = view(rhs, (nvar + 1):(nvar + ncon))
58+
gx = solver.gx
59+
y = solver.y
60+
jval = solver.jval
61+
hval = solver.hval
62+
wval = solver.wval
63+
Δxy = solver.Δxy
64+
nnzh = Int(nvar * (nvar + 1) / 2)
65+
nnzh == nlp.meta.nnzh || error("solver assumes Hessian is dense")
66+
nvar * ncon == nlp.meta.nnzj || error("solver assumes Jacobian is dense")
1467

15-
cx = ncon > 0 ? cons(nlp, x) : zeros(T, 0)
16-
gx = grad(nlp, x)
17-
Jx = ncon > 0 ? Matrix(jac(nlp, x)) : zeros(T, 0, nvar)
18-
y = -Jx' \ gx
19-
Hxy = ncon > 0 ? hess(nlp, x, y) : hess(nlp, x)
68+
grad!(nlp, x, gx)
69+
dual .= gx
2070

21-
dual = gx + Jx' * y
71+
# assume the model returns a dense Jacobian in column-major order
72+
if ncon > 0
73+
cons!(nlp, x, cx)
74+
jac_coord!(nlp, x, jval)
75+
Jx = reshape_array(jval, (ncon, nvar))
76+
Jqr = qr(Jx')
2277

23-
iter = 0
78+
# compute least-squares multipliers
79+
# by solving Jx' y = -gx
80+
gx .*= -1
81+
ldiv!(y, Jqr, gx)
82+
83+
# update dual <- dual + Jx' * y
84+
mul!(dual, Jx', y, one(T), one(T))
85+
end
2486

87+
iter = 0
2588
ϵd = atol + rtol * norm(dual)
2689
ϵp = atol
2790

@@ -32,18 +95,55 @@ function dummy_solver(
3295
tired = neval_obj(nlp) + neval_cons(nlp) > max_eval || elapsed_time > max_time
3396

3497
while !(solved || tired)
35-
Hxy = ncon > 0 ? hess(nlp, x, y) : hess(nlp, x)
36-
W = Symmetric([Hxy zeros(T, nvar, ncon); Jx zeros(T, ncon, ncon)], :L)
37-
Δxy = -W \ [dual; cx]
38-
Δx = Δxy[1:nvar]
39-
Δy = Δxy[(nvar + 1):end]
40-
x += Δx
41-
y += Δy
42-
43-
cx = ncon > 0 ? cons(nlp, x) : zeros(T, 0)
44-
gx = grad(nlp, x)
45-
Jx = ncon > 0 ? Matrix(jac(nlp, x)) : zeros(T, 0, nvar)
46-
dual = gx + Jx' * y
98+
# assume the model returns a dense Hessian in column-major order
99+
# NB: hess_coord!() only returns values in the lower triangle
100+
hess_coord!(nlp, x, y, view(hval, 1:nnzh))
101+
102+
# rearrange nonzeros so they correspond to a dense nvar x nvar matrix
103+
j = nvar * nvar
104+
i = nnzh
105+
k = 1
106+
while i > nvar
107+
for _ = 1:k
108+
hval[j] = hval[i]
109+
hval[i] = 0
110+
j -= 1
111+
i -= 1
112+
end
113+
j -= nvar - k
114+
k += 1
115+
end
116+
117+
# fill in augmented matrix
118+
# W = [H J']
119+
# [J 0 ]
120+
wval .= 0
121+
Wxy = reshape_array(wval, (nvar + ncon, nvar + ncon))
122+
Hxy = reshape_array(hval, (nvar, nvar))
123+
Wxy[1:nvar, 1:nvar] .= Hxy
124+
if ncon > 0
125+
Wxy[(nvar + 1):(nvar + ncon), 1:nvar] .= Jx
126+
end
127+
LBL = factorize(Symmetric(Wxy, :L))
128+
129+
ldiv!(Δxy, LBL, rhs)
130+
Δxy .*= -1
131+
@views Δx = Δxy[1:nvar]
132+
@views Δy = Δxy[(nvar + 1):(nvar + ncon)]
133+
x .+= Δx
134+
y .+= Δy
135+
136+
grad!(nlp, x, gx)
137+
dual .= gx
138+
if ncon > 0
139+
cons!(nlp, x, cx)
140+
jac_coord!(nlp, x, jval)
141+
Jx = reshape_array(jval, (ncon, nvar))
142+
Jqr = qr(Jx')
143+
gx .*= -1
144+
ldiv!(y, Jqr, gx)
145+
mul!(dual, Jx', y, one(T), one(T))
146+
end
47147
elapsed_time = time() - start_time
48148
solved = norm(dual) < ϵd && norm(cx) < ϵp
49149
tired = neval_obj(nlp) + neval_cons(nlp) > max_eval || elapsed_time > max_time
@@ -61,7 +161,6 @@ function dummy_solver(
61161
:max_eval
62162
end
63163

64-
stats = GenericExecutionStats(nlp)
65164
set_status!(stats, status)
66165
set_objective!(stats, fx)
67166
set_residuals!(stats, norm(cx), norm(dual))

test/test_stats.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ function test_stats()
6666
@testset "Testing Dummy Solver with multi-precision" begin
6767
for T in (Float16, Float32, Float64, BigFloat)
6868
nlp = ADNLPModel(x -> dot(x, x), ones(T, 2))
69+
solver = DummySolver(nlp)
70+
stats = GenericExecutionStats(nlp)
6971

7072
with_logger(NullLogger()) do
71-
stats = dummy_solver(nlp)
73+
solve!(solver, nlp, stats)
7274
end
7375
@test typeof(stats.objective) == T
7476
@test typeof(stats.dual_feas) == T
@@ -79,9 +81,11 @@ function test_stats()
7981
@test eltype(stats.multipliers_U) == T
8082

8183
nlp = ADNLPModel(x -> dot(x, x), ones(T, 2), x -> [sum(x) - 1], T[0], T[0])
84+
solver = DummySolver(nlp)
85+
stats = GenericExecutionStats(nlp)
8286

8387
with_logger(NullLogger()) do
84-
stats = dummy_solver(nlp)
88+
solve!(solver, nlp, stats)
8589
end
8690
@test typeof(stats.objective) == T
8791
@test typeof(stats.dual_feas) == T

0 commit comments

Comments
 (0)