Skip to content

Commit b8c6130

Browse files
Merge pull request #675 from SciML/opfjvvj
Add jvp and vjp fields to OptimizationFunction
2 parents 0ec4159 + 20e4799 commit b8c6130

File tree

3 files changed

+83
-43
lines changed

3 files changed

+83
-43
lines changed

src/problems/bvp_problems.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ struct SecondOrderBVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
320320
problem_type::PT
321321
kwargs::K
322322

323-
@add_kwonly function SecondOrderBVProblem{iip}(f::DynamicalBVPFunction{iip, TP}, u0, tspan,
323+
@add_kwonly function SecondOrderBVProblem{iip}(
324+
f::DynamicalBVPFunction{iip, TP}, u0, tspan,
324325
p = NullParameters(); problem_type = nothing, nlls = nothing,
325326
kwargs...) where {iip, TP}
326327
_u0 = prepare_initial_state(u0)
@@ -331,16 +332,19 @@ struct SecondOrderBVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
331332
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
332333
end
333334

334-
function SecondOrderBVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
335+
function SecondOrderBVProblem{iip}(
336+
f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
335337
SecondOrderBVProblem(DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
336338
end
337339
end
338340

339341
function SecondOrderBVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
340342
iip = isinplace(f, 5)
341-
return SecondOrderBVProblem{iip}(DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
343+
return SecondOrderBVProblem{iip}(
344+
DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
342345
end
343346

344-
function SecondOrderBVProblem(f::DynamicalBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
347+
function SecondOrderBVProblem(
348+
f::DynamicalBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
345349
return SecondOrderBVProblem{isinplace(f)}(f, u0, tspan, p; kwargs...)
346-
end
350+
end

src/scimlfunctions.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,7 +1789,8 @@ and more. For all cases, `u` is the state and `p` are the parameters.
17891789
```julia
17901790
OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
17911791
grad = nothing, hess = nothing, hv = nothing,
1792-
cons = nothing, cons_j = nothing, cons_h = nothing,
1792+
cons = nothing, cons_j = nothing, cons_jvp = nothing,
1793+
cons_vjp = nothing, cons_h = nothing,
17931794
hess_prototype = nothing,
17941795
cons_jac_prototype = nothing,
17951796
cons_hess_prototype = nothing,
@@ -1827,6 +1828,8 @@ function described in [Callback Functions](https://docs.sciml.ai/Optimization/st
18271828
bounds passed as `lcons` and `ucons` to [`OptimizationProblem`](@ref), in case of equality
18281829
constraints `lcons` and `ucons` should be passed equal values.
18291830
- `cons_j(J,x,p)` or `J=cons_j(x,p)`: the Jacobian of the constraints.
1831+
- `cons_jvp(Jv,v,x,p)` or `Jv=cons_jvp(v,x,p)`: the Jacobian-vector product of the constraints.
1832+
- `cons_vjp(Jv,v,x,p)` or `Jv=cons_vjp(v,x,p)`: the Jacobian-vector product of the constraints.
18301833
- `cons_h(H,x,p)` or `H=cons_h(x,p)`: the Hessian of the constraints, provided as
18311834
an array of Hessians with `res[i]` being the Hessian with respect to the `i`th output on `cons`.
18321835
- `hess_prototype`: a prototype matrix matching the type that matches the Hessian. For example,
@@ -1892,7 +1895,7 @@ For more details on this argument, see the ODEFunction documentation.
18921895
18931896
The fields of the OptimizationFunction type directly match the names of the inputs.
18941897
"""
1895-
struct OptimizationFunction{iip, AD, F, G, H, HV, C, CJ, CH, HP, CJP, CHP, O,
1898+
struct OptimizationFunction{iip, AD, F, G, H, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
18961899
EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV} <:
18971900
AbstractOptimizationFunction{iip}
18981901
f::F
@@ -1902,6 +1905,8 @@ struct OptimizationFunction{iip, AD, F, G, H, HV, C, CJ, CH, HP, CJP, CHP, O,
19021905
hv::HV
19031906
cons::C
19041907
cons_j::CJ
1908+
cons_jvp::CJV
1909+
cons_vjp::CVJ
19051910
cons_h::CH
19061911
hess_prototype::HP
19071912
cons_jac_prototype::CJP
@@ -2137,7 +2142,8 @@ For more details on this argument, see the ODEFunction documentation.
21372142
21382143
The fields of the DynamicalBVPFunction type directly match the names of the inputs.
21392144
"""
2140-
struct DynamicalBVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
2145+
struct DynamicalBVPFunction{
2146+
iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
21412147
JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV,
21422148
SYS} <: AbstractBVPFunction{iip, twopoint}
21432149
f::F
@@ -3758,7 +3764,8 @@ OptimizationFunction(args...; kwargs...) = OptimizationFunction{true}(args...; k
37583764

37593765
function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
37603766
grad = nothing, hess = nothing, hv = nothing,
3761-
cons = nothing, cons_j = nothing, cons_h = nothing,
3767+
cons = nothing, cons_j = nothing, cons_jvp = nothing,
3768+
cons_vjp = nothing, cons_h = nothing,
37623769
hess_prototype = nothing,
37633770
cons_jac_prototype = __has_jac_prototype(f) ?
37643771
f.jac_prototype : nothing,
@@ -3780,7 +3787,8 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
37803787
sys = sys_or_symbolcache(sys, syms, paramsyms)
37813788
OptimizationFunction{iip, typeof(adtype), typeof(f), typeof(grad), typeof(hess),
37823789
typeof(hv),
3783-
typeof(cons), typeof(cons_j), typeof(cons_h),
3790+
typeof(cons), typeof(cons_j), typeof(cons_jvp),
3791+
typeof(cons_vjp), typeof(cons_h),
37843792
typeof(hess_prototype),
37853793
typeof(cons_jac_prototype), typeof(cons_hess_prototype),
37863794
typeof(observed),
@@ -3789,7 +3797,8 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
37893797
typeof(cons_jac_colorvec), typeof(cons_hess_colorvec),
37903798
typeof(lag_hess_colorvec)
37913799
}(f, adtype, grad, hess,
3792-
hv, cons, cons_j, cons_h,
3800+
hv, cons, cons_j, cons_jvp,
3801+
cons_vjp, cons_h,
37933802
hess_prototype, cons_jac_prototype,
37943803
cons_hess_prototype, observed, expr, cons_expr, sys,
37953804
lag_h, lag_hess_prototype, hess_colorvec, cons_jac_colorvec,
@@ -3987,7 +3996,6 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc;
39873996
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
39883997
bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing,
39893998
sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, twopoint}
3990-
39913999
if mass_matrix === I && f isa Tuple
39924000
mass_matrix = ((I for i in 1:length(f))...,)
39934001
end
@@ -4095,7 +4103,7 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc;
40954103
_f = prepare_function(f)
40964104

40974105
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))
4098-
4106+
40994107
if specialize === NoSpecialize
41004108
DynamicalBVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any,
41014109
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any,
@@ -4127,11 +4135,11 @@ function DynamicalBVPFunction{iip}(f, bc; twopoint::Union{Val, Bool} = Val(false
41274135
end
41284136
DynamicalBVPFunction{iip}(f::DynamicalBVPFunction, bc; kwargs...) where {iip} = f
41294137
function DynamicalBVPFunction(f, bc; twopoint::Union{Val, Bool} = Val(false), kwargs...)
4130-
DynamicalBVPFunction{isinplace(f, 5), FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...)
4138+
DynamicalBVPFunction{isinplace(f, 5), FullSpecialize, _unwrap_val(twopoint)}(
4139+
f, bc; kwargs...)
41314140
end
41324141
DynamicalBVPFunction(f::DynamicalBVPFunction; kwargs...) = f
41334142

4134-
41354143
function IntegralFunction{iip, specialize}(f, integrand_prototype) where {iip, specialize}
41364144
_f = prepare_function(f)
41374145
IntegralFunction{iip, specialize, typeof(_f), typeof(integrand_prototype)}(_f,

test/function_building_error_messages.jl

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -706,24 +706,35 @@ DynamicalBVPFunction(dbfiip, dbciip, jac = dbjac, bcjac = dbcjac)
706706
DynamicalBVPFunction(dbfoop, dbcoop, jac = dbjac, bcjac = dbcjac)
707707

708708
dbWfact(du, u, t) = [1.0]
709-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact)
710-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact)
709+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
710+
dbfiip, dbciip, Wfact = dbWfact)
711+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
712+
dbfoop, dbciip, Wfact = dbWfact)
711713
dbWfact(du, u, p, t) = [1.0]
712-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact)
713-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact)
714+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
715+
dbfiip, dbciip, Wfact = dbWfact)
716+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
717+
dbfoop, dbciip, Wfact = dbWfact)
714718
dbWfact(du, u, p, gamma, t) = [1.0]
715-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact)
716-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact)
719+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
720+
dbfiip, dbciip, Wfact = dbWfact)
721+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
722+
dbfoop, dbciip, Wfact = dbWfact)
717723
dbWfact(ddu, du, u, p, gamma, t) = [1.0]
718724
DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact)
719-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact)
725+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
726+
dbfoop, dbciip, Wfact = dbWfact)
720727

721728
dbWfact_t(du, u, t) = [1.0]
722-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t)
723-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact_t = dbWfact_t)
729+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
730+
dbfiip, dbciip, Wfact_t = dbWfact_t)
731+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
732+
dbfoop, dbciip, Wfact_t = dbWfact_t)
724733
dbWfact_t(du, u, p, t) = [1.0]
725-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t)
726-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact_t = dbWfact_t)
734+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
735+
dbfiip, dbciip, Wfact_t = dbWfact_t)
736+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
737+
dbfoop, dbciip, Wfact_t = dbWfact_t)
727738
dbWfact_t(du, u, p, gamma, t) = [1.0]
728739
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip,
729740
dbciip,
@@ -738,18 +749,25 @@ DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t)
738749
Wfact_t = dbWfact_t)
739750

740751
dbtgrad(du, u, t) = [1.0]
741-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad)
742-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad)
752+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
753+
dbfiip, dbciip, tgrad = dbtgrad)
754+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
755+
dbfoop, dbciip, tgrad = dbtgrad)
743756
dbtgrad(du, u, p, t) = [1.0]
744-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad)
745-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad)
757+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
758+
dbfiip, dbciip, tgrad = dbtgrad)
759+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
760+
dbfoop, dbciip, tgrad = dbtgrad)
746761
dbtgrad(ddu, du, u, p, t) = [1.0]
747762
DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad)
748-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad)
763+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
764+
dbfoop, dbciip, tgrad = dbtgrad)
749765

750766
dbparamjac(du, u, t) = [1.0]
751-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, paramjac = dbparamjac)
752-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, paramjac = dbparamjac)
767+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
768+
dbfiip, dbciip, paramjac = dbparamjac)
769+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
770+
dbfoop, dbciip, paramjac = dbparamjac)
753771
dbparamjac(du, u, p, t) = [1.0]
754772
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip,
755773
dbciip,
@@ -764,25 +782,35 @@ DynamicalBVPFunction(dbfiip, dbciip, paramjac = dbparamjac)
764782
paramjac = dbparamjac)
765783

766784
dbjvp(du, u, p, t) = [1.0]
767-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp)
768-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp)
785+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
786+
dbfiip, dbciip, jvp = dbjvp)
787+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
788+
dbfoop, dbciip, jvp = dbjvp)
769789
dbjvp(du, u, v, p, t) = [1.0]
770-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp)
771-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp)
790+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
791+
dbfiip, dbciip, jvp = dbjvp)
792+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
793+
dbfoop, dbciip, jvp = dbjvp)
772794
dbjvp(ddu, du, u, v, p, t) = [1.0]
773795
DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp)
774-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp)
796+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
797+
dbfoop, dbciip, jvp = dbjvp)
775798

776799
dbvjp(du, u, p, t) = [1.0]
777-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp)
778-
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp)
800+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
801+
dbfiip, dbciip, vjp = dbvjp)
802+
@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(
803+
dbfoop, dbciip, vjp = dbvjp)
779804
dbvjp(du, u, v, p, t) = [1.0]
780-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp)
781-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp)
805+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
806+
dbfiip, dbciip, vjp = dbvjp)
807+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
808+
dbfoop, dbciip, vjp = dbvjp)
782809
dbvjp(ddu, du, u, v, p, t) = [1.0]
783810
DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp)
784811

785-
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp)
812+
@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(
813+
dbfoop, dbciip, vjp = dbvjp)
786814

787815
# IntegralFunction
788816

0 commit comments

Comments
 (0)