Skip to content

Commit c261a03

Browse files
Merge pull request #678 from DhairyaLGandhi/dg/vecsym
fix: Correct gradients for vector of symbols while indexing
2 parents fc5a573 + 071ad2f commit c261a03

File tree

3 files changed

+66
-7
lines changed

3 files changed

+66
-7
lines changed

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SciMLBaseChainRulesCoreExt
33
using SciMLBase
44
import ChainRulesCore
55
import ChainRulesCore: NoTangent, @non_differentiable
6+
using SymbolicIndexingInterface
67

78
function ChainRulesCore.rrule(
89
config::ChainRulesCore.RuleConfig{
@@ -13,7 +14,7 @@ function ChainRulesCore.rrule(
1314
sym,
1415
j::Integer)
1516
function ODESolution_getindex_pullback(Δ)
16-
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
17+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
1718
if i === nothing
1819
getter = getobserved(VA)
1920
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
@@ -66,7 +67,7 @@ end
6667

6768
function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
6869
function ODESolution_getindex_pullback(Δ)
69-
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
70+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
7071
if i === nothing
7172
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
7273
else

ext/SciMLBaseZygoteExt.jl

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using RecursiveArrayTools
1313
# This method resolves the ambiguity with the pullback defined in
1414
# RecursiveArrayToolsZygoteExt
1515
# https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67
16-
@adjoint function getindex(VA::ODESolution, i::Int, j::Int)
16+
@adjoint function Base.getindex(VA::ODESolution, i::Int, j::Int)
1717
function ODESolution_getindex_pullback(Δ)
1818
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
1919
zero(VA.u[1]) for m in 1:length(VA.u)]
@@ -38,7 +38,7 @@ using RecursiveArrayTools
3838
VA[i, j], ODESolution_getindex_pullback
3939
end
4040

41-
@adjoint function getindex(VA::ODESolution, sym, j::Int)
41+
@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
4242
function ODESolution_getindex_pullback(Δ)
4343
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
4444
du, dprob = if i === nothing
@@ -92,7 +92,7 @@ end
9292
out, EnsembleSolution_adjoint
9393
end
9494

95-
@adjoint function getindex(VA::ODESolution, i::Int)
95+
@adjoint function Base.getindex(VA::ODESolution, i::Int)
9696
function ODESolution_getindex_pullback(Δ)
9797
Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x)))
9898
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -106,7 +106,7 @@ end
106106
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
107107
end
108108

109-
@adjoint function getindex(VA::ODESolution, sym)
109+
@adjoint function Base.getindex(VA::ODESolution, sym)
110110
function ODESolution_getindex_pullback(Δ)
111111
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
112112
if i === nothing
@@ -120,6 +120,30 @@ end
120120
VA[sym], ODESolution_getindex_pullback
121121
end
122122

123+
@adjoint function Base.getindex(
124+
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
125+
function ODESolution_getindex_pullback(Δ)
126+
sym = sym isa Tuple ? collect(sym) : sym
127+
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
128+
if i === nothing
129+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
130+
else
131+
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
132+
map(enumerate(us)) do (u_idx, u)
133+
if u_idx in i
134+
idx = findfirst(isequal(u_idx), i)
135+
Δ[t_idx][idx]
136+
else
137+
zero(T)
138+
end
139+
end
140+
end
141+
(Δ′, nothing)
142+
end
143+
end
144+
VA[sym], ODESolution_getindex_pullback
145+
end
146+
123147
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
124148
}(u,
125149
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,

test/downstream/symbol_indexing.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Test
1+
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface,
2+
Zygote, Test
23
using Optimization, OptimizationOptimJL
34
using ModelingToolkit: t_nounits as t, D_nounits as D
45

@@ -97,6 +98,39 @@ end
9798
@test length(sol[[lorenz1.x, lorenz2.x], :]) == length(sol)
9899
@test length(sol[[lorenz1.x, lorenz2.x], :][1]) == 2
99100

101+
gs_sym, = Zygote.gradient(sol) do sol
102+
sum(sol[lorenz1.x])
103+
end
104+
idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x)
105+
true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys)))
106+
true_grad_sym[idx_sym] = 1.0
107+
108+
@test all(map(x -> x == true_grad_sym, gs_sym))
109+
110+
gs_vec, = Zygote.gradient(sol) do sol
111+
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
112+
end
113+
idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
114+
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
115+
true_grad_vecsym[idx_vecsym] .= 1.0
116+
117+
@test all(map(x -> x == true_grad_vecsym, gs_vec))
118+
119+
gs_tup, = Zygote.gradient(sol) do sol
120+
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
121+
end
122+
idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
123+
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
124+
true_grad_tupsym[idx_tupsym] .= 1.0
125+
126+
@test all(map(x -> x == true_grad_tupsym, gs_tup))
127+
128+
gs_ts, = Zygote.gradient(sol) do sol
129+
sum(sol[[lorenz1.x, lorenz2.x], :])
130+
end
131+
132+
@test all(map(x -> x == true_grad_vecsym, gs_ts))
133+
100134
@variables q(t)[1:2] = [1.0, 2.0]
101135
eqs = [D(q[1]) ~ 2q[1]
102136
D(q[2]) ~ 2.0]

0 commit comments

Comments
 (0)