@@ -2094,6 +2094,109 @@ struct MultiObjectiveOptimizationFunction{
2094
2094
initialization_data:: ID
2095
2095
end
2096
2096
2097
+ """
2098
+ $(TYPEDEF)
2099
+ """
2100
+ abstract type AbstractODEInputFunction{iip} <: AbstractDiffEqFunction{iip} end
2101
+
2102
+ @doc doc"""
2103
+ $(TYPEDEF)
2104
+
2105
+ A representation of a ODE function `f` with inputs, defined by:
2106
+
2107
+ ```math
2108
+ \f rac{dx}{dt} = f(x, u, p, t)
2109
+ ```
2110
+ where `x` are the states of the system and `u` are the inputs (which may represent
2111
+ different things in different contexts, such as control variables in optimal control).
2112
+
2113
+ Includes all of its related functions, such as the Jacobian of `f`, its gradient
2114
+ with respect to time, and more. For all cases, `u0` is the initial condition,
2115
+ `p` are the parameters, and `t` is the independent variable.
2116
+
2117
+ ```julia
2118
+ ODEInputFunction{iip, specialize}(f;
2119
+ mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
2120
+ analytic = __has_analytic(f) ? f.analytic : nothing,
2121
+ tgrad= __has_tgrad(f) ? f.tgrad : nothing,
2122
+ jac = __has_jac(f) ? f.jac : nothing,
2123
+ control_jac = __has_controljac(f) ? f.controljac : nothing,
2124
+ jvp = __has_jvp(f) ? f.jvp : nothing,
2125
+ vjp = __has_vjp(f) ? f.vjp : nothing,
2126
+ jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
2127
+ controljac_prototype = __has_controljac_prototype(f) ? f.controljac_prototype : nothing,
2128
+ sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
2129
+ paramjac = __has_paramjac(f) ? f.paramjac : nothing,
2130
+ syms = nothing,
2131
+ indepsym = nothing,
2132
+ paramsyms = nothing,
2133
+ colorvec = __has_colorvec(f) ? f.colorvec : nothing,
2134
+ sys = __has_sys(f) ? f.sys : nothing)
2135
+ ```
2136
+
2137
+ `f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`.
2138
+ See the section on `iip` for more details on in-place vs out-of-place handling.
2139
+
2140
+ - `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
2141
+ to determine that the equation is actually a BVP for differential algebraic equation (DAE)
2142
+ if `M` is singular.
2143
+ - `jac(J,dx,x,u,p,gamma,t)` or `J=jac(dx,x,u,p,gamma,t)`: returns ``\f rac{df}{dx}``
2144
+ - `control_jac(J,du,x,u,p,gamma,t)` or `J=control_jac(du,x,u,p,gamma,t)`: returns ``\f rac{df}{du}``
2145
+ - `jvp(Jv,v,du,x,u,p,gamma,t)` or `Jv=jvp(v,du,x,u,p,gamma,t)`: returns the directional
2146
+ derivative ``\f rac{df}{du} v``
2147
+ - `vjp(Jv,v,du,x,u,p,gamma,t)` or `Jv=vjp(v,du,x,u,p,gamma,t)`: returns the adjoint
2148
+ derivative ``\f rac{df}{du}^\a st v``
2149
+ - `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
2150
+ if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
2151
+ as the prototype and integrators will specialize on this structure where possible. Non-structured
2152
+ sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
2153
+ The default is `nothing`, which means a dense Jacobian.
2154
+ - `controljac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
2155
+ if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
2156
+ as the prototype and integrators will specialize on this structure where possible. Non-structured
2157
+ sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
2158
+ The default is `nothing`, which means a dense Jacobian.
2159
+ - `paramjac(pJ,x,u,p,t)`: returns the parameter Jacobian ``\f rac{df}{dp}``.
2160
+ - `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity
2161
+ pattern of the `jac_prototype`. This specializes the Jacobian construction when using
2162
+ finite differences and automatic differentiation to be computed in an accelerated manner
2163
+ based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
2164
+ internally computed on demand when required. The cost of this operation is highly dependent
2165
+ on the sparsity pattern.
2166
+
2167
+ ## iip: In-Place vs Out-Of-Place
2168
+ For more details on this argument, see the ODEFunction documentation.
2169
+
2170
+ ## specialize: Controlling Compilation and Specialization
2171
+ For more details on this argument, see the ODEFunction documentation.
2172
+
2173
+ ## Fields
2174
+ The fields of the ODEInputFunction type directly match the names of the inputs.
2175
+ """
2176
+ struct ODEInputFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
2177
+ JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV,
2178
+ SYS, ID} <: AbstractODEInputFunction{iip}
2179
+ f:: F
2180
+ mass_matrix:: TMM
2181
+ analytic:: Ta
2182
+ tgrad:: Tt
2183
+ jac:: TJ
2184
+ controljac:: CTJ
2185
+ jvp:: JVP
2186
+ vjp:: VJP
2187
+ jac_prototype:: JP
2188
+ controljac_prototype:: CJP
2189
+ sparsity:: SP
2190
+ Wfact:: TW
2191
+ Wfact_t:: TWt
2192
+ W_prototype:: WP
2193
+ paramjac:: TPJ
2194
+ observed:: O
2195
+ colorvec:: TCV
2196
+ sys:: SYS
2197
+ initialization_data:: ID
2198
+ end
2199
+
2097
2200
"""
2098
2201
$(TYPEDEF)
2099
2202
"""
@@ -2493,6 +2596,7 @@ end
2493
2596
(f:: ImplicitDiscreteFunction )(args... ) = f. f (args... )
2494
2597
(f:: DAEFunction )(args... ) = f. f (args... )
2495
2598
(f:: DDEFunction )(args... ) = f. f (args... )
2599
+ (f:: ODEInputFunction )(args... ) = f. f (args... )
2496
2600
2497
2601
function (f:: DynamicalDDEFunction )(u, h, p, t)
2498
2602
ArrayPartition (f. f1 (u. x[1 ], u. x[2 ], h, p, t), f. f2 (u. x[1 ], u. x[2 ], h, p, t))
@@ -4595,6 +4699,149 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
4595
4699
BatchIntegralFunction {calculated_iip} (f, integrand_prototype; kwargs... )
4596
4700
end
4597
4701
4702
+ function ODEInputFunction {iip, specialize} (f;
4703
+ mass_matrix = __has_mass_matrix (f) ? f. mass_matrix :
4704
+ I,
4705
+ analytic = __has_analytic (f) ? f. analytic : nothing ,
4706
+ tgrad = __has_tgrad (f) ? f. tgrad : nothing ,
4707
+ jac = __has_jac (f) ? f. jac : nothing ,
4708
+ controljac = __has_controljac (f) ? f. controljac : nothing ,
4709
+ jvp = __has_jvp (f) ? f. jvp : nothing ,
4710
+ vjp = __has_vjp (f) ? f. vjp : nothing ,
4711
+ jac_prototype = __has_jac_prototype (f) ?
4712
+ f. jac_prototype :
4713
+ nothing ,
4714
+ controljac_prototype = __has_controljac_prototype (f) ?
4715
+ f. controljac_prototype :
4716
+ nothing ,
4717
+ sparsity = __has_sparsity (f) ? f. sparsity :
4718
+ jac_prototype,
4719
+ Wfact = __has_Wfact (f) ? f. Wfact : nothing ,
4720
+ Wfact_t = __has_Wfact_t (f) ? f. Wfact_t : nothing ,
4721
+ W_prototype = __has_W_prototype (f) ? f. W_prototype : nothing ,
4722
+ paramjac = __has_paramjac (f) ? f. paramjac : nothing ,
4723
+ syms = nothing ,
4724
+ indepsym = nothing ,
4725
+ paramsyms = nothing ,
4726
+ observed = __has_observed (f) ? f. observed :
4727
+ DEFAULT_OBSERVED,
4728
+ colorvec = __has_colorvec (f) ? f. colorvec : nothing ,
4729
+ sys = __has_sys (f) ? f. sys : nothing ,
4730
+ initializeprob = __has_initializeprob (f) ? f. initializeprob : nothing ,
4731
+ update_initializeprob! = __has_update_initializeprob! (f) ?
4732
+ f. update_initializeprob! : nothing ,
4733
+ initializeprobmap = __has_initializeprobmap (f) ? f. initializeprobmap : nothing ,
4734
+ initializeprobpmap = __has_initializeprobpmap (f) ? f. initializeprobpmap : nothing ,
4735
+ initialization_data = __has_initialization_data (f) ? f. initialization_data :
4736
+ nothing ,
4737
+ nlprob_data = __has_nlprob_data (f) ? f. nlprob_data : nothing
4738
+ ) where {iip,
4739
+ specialize
4740
+ }
4741
+ if mass_matrix === I && f isa Tuple
4742
+ mass_matrix = ((I for i in 1 : length (f)). .. ,)
4743
+ end
4744
+
4745
+ if (specialize === FunctionWrapperSpecialize) &&
4746
+ ! (f isa FunctionWrappersWrappers. FunctionWrappersWrapper)
4747
+ error (" FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!" )
4748
+ end
4749
+
4750
+ if jac === nothing && isa (jac_prototype, AbstractSciMLOperator)
4751
+ if iip
4752
+ jac = (J, x, u, p, t) -> update_coefficients! (J, x, p, t) # (J,x,u,p,t)
4753
+ else
4754
+ jac = (x, u, p, t) -> update_coefficients (deepcopy (jac_prototype), x, p, t)
4755
+ end
4756
+ end
4757
+
4758
+ if controljac === nothing && isa (controljac_prototype, AbstractSciMLOperator)
4759
+ if iip_bc
4760
+ controljac = (J, x, u, p, t) -> update_coefficients! (J, u, p, t) # (J,x,u,p,t)
4761
+ else
4762
+ controljac = (x, u, p, t) -> update_coefficients (deepcopy (controljac_prototype), u, p, t)
4763
+ end
4764
+ end
4765
+
4766
+ if jac_prototype != = nothing && colorvec === nothing &&
4767
+ ArrayInterface. fast_matrix_colors (jac_prototype)
4768
+ _colorvec = ArrayInterface. matrix_colors (jac_prototype)
4769
+ else
4770
+ _colorvec = colorvec
4771
+ end
4772
+
4773
+ jaciip = jac != = nothing ? isinplace (jac, 5 , " jac" , iip) : iip
4774
+ controljaciip = controljac != = nothing ? isinplace (controljac, 5 , " controljac" , iip) : iip
4775
+ tgradiip = tgrad != = nothing ? isinplace (tgrad, 5 , " tgrad" , iip) : iip
4776
+ jvpiip = jvp != = nothing ? isinplace (jvp, 6 , " jvp" , iip) : iip
4777
+ vjpiip = vjp != = nothing ? isinplace (vjp, 6 , " vjp" , iip) : iip
4778
+ Wfactiip = Wfact != = nothing ? isinplace (Wfact, 6 , " Wfact" , iip) : iip
4779
+ Wfact_tiip = Wfact_t != = nothing ? isinplace (Wfact_t, 6 , " Wfact_t" , iip) : iip
4780
+ paramjaciip = paramjac != = nothing ? isinplace (paramjac, 5 , " paramjac" , iip) : iip
4781
+
4782
+ nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
4783
+ paramjaciip) .!= iip
4784
+ if any (nonconforming)
4785
+ nonconforming = findall (nonconforming)
4786
+ functions = [" jac" , " tgrad" , " jvp" , " vjp" , " Wfact" , " Wfact_t" , " paramjac" ][nonconforming]
4787
+ throw (NonconformingFunctionsError (functions))
4788
+ end
4789
+
4790
+ _f = prepare_function (f)
4791
+
4792
+ sys = sys_or_symbolcache (sys, syms, paramsyms, indepsym)
4793
+ initdata = reconstruct_initialization_data (
4794
+ initialization_data, initializeprob, update_initializeprob!,
4795
+ initializeprobmap, initializeprobpmap)
4796
+
4797
+ if specialize === NoSpecialize
4798
+ ODEInputFunction{iip, specialize,
4799
+ Any, Any, Any, Any,
4800
+ Any, Any, Any, Any, typeof (jac_prototype), typeof (controljac_prototype),
4801
+ typeof (sparsity), Any, Any, typeof (W_prototype), Any,
4802
+ Any,
4803
+ typeof (_colorvec),
4804
+ typeof (sys), Union{Nothing, OverrideInitData}}(
4805
+ _f, mass_matrix, analytic, tgrad, jac, controljac,
4806
+ jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4807
+ Wfact_t, W_prototype, paramjac,
4808
+ observed, _colorvec, sys, initdata)
4809
+ elseif specialize === false
4810
+ ODEInputFunction{iip, FunctionWrapperSpecialize,
4811
+ typeof (_f), typeof (mass_matrix), typeof (analytic), typeof (tgrad),
4812
+ typeof (jac), typeof (controljac), typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (controljac_prototype),
4813
+ typeof (sparsity), typeof (Wfact), typeof (Wfact_t), typeof (W_prototype),
4814
+ typeof (paramjac),
4815
+ typeof (observed),
4816
+ typeof (_colorvec),
4817
+ typeof (sys), typeof (initdata)}(_f, mass_matrix,
4818
+ analytic, tgrad, jac, controljac,
4819
+ jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4820
+ Wfact_t, W_prototype, paramjac,
4821
+ observed, _colorvec, sys, initdata)
4822
+ else
4823
+ ODEInputFunction{iip, specialize,
4824
+ typeof (_f), typeof (mass_matrix), typeof (analytic), typeof (tgrad),
4825
+ typeof (jac), typeof (controljac), typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (controljac_prototype),
4826
+ typeof (sparsity), typeof (Wfact), typeof (Wfact_t), typeof (W_prototype),
4827
+ typeof (paramjac),
4828
+ typeof (observed),
4829
+ typeof (_colorvec),
4830
+ typeof (sys), typeof (initdata)}(
4831
+ _f, mass_matrix, analytic, tgrad,
4832
+ jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4833
+ Wfact_t, W_prototype, paramjac,
4834
+ observed, _colorvec, sys, initdata)
4835
+ end
4836
+ end
4837
+
4838
+ function ODEInputFunction {iip} (f; kwargs... ) where {iip}
4839
+ ODEInputFunction {iip, FullSpecialize} (f; kwargs... )
4840
+ end
4841
+ ODEInputFunction {iip} (f:: ODEInputFunction ; kwargs... ) where {iip} = f
4842
+ ODEInputFunction (f; kwargs... ) = ODEInputFunction {isinplace(f, 5), FullSpecialize} (f; kwargs... )
4843
+ ODEInputFunction (f:: ODEInputFunction ; kwargs... ) = f
4844
+
4598
4845
# ######### Utility functions
4599
4846
4600
4847
function sys_or_symbolcache (sys, syms, paramsyms, indepsym = nothing )
@@ -4628,6 +4875,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t)
4628
4875
__has_W_prototype (f) = isdefined (f, :W_prototype )
4629
4876
__has_paramjac (f) = isdefined (f, :paramjac )
4630
4877
__has_jac_prototype (f) = isdefined (f, :jac_prototype )
4878
+ __has_controljac_prototype (f) = isdefined (f, :controljac_prototype )
4631
4879
__has_sparsity (f) = isdefined (f, :sparsity )
4632
4880
__has_mass_matrix (f) = isdefined (f, :mass_matrix )
4633
4881
__has_syms (f) = isdefined (f, :syms )
0 commit comments