1
1
module DiffEqBaseReverseDiffExt
2
2
3
- if isdefined (Base, :get_extension )
4
- using DiffEqBase
5
- import DiffEqBase: value
6
- import ReverseDiff
7
- else
8
- using .. DiffEqBase
9
- import .. DiffEqBase: value
10
- import .. ReverseDiff
11
- end
3
+ using DiffEqBase
4
+ import DiffEqBase: value
5
+ import ReverseDiff
6
+ import DiffEqBase. ArrayInterface
12
7
13
8
DiffEqBase. value (x:: Type{ReverseDiff.TrackedReal{V, D, O}} ) where {V, D, O} = V
14
9
function DiffEqBase. value (x:: Type {
@@ -113,7 +108,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
113
108
u0:: AbstractArray{<:ReverseDiff.TrackedReal} ,
114
109
p:: AbstractArray{<:ReverseDiff.TrackedReal} , args... ;
115
110
kwargs... )
116
- DiffEqBase. solve_up (prob, sensealg, reduce (vcat, u0), reduce (vcat, p), args... ;
111
+ DiffEqBase. solve_up (prob, sensealg, ArrayInterface. aos_to_soa (u0),
112
+ ArrayInterface. aos_to_soa (p), args... ;
117
113
kwargs... )
118
114
end
119
115
@@ -123,7 +119,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
123
119
Nothing}, u0,
124
120
p:: AbstractArray{<:ReverseDiff.TrackedReal} ,
125
121
args... ; kwargs... )
126
- DiffEqBase. solve_up (prob, sensealg, u0, reduce (vcat, p), args... ; kwargs... )
122
+ DiffEqBase. solve_up (
123
+ prob, sensealg, u0, ArrayInterface. aos_to_soa (p), args... ; kwargs... )
127
124
end
128
125
129
126
function DiffEqBase. solve_up (prob:: DiffEqBase.AbstractDEProblem ,
@@ -132,7 +129,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
132
129
Nothing}, u0:: ReverseDiff.TrackedArray ,
133
130
p:: AbstractArray{<:ReverseDiff.TrackedReal} ,
134
131
args... ; kwargs... )
135
- DiffEqBase. solve_up (prob, sensealg, u0, reduce (vcat, p), args... ; kwargs... )
132
+ DiffEqBase. solve_up (
133
+ prob, sensealg, u0, ArrayInterface. aos_to_soa (p), args... ; kwargs... )
136
134
end
137
135
138
136
function DiffEqBase. solve_up (prob:: DiffEqBase.DEProblem ,
@@ -141,7 +139,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
141
139
Nothing},
142
140
u0:: AbstractArray{<:ReverseDiff.TrackedReal} , p,
143
141
args... ; kwargs... )
144
- DiffEqBase. solve_up (prob, sensealg, reduce (vcat, u0), p, args... ; kwargs... )
142
+ DiffEqBase. solve_up (
143
+ prob, sensealg, ArrayInterface. aos_to_soa (u0), p, args... ; kwargs... )
145
144
end
146
145
147
146
function DiffEqBase. solve_up (prob:: DiffEqBase.DEProblem ,
@@ -150,7 +149,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
150
149
Nothing},
151
150
u0:: AbstractArray{<:ReverseDiff.TrackedReal} , p:: ReverseDiff.TrackedArray ,
152
151
args... ; kwargs... )
153
- DiffEqBase. solve_up (prob, sensealg, reduce (vcat, u0), p, args... ; kwargs... )
152
+ DiffEqBase. solve_up (
153
+ prob, sensealg, ArrayInterface. aos_to_soa (u0), p, args... ; kwargs... )
154
154
end
155
155
156
156
# Required becase ReverseDiff.@grad function DiffEqBase.solve_up is not supported!
0 commit comments