Skip to content

Commit 8656a17

Browse files
refactor: rewrite OverrideInit to be more Zygote compatible
Co-authored-by: Dhairya Gandhi <dhairya@juliahub.com>
1 parent ecdc172 commit 8656a17

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/initialization.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,18 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
295295
end
296296

297297
if initdata.initializeprobmap !== nothing
298-
u0 = initdata.initializeprobmap(nlsol)
298+
u02 = initdata.initializeprobmap(nlsol)
299299
end
300300
if initdata.initializeprobpmap !== nothing
301-
p = initdata.initializeprobpmap(valp, nlsol)
301+
p2 = initdata.initializeprobpmap(valp, nlsol)
302302
end
303303

304-
return u0, p, success
304+
# specifically needs to be written this way for Zygote
305+
# See https://github.com/SciML/ModelingToolkit.jl/pull/3585#issuecomment-2883919162
306+
u03 = isnothing(initdata.initializeprobmap) ? u0 : u02
307+
p3 = isnothing(initdata.initializeprobpmap) ? p : p2
308+
309+
return u03, p3, success
305310
end
306311

307312
"""

0 commit comments

Comments
 (0)