Skip to content

Commit 169d419

Browse files
Merge pull request #860 from oscardssmith/os/refactor-nlprob-to-nlprob_data
fix nlprob to match the initialization system
2 parents 4e3fa90 + 067a926 commit 169d419

File tree

4 files changed

+75
-34
lines changed

4 files changed

+75
-34
lines changed

src/ODE_nlsolve.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
A collection of all the data required for custom ODE Nonlinear problem solving
5+
"""
6+
struct ODE_NLProbData{NLProb, UNLProb, NLProbMap, NLProbPmap}
7+
"""
8+
The `AbstractNonlinearProblem` to define custom nonlinear problems to be used for
9+
implicit time discretizations. This allows to use extra structure of the ODE function (e.g.
10+
multi-level structure). The nonlinear function must match that form of the function implicit
11+
ODE integration algorithms need do solve the a nonlinear problems,
12+
specifically of the form `z = outer_tmp + dt⋅f(γ⋅z+inner_tmp,p,t)`.
13+
Here `z` is the stage solution vector, `p` is the parameter of the ODE problem, `t` is
14+
the time, `dt` the respective time increment`, `γ` is some scaling factor and the temporary
15+
variables are some compatible vectors set by the specific solver.
16+
Note that this field will not be used for integrators such as fully-implicit Runge-Kutta methods
17+
that need to solve different nonlinear systems.
18+
The inner nonlinear function of the nonlinear problem is in general of the form `g(z,p') = 0`
19+
where `p'` is a NamedTuple with all information about the specific nonlinear problem at hand to solve
20+
for a specific time discretization. Specifically, it is `(;dt, γ, inner_tmp, outer_tmp, t, p)`, such that
21+
`g(z,p') = dt⋅f(γ⋅z+inner_tmp,p,t) + outer_tmp - z = 0`.
22+
"""
23+
nlprob::NLProb
24+
"""
25+
A function which takes `(nlprob, value_provider)` and updates
26+
the parameters of the former with their values in the latter.
27+
If absent (`nothing`) this will not be called, and the parameters
28+
in `nlprob` will be used without modification. `value_provider`
29+
refers to a value provider as defined by SymbolicIndexingInterface.jl.
30+
Usually this will refer to a problem or integrator.
31+
"""
32+
update_nlprob!::UNLProb
33+
"""
34+
A function which takes the solution of `nlprob` and returns
35+
the state vector of the original problem.
36+
"""
37+
nlprobmap::NLProbMap
38+
"""
39+
A function which takes the solution of `nlprob` and returns
40+
the parameter object of the original problem. If absent (`nothing`),
41+
this will not be called and the parameters of the problem being
42+
solved will be returned as-is.
43+
"""
44+
nlprobpmap::NLProbPmap
45+
end
46+

src/SciMLBase.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,8 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
658658
"""
659659
struct TrackerOriginator <: ADOriginator end
660660

661+
include("initialization.jl")
662+
include("ODE_nlsolve.jl")
661663
include("utils.jl")
662664
include("function_wrappers.jl")
663665
include("scimlfunctions.jl")
@@ -744,7 +746,6 @@ include("ensemble/ensemble_problems.jl")
744746
include("ensemble/basic_ensemble_solve.jl")
745747
include("ensemble/ensemble_analysis.jl")
746748

747-
include("initialization.jl")
748749
include("solve.jl")
749750
include("interpolation.jl")
750751
include("integrator_interface.jl")

src/initialization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Check if the algebraic constraints are satisfied, and error if they aren't. Retu
100100
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101101
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
102102
"""
103-
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,
103+
function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit,
104104
isinplace::Union{Val{true}, Val{false}}; kwargs...)
105105
u0 = state_values(integrator)
106106
p = parameter_values(integrator)
@@ -135,7 +135,7 @@ function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
135135
return f(args...)
136136
end
137137

138-
function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit,
138+
function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit,
139139
isinplace::Union{Val{true}, Val{false}}; kwargs...)
140140
u0 = state_values(integrator)
141141
p = parameter_values(integrator)

src/scimlfunctions.jl

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,6 @@ the usage of `f`. These include:
289289
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
290290
internally computed on demand when required. The cost of this operation is highly dependent
291291
on the sparsity pattern.
292-
- `nlprob`: a `NonlinearProblem` that solves `f(u, t, p) = u_tmp`
293-
where the nonlinear parameters are the tuple `(t, u_tmp, p)`.
294-
This will be used as the nonlinear problem inside an implicit solver by specifying `u, u_tmp` and `t`
295-
such that solving this function produces a solution to the implicit step of your solver.
296-
297292
## iip: In-Place vs Out-Of-Place
298293
299294
`iip` is the optional boolean for determining whether a given function is written to
@@ -406,7 +401,7 @@ numerically-defined functions.
406401
"""
407402
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
408403
O, TCV,
409-
SYS, ID, NLP} <: AbstractODEFunction{iip}
404+
SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
410405
f::F
411406
mass_matrix::TMM
412407
analytic::Ta
@@ -424,7 +419,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
424419
colorvec::TCV
425420
sys::SYS
426421
initialization_data::ID
427-
nlprob::NLP
422+
nlprob_data::NLP
428423
end
429424

430425
@doc doc"""
@@ -527,8 +522,7 @@ information on generating the SplitFunction from this symbolic engine.
527522
"""
528523
struct SplitFunction{
529524
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
530-
TPJ, O,
531-
TCV, SYS, ID, NLP} <: AbstractODEFunction{iip}
525+
TPJ, O, TCV, SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
532526
f1::F1
533527
f2::F2
534528
mass_matrix::TMM
@@ -547,8 +541,8 @@ struct SplitFunction{
547541
observed::O
548542
colorvec::TCV
549543
sys::SYS
550-
nlprob::NLP
551544
initialization_data::ID
545+
nlprob_data::NLP
552546
end
553547

554548
@doc doc"""
@@ -2446,9 +2440,9 @@ function ODEFunction{iip, specialize}(f;
24462440
f.update_initializeprob! : nothing,
24472441
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
24482442
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
2449-
nlprob = __has_nlprob(f) ? f.nlprob : nothing,
24502443
initialization_data = __has_initialization_data(f) ? f.initialization_data :
2451-
nothing
2444+
nothing,
2445+
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing,
24522446
) where {iip,
24532447
specialize
24542448
}
@@ -2506,10 +2500,10 @@ function ODEFunction{iip, specialize}(f;
25062500
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
25072501
Any,
25082502
typeof(_colorvec),
2509-
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2503+
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(_f, mass_matrix, analytic, tgrad, jac,
25102504
jvp, vjp, jac_prototype, sparsity, Wfact,
25112505
Wfact_t, W_prototype, paramjac,
2512-
observed, _colorvec, sys, initdata, nlprob)
2506+
observed, _colorvec, sys, initdata, nlprob_data)
25132507
elseif specialize === false
25142508
ODEFunction{iip, FunctionWrapperSpecialize,
25152509
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2518,11 +2512,11 @@ function ODEFunction{iip, specialize}(f;
25182512
typeof(paramjac),
25192513
typeof(observed),
25202514
typeof(_colorvec),
2521-
typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix,
2515+
typeof(sys), typeof(initdata), typeof(nlprob_data)}(_f, mass_matrix,
25222516
analytic, tgrad, jac,
25232517
jvp, vjp, jac_prototype, sparsity, Wfact,
25242518
Wfact_t, W_prototype, paramjac,
2525-
observed, _colorvec, sys, initdata, nlprob)
2519+
observed, _colorvec, sys, initdata, nlprob_data)
25262520
else
25272521
ODEFunction{iip, specialize,
25282522
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2531,11 +2525,11 @@ function ODEFunction{iip, specialize}(f;
25312525
typeof(paramjac),
25322526
typeof(observed),
25332527
typeof(_colorvec),
2534-
typeof(sys), typeof(initdata), typeof(nlprob)}(
2528+
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
25352529
_f, mass_matrix, analytic, tgrad,
25362530
jac, jvp, vjp, jac_prototype, sparsity, Wfact,
25372531
Wfact_t, W_prototype, paramjac,
2538-
observed, _colorvec, sys, initdata, nlprob)
2532+
observed, _colorvec, sys, initdata, nlprob_data)
25392533
end
25402534
end
25412535

@@ -2552,23 +2546,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25522546
Any, Any, Any, Any, typeof(f.jac_prototype),
25532547
typeof(f.sparsity), Any, Any, Any,
25542548
Any, typeof(f.colorvec),
2555-
typeof(f.sys), Any, Any}(
2549+
typeof(f.sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
25562550
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25572551
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25582552
f.Wfact_t, f.W_prototype, f.paramjac,
2559-
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
2553+
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
25602554
else
25612555
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25622556
typeof(f.analytic), typeof(f.tgrad),
25632557
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
25642558
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
25652559
typeof(f.paramjac),
25662560
typeof(f.observed), typeof(f.colorvec),
2567-
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob)}(
2561+
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob_data)}(
25682562
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25692563
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25702564
f.Wfact_t, f.W_prototype, f.paramjac,
2571-
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
2565+
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
25722566
end
25732567
end
25742568

@@ -2703,7 +2697,7 @@ end
27032697
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
27042698
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
27052699
observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing,
2706-
initializeprobmap = nothing, initializeprobpmap = nothing, nlprob = nothing, initialization_data = nothing)
2700+
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob_data = nothing)
27072701
f1 = ODEFunction(f1)
27082702
f2 = ODEFunction(f2)
27092703

@@ -2721,11 +2715,11 @@ end
27212715
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
27222716
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27232717
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2724-
typeof(sys), typeof(initdata), typeof(nlprob)}(
2718+
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
27252719
f1, f2, mass_matrix,
27262720
cache, analytic, tgrad, jac, jvp, vjp,
27272721
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2728-
initdata, nlprob)
2722+
initdata, nlprob_data)
27292723
end
27302724
function SplitFunction{iip, specialize}(f1, f2;
27312725
mass_matrix = __has_mass_matrix(f1) ?
@@ -2762,7 +2756,7 @@ function SplitFunction{iip, specialize}(f1, f2;
27622756
f1.update_initializeprob! : nothing,
27632757
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
27642758
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
2765-
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing,
2759+
nlprob_data = __has_nlprob_data(f1) ? f1.nlprob_data : nothing,
27662760
initialization_data = __has_initialization_data(f1) ? f1.initialization_data :
27672761
nothing
27682762
) where {iip,
@@ -2776,23 +2770,23 @@ function SplitFunction{iip, specialize}(f1, f2;
27762770
if specialize === NoSpecialize
27772771
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
27782772
Any, Any, Any, Any, Any, Any, Any,
2779-
Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
2773+
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(f1, f2, mass_matrix, _func_cache,
27802774
analytic,
27812775
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
27822776
sparsity, Wfact, Wfact_t, paramjac,
2783-
observed, colorvec, sys, initdata, nlprob)
2777+
observed, colorvec, sys, initdata, nlprob_data)
27842778
else
27852779
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
27862780
typeof(_func_cache), typeof(analytic),
27872781
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
27882782
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27892783
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
27902784
typeof(colorvec),
2791-
typeof(sys), typeof(initdata), typeof(nlprob)}(f1, f2,
2785+
typeof(sys), typeof(initdata), typeof(nlprob_data)}(f1, f2,
27922786
mass_matrix, _func_cache, analytic, tgrad, jac,
27932787
jvp, vjp, jac_prototype, W_prototype,
27942788
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2795-
initdata, nlprob)
2789+
initdata, nlprob_data)
27962790
end
27972791
end
27982792

@@ -4488,7 +4482,7 @@ __has_colorvec(f) = isdefined(f, :colorvec)
44884482
__has_sys(f) = isdefined(f, :sys)
44894483
__has_analytic_full(f) = isdefined(f, :analytic_full)
44904484
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
4491-
__has_nlprob(f) = isdefined(f, :nlprob)
4485+
__has_nlprob_data(f) = isdefined(f, :nlprob_data)
44924486
function __has_initializeprob(f)
44934487
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob)
44944488
end

0 commit comments

Comments
 (0)