diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 0e84ccbc3..4e9412063 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -111,15 +111,15 @@ end function ChainRulesCore.rrule( ::Type{ - <:ODESolution{uType, tType, isinplace, P, NP, F, G, K, + <:RODESolution{uType, tType, isinplace, P, NP, F, G, K, ND }}, u, args...) where {uType, tType, isinplace, P, NP, F, G, K, ND} - function SDESolutionAdjoint(ȳ) + function RODESolutionAdjoint(ȳ) (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) end - SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint + RODESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), RODESolutionAdjoint end function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index c7ba1827e..a89bb3a93 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -135,11 +135,11 @@ end @adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...) where {uType, tType, isinplace, P, NP, F, G, K, ND} - function SDESolutionAdjoint(ȳ) + function SDEProblemAdjoint(ȳ) (ȳ, ntuple(_ -> nothing, length(args))...) end - SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint + SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDEProblemAdjoint end @adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u,