Skip to content

Commit fed63df

Browse files
Merge pull request #1039 from jClugstor/autospecialize_gradients_fix
Add adjoint for ODEProblem constructor w/ iip, specialization
2 parents ec0afb6 + 2bffd29 commit fed63df

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
6464
ODEProblem(args...; kwargs...), ODEProblemAdjoint
6565
end
6666

67+
function ChainRulesCore.rrule(::Type{
68+
<:ODEProblem{iip, T}}, args...; kwargs...) where {iip, T}
69+
function ODEProblemAdjoint(ȳ)
70+
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
71+
end
72+
73+
ODEProblem(args...; kwargs...), ODEProblemAdjoint
74+
end
75+
6776
function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...)
6877
function SDEProblemAdjoint(ȳ)
6978
(NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)

0 commit comments

Comments
 (0)