Skip to content

Commit bcd324a

Browse files
feat: support parameter updates in initialize_dae!
1 parent 457611c commit bcd324a

File tree

3 files changed

+13
-0
lines changed

3 files changed

+13
-0
lines changed

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.6.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
89
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
910
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1011
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -42,6 +43,7 @@ OrdinaryDiffEqCoreEnzymeCoreExt = "EnzymeCore"
4243

4344
[compat]
4445
ADTypes = "0.2, 1"
46+
Accessors = "0.1.36"
4547
Adapt = "3.0, 4"
4648
ArrayInterface = "7"
4749
DataStructures = "0.18"

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ import DiffEqBase: calculate_residuals,
7070
import Polyester
7171
using MacroTools, Adapt
7272
import ADTypes: AutoFiniteDiff, AutoForwardDiff
73+
import Accessors: @reset
7374

7475
using SciMLStructures: canonicalize, Tunable, isscimlstructure
7576

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
153153
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
154154
initializeprob = prob.f.initializeprob
155155

156+
if SciMLBase.has_update_initializeprob!(prob.f)
157+
prob.f.update_initializeprob!(initializeprob, prob)
158+
end
159+
156160
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
157161
# Since then it's the case of not a DAE but has initializeprob
158162
# In which case, it should be differentiable
@@ -173,6 +177,12 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
173177
else
174178
error("Unreachable reached. Report this error.")
175179
end
180+
if SciMLBase.has_initializeprobpmap(prob.f)
181+
integrator.p = prob.f.initializeprobpmap(prob, nlsol)
182+
sol = integrator.sol
183+
@reset sol.prob.p = integrator.p
184+
integrator.sol = sol
185+
end
176186

177187
if nlsol.retcode != ReturnCode.Success
178188
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,

0 commit comments

Comments
 (0)