Skip to content

Commit e12e8bf

Browse files
Merge pull request #1025 from DhairyaLGandhi/dg/trim
Make `get_initial_values` differentiate
2 parents ca86756 + 88a10de commit e12e8bf

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

src/initialization.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
260260
end
261261

262262
if is_trivial_initialization(initdata)
263-
nlsol = initprob
263+
nlsol = initdata
264264
success = true
265265
else
266266
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
@@ -295,18 +295,13 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
295295
end
296296

297297
if initdata.initializeprobmap !== nothing
298-
u02 = initdata.initializeprobmap(nlsol)
298+
u0 = initdata.initializeprobmap(choose_branch(nlsol))
299299
end
300300
if initdata.initializeprobpmap !== nothing
301-
p2 = initdata.initializeprobpmap(valp, nlsol)
301+
p = initdata.initializeprobpmap(valp, choose_branch(nlsol))
302302
end
303303

304-
# specifically needs to be written this way for Zygote
305-
# See https://github.com/SciML/ModelingToolkit.jl/pull/3585#issuecomment-2883919162
306-
u03 = isnothing(initdata.initializeprobmap) ? u0 : u02
307-
p3 = isnothing(initdata.initializeprobpmap) ? p : p2
308-
309-
return u03, p3, success
304+
return u0, p, success
310305
end
311306

312307
"""

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,6 @@ Strips a SciMLSolution object and its interpolation of their functions to better
550550
function strip_solution(sol::AbstractSciMLSolution)
551551
sol
552552
end
553+
554+
choose_branch(x::OverrideInitData) = x.initializeprob
555+
choose_branch(sol::AbstractSciMLSolution) = sol

0 commit comments

Comments
 (0)