diff --git a/src/initialization.jl b/src/initialization.jl index c580f9440..423d2e93c 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -295,13 +295,18 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, end if initdata.initializeprobmap !== nothing - u0 = initdata.initializeprobmap(nlsol) + u02 = initdata.initializeprobmap(nlsol) end if initdata.initializeprobpmap !== nothing - p = initdata.initializeprobpmap(valp, nlsol) + p2 = initdata.initializeprobpmap(valp, nlsol) end - return u0, p, success + # specifically needs to be written this way for Zygote + # See https://github.com/SciML/ModelingToolkit.jl/pull/3585#issuecomment-2883919162 + u03 = isnothing(initdata.initializeprobmap) ? u0 : u02 + p3 = isnothing(initdata.initializeprobpmap) ? p : p2 + + return u03, p3, success end """