Skip to content

Commit 3958d46

Browse files
committed
More type stability in ADNLPModels.jl
1 parent 63b5bc2 commit 3958d46

File tree

10 files changed

+52
-66
lines changed

10 files changed

+52
-66
lines changed

src/ADNLPModels.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ function ADNLSModel!(model::AbstractNLSModel; kwargs...)
160160
end
161161
end
162162

163-
export get_adbackend, set_adbackend!
163+
export get_adbackend
164164

165165
"""
166166
get_c(nlp)
@@ -244,22 +244,24 @@ Returns the value `adbackend` from nlp.
244244
get_adbackend(nlp::ADModel) = nlp.adbackend
245245

246246
"""
247-
set_adbackend!(nlp, new_adbackend)
248-
set_adbackend!(nlp; kwargs...)
247+
new_nlp = set_adbackend(nlp, new_adbackend)
248+
new_nlp = set_adbackend(nlp; kwargs...)
249249
250-
Replace the current `adbackend` value of nlp by `new_adbackend` or instantiate a new one with `kwargs`, see `ADModelBackend`.
250+
Create a copy of nlp that replaces the current `adbackend` with `new_adbackend` or instantiate a new one with `kwargs`, see `ADModelBackend`.
251251
By default, the setter with kwargs will reuse existing backends.
252252
"""
253-
function set_adbackend!(nlp::ADModel, new_adbackend::ADModelBackend)
254-
nlp.adbackend = new_adbackend
255-
return nlp
253+
function _set_adbackend(nlp::ADM, new_adbackend::ADModelBackend) where{ADM}
254+
values = [f == :adbackend ? new_adbackend : getfield(nlp, f) for f in fieldnames(ADM)]
255+
base_type = Base.typename(ADM).wrapper
256+
return base_type(values...)
256257
end
257-
function set_adbackend!(nlp::ADModel; kwargs...)
258+
259+
function _set_adbackend(nlp::ADModel; kwargs...)
258260
args = []
259261
for field in fieldnames(ADNLPModels.ADModelBackend)
260262
push!(args, if field in keys(kwargs) && typeof(kwargs[field]) <: ADBackend
261263
kwargs[field]
262-
elseif field in keys(kwargs) && typeof(kwargs[field]) <: DataType
264+
elseif field in keys(kwargs) && kwargs[field] <: ADBackend
263265
if typeof(nlp) <: ADNLPModel
264266
kwargs[field](nlp.meta.nvar, nlp.f, nlp.meta.ncon; kwargs...)
265267
elseif typeof(nlp) <: ADNLSModel
@@ -269,8 +271,8 @@ function set_adbackend!(nlp::ADModel; kwargs...)
269271
getfield(nlp.adbackend, field)
270272
end)
271273
end
272-
nlp.adbackend = ADModelBackend(args...)
273-
return nlp
274+
new_nlp = set_adbackend(nlp, ADModelBackend(args...))
275+
return new_nlp
274276
end
275277

276278
end # module

src/forward.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ function gradient!(::GenericForwardDiffADGradient, g, f, x)
44
return ForwardDiff.gradient!(g, f, x)
55
end
66

7-
struct ForwardDiffADGradient <: ADBackend
8-
cfg
7+
struct ForwardDiffADGradient{GC} <: ADBackend
8+
cfg::GC
99
end
1010
function ForwardDiffADGradient(
1111
nvar::Integer,
@@ -109,7 +109,7 @@ function GenericForwardDiffADJtprod(
109109
return GenericForwardDiffADJtprod()
110110
end
111111
function Jtprod!(::GenericForwardDiffADJtprod, Jtv, f, x, v, ::Val)
112-
Jtv .= ForwardDiff.gradient(x -> dot(f(x), v), x)
112+
ForwardDiff.gradient!(Jtv, x -> dot(f(x), v), x)
113113
return Jtv
114114
end
115115

src/nlp.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
export ADNLPModel, ADNLPModel!
22

3-
mutable struct ADNLPModel{T, S, Si} <: AbstractADNLPModel{T, S}
3+
mutable struct ADNLPModel{T, S, Si, F1, F2, ADMB <: ADModelBackend} <: AbstractADNLPModel{T, S}
44
meta::NLPModelMeta{T, S}
55
counters::Counters
6-
adbackend::ADModelBackend
6+
adbackend::ADMB
77

88
# Functions
9-
f
9+
f::F1
1010

1111
clinrows::Si
1212
clincols::Si
1313
clinvals::S
1414

15-
c!
15+
c!::F2
1616
end
1717

1818
ADNLPModel(
@@ -127,7 +127,7 @@ function ADNLPModel(f, x0::S; name::String = "Generic", minimize::Bool = true, k
127127
meta =
128128
NLPModelMeta{T, S}(nvar, x0 = x0, nnzh = nnzh, minimize = minimize, islp = false, name = name)
129129

130-
return ADNLPModel(meta, Counters(), adbackend, f, x -> T[])
130+
return ADNLPModel(meta, Counters(), adbackend, f, (c, x) -> similar(S, 0))
131131
end
132132

133133
function ADNLPModel(
@@ -157,7 +157,7 @@ function ADNLPModel(
157157
name = name,
158158
)
159159

160-
return ADNLPModel(meta, Counters(), adbackend, f, x -> T[])
160+
return ADNLPModel(meta, Counters(), adbackend, f, x -> similar(S, 0))
161161
end
162162

163163
function ADNLPModel(f, x0::S, c, lcon::S, ucon::S; kwargs...) where {S}
@@ -222,8 +222,7 @@ function ADNLPModel(
222222
ucon::S;
223223
kwargs...,
224224
) where {S}
225-
T = eltype(S)
226-
return ADNLPModel(f, x0, clinrows, clincols, clinvals, x -> T[], lcon, ucon; kwargs...)
225+
return ADNLPModel(f, x0, clinrows, clincols, clinvals, x -> similar(S, 0), lcon, ucon; kwargs...)
227226
end
228227

229228
function ADNLPModel(
@@ -339,7 +338,6 @@ function ADNLPModel(
339338
ucon::S;
340339
kwargs...,
341340
) where {S}
342-
T = eltype(S)
343341
return ADNLPModel(
344342
f,
345343
x0,
@@ -348,7 +346,7 @@ function ADNLPModel(
348346
clinrows,
349347
clincols,
350348
clinvals,
351-
x -> T[],
349+
x -> similar(S, 0),
352350
lcon,
353351
ucon;
354352
kwargs...,

src/nls.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
export ADNLSModel, ADNLSModel!
22

3-
mutable struct ADNLSModel{T, S, Si} <: AbstractADNLSModel{T, S}
3+
mutable struct ADNLSModel{T, S, Si, F1, F2, ADMB <: ADModelBackend} <: AbstractADNLSModel{T, S}
44
meta::NLPModelMeta{T, S}
55
nls_meta::NLSMeta{T, S}
66
counters::NLSCounters
7-
adbackend::ADModelBackend
7+
adbackend::ADMB
88

99
# Function
10-
F!
10+
F!::F1
1111

1212
clinrows::Si
1313
clincols::Si
1414
clinvals::S
1515

16-
c!
16+
c!::F2
1717
end
1818

1919
ADNLSModel(

src/reverse.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ end
77
struct GenericReverseDiffADJprod <: ADBackend end
88
struct GenericReverseDiffADJtprod <: ADBackend end
99

10-
struct ReverseDiffADGradient <: ADBackend
11-
cfg
10+
struct ReverseDiffADGradient{GC} <: ADBackend
11+
cfg::GC
1212
end
1313

1414
function ReverseDiffADGradient(

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
34
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
45
ManualNLPModels = "30dfa513-9b2f-4fb3-9796-781eabac1617"
56
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
@@ -12,6 +13,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1213

1314
[compat]
1415
ForwardDiff = "0.10"
16+
JET = "0.9, 0.10"
1517
ManualNLPModels = "0.1"
1618
NLPModels = "0.21"
1719
NLPModelsModifiers = "0.7"

test/nlp/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function test_autodiff_model(name; kwargs...)
2525
@test abs(obj(nlp, β) - norm(y .- β[1] - β[2] * x)^2 / 2) < 1e-12
2626
@test norm(grad(nlp, β)) < 1e-12
2727

28-
test_getter_setter(nlp)
28+
test_allocations(nlp)
2929

3030
@testset "Constructors for ADNLPModel with $name" begin
3131
lvar, uvar, lcon, ucon, y0 = -ones(2), ones(2), -ones(1), ones(1), zeros(1)

test/nls/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function autodiff_nls_test(name; kwargs...)
55

66
@test isapprox(residual(nls, ones(2)), zeros(2), rtol = 1e-8)
77

8-
test_getter_setter(nls)
8+
test_allocations(nls)
99
end
1010

1111
@testset "Constructors for ADNLSModel" begin

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LinearAlgebra, SparseArrays, Test
22
using SparseMatrixColorings
3+
using ForwardDiff, JET
34
using ADNLPModels, ManualNLPModels, NLPModels, NLPModelsModifiers, NLPModelsTest
45
using ADNLPModels:
56
gradient, gradient!, jacobian, hessian, Jprod!, Jtprod!, directional_second_derivative, Hvprod!

test/utils.jl

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,19 @@
1-
ReverseDiffAD(nvar, f) = ADNLPModels.ADModelBackend(
2-
nvar,
3-
f,
4-
gradient_backend = ADNLPModels.ReverseDiffADGradient,
5-
hprod_backend = ADNLPModels.ReverseDiffADHvprod,
6-
jprod_backend = ADNLPModels.ReverseDiffADJprod,
7-
jtprod_backend = ADNLPModels.ReverseDiffADJtprod,
8-
jacobian_backend = ADNLPModels.ReverseDiffADJacobian,
9-
hessian_backend = ADNLPModels.ReverseDiffADHessian,
10-
)
1+
function test_allocations(nlp::ADNLPModel)
2+
x = nlp.meta.x0
3+
y = zeros(eltype(nlp.meta.x0), nlp.meta.ncon)
4+
g = zeros(eltype(nlp.meta.x0), nlp.meta.nvar)
5+
@test_opt target_modules=(ADNLPModels,) obj(nlp, x)
6+
@test_opt target_modules=(ADNLPModels,) cons!(nlp, x, y)
7+
@test_opt target_modules=(ADNLPModels,) grad!(nlp, x, g)
8+
end
119

12-
function test_getter_setter(nlp)
13-
@test get_adbackend(nlp) == nlp.adbackend
14-
if typeof(nlp) <: ADNLPModel
15-
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, nlp.f))
16-
elseif typeof(nlp) <: ADNLSModel
17-
function F(x; nequ = nlp.nls_meta.nequ)
18-
Fx = similar(x, nequ)
19-
nlp.F!(Fx, x)
20-
return Fx
21-
end
22-
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, x -> sum(F(x) .^ 2)))
23-
end
24-
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ReverseDiffADGradient
25-
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
26-
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
27-
set_adbackend!(
28-
nlp,
29-
gradient_backend = ADNLPModels.ForwardDiffADGradient,
30-
jtprod_backend = ADNLPModels.GenericForwardDiffADJtprod(),
31-
)
32-
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ForwardDiffADGradient
33-
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
34-
@test typeof(get_adbackend(nlp).jtprod_backend) <: ADNLPModels.GenericForwardDiffADJtprod
35-
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
10+
function test_allocations(nlp::ADNLSModel)
11+
x = nlp.meta.x0
12+
y = zeros(eltype(nlp.meta.x0), nlp.meta.ncon)
13+
g = zeros(eltype(nlp.meta.x0), nlp.meta.nvar)
14+
Fx = zeros(eltype(nlp.meta.x0), nlp.nls_meta.nequ)
15+
@test_opt target_modules=(ADNLPModels,) function_filter=(@nospecialize(f) -> f != ForwardDiff.gradient!) obj(nlp, x)
16+
@test_opt target_modules=(ADNLPModels,) function_filter=(@nospecialize(f) -> f != ForwardDiff.gradient!) cons!(nlp, x, y)
17+
@test_opt target_modules=(ADNLPModels,) function_filter=(@nospecialize(f) -> f != ForwardDiff.gradient!) grad!(nlp, x, g, Fx)
18+
@test_opt target_modules=(ADNLPModels,) function_filter=(@nospecialize(f) -> f != ForwardDiff.gradient!) residual!(nlp, x, Fx)
3619
end

0 commit comments

Comments
 (0)