From e11fbfde0f35de200c784d682d95e96858f01355 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 20 May 2025 17:34:31 +0530 Subject: [PATCH] Revert "Make `get_initial_values` differentiate" --- src/initialization.jl | 13 +++++++++---- src/utils.jl | 3 --- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 4d25633f6..423d2e93c 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -260,7 +260,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, end if is_trivial_initialization(initdata) - nlsol = initdata + nlsol = initprob success = true else nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) @@ -295,13 +295,18 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, end if initdata.initializeprobmap !== nothing - u0 = initdata.initializeprobmap(choose_branch(nlsol)) + u02 = initdata.initializeprobmap(nlsol) end if initdata.initializeprobpmap !== nothing - p = initdata.initializeprobpmap(valp, choose_branch(nlsol)) + p2 = initdata.initializeprobpmap(valp, nlsol) end - return u0, p, success + # specifically needs to be written this way for Zygote + # See https://github.com/SciML/ModelingToolkit.jl/pull/3585#issuecomment-2883919162 + u03 = isnothing(initdata.initializeprobmap) ? u0 : u02 + p3 = isnothing(initdata.initializeprobpmap) ? p : p2 + + return u03, p3, success end """ diff --git a/src/utils.jl b/src/utils.jl index 5720be1cd..ecded5af1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -550,6 +550,3 @@ Strips a SciMLSolution object and its interpolation of their functions to better function strip_solution(sol::AbstractSciMLSolution) sol end - -choose_branch(x::OverrideInitData) = x.initializeprob -choose_branch(sol::AbstractSciMLSolution) = sol \ No newline at end of file