Skip to content

Commit cc2316f

Browse files
test: index with correct symbols
1 parent a19208a commit cc2316f

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

test/downstream/symbol_indexing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Test
1+
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Zygote, Test
22
using Optimization, OptimizationOptimJL
33
using ModelingToolkit: t_nounits as t, D_nounits as D
44

@@ -107,7 +107,7 @@ true_grad_sym[idx_sym] .= 1.
107107
@test "Symbolic Indexing Adjoint: Symbol" all(x -> x == true_grad_sym, gs_sym)
108108

109109
gs_vec, = Zygote.gradient(sol) do sol
110-
sum(sum.(sol[[lorenz1.x, lorenz2]]))
110+
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
111111
end
112112
idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
113113
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
@@ -116,7 +116,7 @@ true_grad_vecsym[idx_vecsym] .= 1.
116116
@test "Symbolic Indexing Adjoint: Vector{Symbol}" all(x -> x == true_grad_vecsym, gs_vec)
117117

118118
gs_tup, = Zygote.gradient(sol) do sol
119-
sum(sum.(collect.(sol[(lorenz1.x, lorenz2)])))
119+
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
120120
end
121121
idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
122122
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))

0 commit comments

Comments
 (0)