diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 028d505ba..71d14c8a9 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -116,4 +116,15 @@ function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged out, EnsembleSolution_adjoint end +function ChainRulesCore.rrule( + ::Type{<:SciMLBase.NonlinearSolution{ + T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob, + args...) where {T, N, uType, R, P, A, O, uType2, S, Tr} + function NonlinearSolutionAdjoint(ȳ) + (NoTangent(), ȳ.u, NoTangent(), ȳ.prob, ntuple(_ -> NoTangent(), length(args))...) + end + SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, resid, prob, args...), + NonlinearSolutionAdjoint +end + end diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 45a8e0f63..b22a5b051 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -228,7 +228,7 @@ end uType2 } function NonlinearSolutionAdjoint(ȳ) - (ȳ, ntuple(_ -> nothing, length(args))...) + (ȳ.u, ntuple(_ -> nothing, length(args))...) end NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint end