Skip to content

Commit bba5c52

Browse files
fix: update SII syntax in CR ext also
1 parent cc2316f commit bba5c52

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SciMLBaseChainRulesCoreExt
33
using SciMLBase
44
import ChainRulesCore
55
import ChainRulesCore: NoTangent, @non_differentiable
6+
using SymbolicIndexingInterface
67

78
function ChainRulesCore.rrule(
89
config::ChainRulesCore.RuleConfig{
@@ -13,7 +14,7 @@ function ChainRulesCore.rrule(
1314
sym,
1415
j::Integer)
1516
function ODESolution_getindex_pullback(Δ)
16-
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
17+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
1718
if i === nothing
1819
getter = getobserved(VA)
1920
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
@@ -66,7 +67,7 @@ end
6667

6768
function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
6869
function ODESolution_getindex_pullback(Δ)
69-
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
70+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
7071
if i === nothing
7172
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
7273
else

ext/SciMLBaseZygoteExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using RecursiveArrayTools
1313
# This method resolves the ambiguity with the pullback defined in
1414
# RecursiveArrayToolsZygoteExt
1515
# https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67
16-
@adjoint function getindex(VA::ODESolution, i::Int, j::Int)
16+
@adjoint function Base.getindex(VA::ODESolution, i::Int, j::Int)
1717
function ODESolution_getindex_pullback(Δ)
1818
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
1919
zero(VA.u[1]) for m in 1:length(VA.u)]
@@ -38,7 +38,7 @@ using RecursiveArrayTools
3838
VA[i, j], ODESolution_getindex_pullback
3939
end
4040

41-
@adjoint function getindex(VA::ODESolution, sym, j::Int)
41+
@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
4242
function ODESolution_getindex_pullback(Δ)
4343
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
4444
du, dprob = if i === nothing
@@ -92,7 +92,7 @@ end
9292
out, EnsembleSolution_adjoint
9393
end
9494

95-
@adjoint function getindex(VA::ODESolution, i::Int)
95+
@adjoint function Base.getindex(VA::ODESolution, i::Int)
9696
function ODESolution_getindex_pullback(Δ)
9797
Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x)))
9898
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -106,7 +106,7 @@ end
106106
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
107107
end
108108

109-
@adjoint function getindex(VA::ODESolution, sym)
109+
@adjoint function Base.getindex(VA::ODESolution, sym)
110110
function ODESolution_getindex_pullback(Δ)
111111
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
112112
if i === nothing
@@ -120,7 +120,7 @@ end
120120
VA[sym], ODESolution_getindex_pullback
121121
end
122122

123-
@adjoint function getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T
123+
@adjoint function Base.getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T
124124
function ODESolution_getindex_pullback(Δ)
125125
sym = sym isa Tuple ? collect(sym) : sym
126126
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)

test/downstream/symbol_indexing.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ gs_sym, = Zygote.gradient(sol) do sol
102102
end
103103
idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x)
104104
true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys)))
105-
true_grad_sym[idx_sym] .= 1.
105+
true_grad_sym[idx_sym] = 1.
106106

107-
@test "Symbolic Indexing Adjoint: Symbol" all(x -> x == true_grad_sym, gs_sym)
107+
@test all(map(x -> x == true_grad_sym, gs_sym))
108108

109109
gs_vec, = Zygote.gradient(sol) do sol
110110
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
@@ -113,7 +113,7 @@ idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lor
113113
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
114114
true_grad_vecsym[idx_vecsym] .= 1.
115115

116-
@test "Symbolic Indexing Adjoint: Vector{Symbol}" all(x -> x == true_grad_vecsym, gs_vec)
116+
@test all(map(x -> x == true_grad_vecsym, gs_vec))
117117

118118
gs_tup, = Zygote.gradient(sol) do sol
119119
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
@@ -122,13 +122,13 @@ idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lor
122122
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
123123
true_grad_tupsym[idx_tupsym] .= 1.
124124

125-
@test "Symbolic Indexing Adjoint: Tuple{Symbol}" all(x -> x == true_grad_tupsym, gs_tup)
125+
@test all(x -> x == true_grad_tupsym, gs_tup)
126126

127127
gs_ts, = Zygote.gradient(sol) do sol
128128
sum(sol[[lorenz1.x, lorenz2], :])
129129
end
130130

131-
@test "Symbolic Indexing Adjoint: Timeseries/ Vector{Symbol}" all(x -> x == true_grad_vecsym, gs_ts)
131+
@test all(x -> x == true_grad_vecsym, gs_ts)
132132

133133
@variables q(t)[1:2] = [1.0, 2.0]
134134
eqs = [D(q[1]) ~ 2q[1]

0 commit comments

Comments
 (0)