Skip to content

Commit a15f5a4

Browse files
Merge pull request #996 from vyudu/ctrl
feaet: add ControlFunction
2 parents 5f4ac56 + 97115af commit a15f5a4

File tree

4 files changed

+250
-2
lines changed

4 files changed

+250
-2
lines changed

src/SciMLBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, D
822822
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
823823
IncrementingODEFunction, NonlinearFunction, HomotopyNonlinearFunction,
824824
IntervalNonlinearFunction, BVPFunction,
825-
DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction
825+
DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction, ODEInputFunction
826826

827827
export OptimizationFunction, MultiObjectiveOptimizationFunction
828828

src/problems/implicit_discrete_problems.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dt: the time step
2727
2828
### Constructors
2929
30-
- `ImplicitDiscreteProblem(f::ODEFunction,u0,tspan,p=NullParameters();kwargs...)` :
30+
- `ImplicitDiscreteProblem(f::ImplicitDiscreteFunction,u0,tspan,p=NullParameters();kwargs...)` :
3131
Defines the discrete problem with the specified functions.
3232
- `ImplicitDiscreteProblem{isinplace,specialize}(f,u0,tspan,p=NullParameters();kwargs...)` :
3333
Defines the discrete problem with the specified functions.

src/scimlfunctions.jl

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,109 @@ struct MultiObjectiveOptimizationFunction{
20942094
initialization_data::ID
20952095
end
20962096

2097+
"""
2098+
$(TYPEDEF)
2099+
"""
2100+
abstract type AbstractODEInputFunction{iip} <: AbstractDiffEqFunction{iip} end
2101+
2102+
@doc doc"""
2103+
$(TYPEDEF)
2104+
2105+
A representation of a ODE function `f` with inputs, defined by:
2106+
2107+
```math
2108+
\frac{dx}{dt} = f(x, u, p, t)
2109+
```
2110+
where `x` are the states of the system and `u` are the inputs (which may represent
2111+
different things in different contexts, such as control variables in optimal control).
2112+
2113+
Includes all of its related functions, such as the Jacobian of `f`, its gradient
2114+
with respect to time, and more. For all cases, `u0` is the initial condition,
2115+
`p` are the parameters, and `t` is the independent variable.
2116+
2117+
```julia
2118+
ODEInputFunction{iip, specialize}(f;
2119+
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
2120+
analytic = __has_analytic(f) ? f.analytic : nothing,
2121+
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
2122+
jac = __has_jac(f) ? f.jac : nothing,
2123+
control_jac = __has_controljac(f) ? f.controljac : nothing,
2124+
jvp = __has_jvp(f) ? f.jvp : nothing,
2125+
vjp = __has_vjp(f) ? f.vjp : nothing,
2126+
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
2127+
controljac_prototype = __has_controljac_prototype(f) ? f.controljac_prototype : nothing,
2128+
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
2129+
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
2130+
syms = nothing,
2131+
indepsym = nothing,
2132+
paramsyms = nothing,
2133+
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
2134+
sys = __has_sys(f) ? f.sys : nothing)
2135+
```
2136+
2137+
`f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`.
2138+
See the section on `iip` for more details on in-place vs out-of-place handling.
2139+
2140+
- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
2141+
to determine that the equation is actually a BVP for differential algebraic equation (DAE)
2142+
if `M` is singular.
2143+
- `jac(J,dx,x,u,p,gamma,t)` or `J=jac(dx,x,u,p,gamma,t)`: returns ``\frac{df}{dx}``
2144+
- `control_jac(J,du,x,u,p,gamma,t)` or `J=control_jac(du,x,u,p,gamma,t)`: returns ``\frac{df}{du}``
2145+
- `jvp(Jv,v,du,x,u,p,gamma,t)` or `Jv=jvp(v,du,x,u,p,gamma,t)`: returns the directional
2146+
derivative ``\frac{df}{du} v``
2147+
- `vjp(Jv,v,du,x,u,p,gamma,t)` or `Jv=vjp(v,du,x,u,p,gamma,t)`: returns the adjoint
2148+
derivative ``\frac{df}{du}^\ast v``
2149+
- `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
2150+
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
2151+
as the prototype and integrators will specialize on this structure where possible. Non-structured
2152+
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
2153+
The default is `nothing`, which means a dense Jacobian.
2154+
- `controljac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
2155+
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
2156+
as the prototype and integrators will specialize on this structure where possible. Non-structured
2157+
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
2158+
The default is `nothing`, which means a dense Jacobian.
2159+
- `paramjac(pJ,x,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``.
2160+
- `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity
2161+
pattern of the `jac_prototype`. This specializes the Jacobian construction when using
2162+
finite differences and automatic differentiation to be computed in an accelerated manner
2163+
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
2164+
internally computed on demand when required. The cost of this operation is highly dependent
2165+
on the sparsity pattern.
2166+
2167+
## iip: In-Place vs Out-Of-Place
2168+
For more details on this argument, see the ODEFunction documentation.
2169+
2170+
## specialize: Controlling Compilation and Specialization
2171+
For more details on this argument, see the ODEFunction documentation.
2172+
2173+
## Fields
2174+
The fields of the ODEInputFunction type directly match the names of the inputs.
2175+
"""
2176+
struct ODEInputFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
2177+
JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV,
2178+
SYS, ID} <: AbstractODEInputFunction{iip}
2179+
f::F
2180+
mass_matrix::TMM
2181+
analytic::Ta
2182+
tgrad::Tt
2183+
jac::TJ
2184+
controljac::CTJ
2185+
jvp::JVP
2186+
vjp::VJP
2187+
jac_prototype::JP
2188+
controljac_prototype::CJP
2189+
sparsity::SP
2190+
Wfact::TW
2191+
Wfact_t::TWt
2192+
W_prototype::WP
2193+
paramjac::TPJ
2194+
observed::O
2195+
colorvec::TCV
2196+
sys::SYS
2197+
initialization_data::ID
2198+
end
2199+
20972200
"""
20982201
$(TYPEDEF)
20992202
"""
@@ -2493,6 +2596,7 @@ end
24932596
(f::ImplicitDiscreteFunction)(args...) = f.f(args...)
24942597
(f::DAEFunction)(args...) = f.f(args...)
24952598
(f::DDEFunction)(args...) = f.f(args...)
2599+
(f::ODEInputFunction)(args...) = f.f(args...)
24962600

24972601
function (f::DynamicalDDEFunction)(u, h, p, t)
24982602
ArrayPartition(f.f1(u.x[1], u.x[2], h, p, t), f.f2(u.x[1], u.x[2], h, p, t))
@@ -4595,6 +4699,149 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
45954699
BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...)
45964700
end
45974701

4702+
function ODEInputFunction{iip, specialize}(f;
4703+
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
4704+
I,
4705+
analytic = __has_analytic(f) ? f.analytic : nothing,
4706+
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
4707+
jac = __has_jac(f) ? f.jac : nothing,
4708+
controljac = __has_controljac(f) ? f.controljac : nothing,
4709+
jvp = __has_jvp(f) ? f.jvp : nothing,
4710+
vjp = __has_vjp(f) ? f.vjp : nothing,
4711+
jac_prototype = __has_jac_prototype(f) ?
4712+
f.jac_prototype :
4713+
nothing,
4714+
controljac_prototype = __has_controljac_prototype(f) ?
4715+
f.controljac_prototype :
4716+
nothing,
4717+
sparsity = __has_sparsity(f) ? f.sparsity :
4718+
jac_prototype,
4719+
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
4720+
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
4721+
W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing,
4722+
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
4723+
syms = nothing,
4724+
indepsym = nothing,
4725+
paramsyms = nothing,
4726+
observed = __has_observed(f) ? f.observed :
4727+
DEFAULT_OBSERVED,
4728+
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
4729+
sys = __has_sys(f) ? f.sys : nothing,
4730+
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
4731+
update_initializeprob! = __has_update_initializeprob!(f) ?
4732+
f.update_initializeprob! : nothing,
4733+
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
4734+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
4735+
initialization_data = __has_initialization_data(f) ? f.initialization_data :
4736+
nothing,
4737+
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
4738+
) where {iip,
4739+
specialize
4740+
}
4741+
if mass_matrix === I && f isa Tuple
4742+
mass_matrix = ((I for i in 1:length(f))...,)
4743+
end
4744+
4745+
if (specialize === FunctionWrapperSpecialize) &&
4746+
!(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
4747+
error("FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!")
4748+
end
4749+
4750+
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
4751+
if iip
4752+
jac = (J, x, u, p, t) -> update_coefficients!(J, x, p, t) #(J,x,u,p,t)
4753+
else
4754+
jac = (x, u, p, t) -> update_coefficients(deepcopy(jac_prototype), x, p, t)
4755+
end
4756+
end
4757+
4758+
if controljac === nothing && isa(controljac_prototype, AbstractSciMLOperator)
4759+
if iip_bc
4760+
controljac = (J, x, u, p, t) -> update_coefficients!(J, u, p, t) #(J,x,u,p,t)
4761+
else
4762+
controljac = (x, u, p, t) -> update_coefficients(deepcopy(controljac_prototype), u, p, t)
4763+
end
4764+
end
4765+
4766+
if jac_prototype !== nothing && colorvec === nothing &&
4767+
ArrayInterface.fast_matrix_colors(jac_prototype)
4768+
_colorvec = ArrayInterface.matrix_colors(jac_prototype)
4769+
else
4770+
_colorvec = colorvec
4771+
end
4772+
4773+
jaciip = jac !== nothing ? isinplace(jac, 5, "jac", iip) : iip
4774+
controljaciip = controljac !== nothing ? isinplace(controljac, 5, "controljac", iip) : iip
4775+
tgradiip = tgrad !== nothing ? isinplace(tgrad, 5, "tgrad", iip) : iip
4776+
jvpiip = jvp !== nothing ? isinplace(jvp, 6, "jvp", iip) : iip
4777+
vjpiip = vjp !== nothing ? isinplace(vjp, 6, "vjp", iip) : iip
4778+
Wfactiip = Wfact !== nothing ? isinplace(Wfact, 6, "Wfact", iip) : iip
4779+
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 6, "Wfact_t", iip) : iip
4780+
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 5, "paramjac", iip) : iip
4781+
4782+
nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
4783+
paramjaciip) .!= iip
4784+
if any(nonconforming)
4785+
nonconforming = findall(nonconforming)
4786+
functions = ["jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming]
4787+
throw(NonconformingFunctionsError(functions))
4788+
end
4789+
4790+
_f = prepare_function(f)
4791+
4792+
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
4793+
initdata = reconstruct_initialization_data(
4794+
initialization_data, initializeprob, update_initializeprob!,
4795+
initializeprobmap, initializeprobpmap)
4796+
4797+
if specialize === NoSpecialize
4798+
ODEInputFunction{iip, specialize,
4799+
Any, Any, Any, Any,
4800+
Any, Any, Any, Any, typeof(jac_prototype), typeof(controljac_prototype),
4801+
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
4802+
Any,
4803+
typeof(_colorvec),
4804+
typeof(sys), Union{Nothing, OverrideInitData}}(
4805+
_f, mass_matrix, analytic, tgrad, jac, controljac,
4806+
jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4807+
Wfact_t, W_prototype, paramjac,
4808+
observed, _colorvec, sys, initdata)
4809+
elseif specialize === false
4810+
ODEInputFunction{iip, FunctionWrapperSpecialize,
4811+
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
4812+
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
4813+
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
4814+
typeof(paramjac),
4815+
typeof(observed),
4816+
typeof(_colorvec),
4817+
typeof(sys), typeof(initdata)}(_f, mass_matrix,
4818+
analytic, tgrad, jac, controljac,
4819+
jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4820+
Wfact_t, W_prototype, paramjac,
4821+
observed, _colorvec, sys, initdata)
4822+
else
4823+
ODEInputFunction{iip, specialize,
4824+
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
4825+
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
4826+
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
4827+
typeof(paramjac),
4828+
typeof(observed),
4829+
typeof(_colorvec),
4830+
typeof(sys), typeof(initdata)}(
4831+
_f, mass_matrix, analytic, tgrad,
4832+
jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4833+
Wfact_t, W_prototype, paramjac,
4834+
observed, _colorvec, sys, initdata)
4835+
end
4836+
end
4837+
4838+
function ODEInputFunction{iip}(f; kwargs...) where {iip}
4839+
ODEInputFunction{iip, FullSpecialize}(f; kwargs...)
4840+
end
4841+
ODEInputFunction{iip}(f::ODEInputFunction; kwargs...) where {iip} = f
4842+
ODEInputFunction(f; kwargs...) = ODEInputFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...)
4843+
ODEInputFunction(f::ODEInputFunction; kwargs...) = f
4844+
45984845
########## Utility functions
45994846

46004847
function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
@@ -4628,6 +4875,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t)
46284875
__has_W_prototype(f) = isdefined(f, :W_prototype)
46294876
__has_paramjac(f) = isdefined(f, :paramjac)
46304877
__has_jac_prototype(f) = isdefined(f, :jac_prototype)
4878+
__has_controljac_prototype(f) = isdefined(f, :controljac_prototype)
46314879
__has_sparsity(f) = isdefined(f, :sparsity)
46324880
__has_mass_matrix(f) = isdefined(f, :mass_matrix)
46334881
__has_syms(f) = isdefined(f, :syms)

src/solutions/solution_utils.jl

Whitespace-only changes.

0 commit comments

Comments
 (0)