You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
68
69
functionODESolution_getindex_pullback(Δ)
69
-
i =symbolic_type(sym) !=NotSymbolic() ?sym_to_index(sym, VA) : sym
70
+
i =symbolic_type(sym) !=NotSymbolic() ?variable_index(VA, sym) : sym
70
71
if i ===nothing
71
72
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
i =symbolic_type(sym) !=NotSymbolic() ?variable_index(VA, sym) : sym
112
112
if i ===nothing
@@ -120,6 +120,30 @@ end
120
120
VA[sym], ODESolution_getindex_pullback
121
121
end
122
122
123
+
@adjointfunction Base.getindex(
124
+
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
125
+
functionODESolution_getindex_pullback(Δ)
126
+
sym = sym isa Tuple ?collect(sym) : sym
127
+
i =map(x ->symbolic_type(x) !=NotSymbolic() ?variable_index(VA, x) : x, sym)
128
+
if i ===nothing
129
+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
0 commit comments