Skip to content

Commit 4f9368b

Browse files
Merge pull request #1018 from SciML/reversediff
Fix ReverseDiff Array of Structs to Struct of Array
2 parents e8de48e + 7c00788 commit 4f9368b

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ DiffEqBaseTrackerExt = "Tracker"
5858
DiffEqBaseUnitfulExt = "Unitful"
5959

6060
[compat]
61-
ArrayInterface = "7"
61+
ArrayInterface = "7.8"
6262
ChainRulesCore = "1"
6363
DataStructures = "0.18"
6464
Distributions = "0.25"

ext/DiffEqBaseReverseDiffExt.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
module DiffEqBaseReverseDiffExt
22

3-
if isdefined(Base, :get_extension)
4-
using DiffEqBase
5-
import DiffEqBase: value
6-
import ReverseDiff
7-
else
8-
using ..DiffEqBase
9-
import ..DiffEqBase: value
10-
import ..ReverseDiff
11-
end
3+
using DiffEqBase
4+
import DiffEqBase: value
5+
import ReverseDiff
6+
import DiffEqBase.ArrayInterface
127

138
DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
149
function DiffEqBase.value(x::Type{
@@ -113,7 +108,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
113108
u0::AbstractArray{<:ReverseDiff.TrackedReal},
114109
p::AbstractArray{<:ReverseDiff.TrackedReal}, args...;
115110
kwargs...)
116-
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), reduce(vcat, p), args...;
111+
DiffEqBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0),
112+
ArrayInterface.aos_to_soa(p), args...;
117113
kwargs...)
118114
end
119115

@@ -123,7 +119,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
123119
Nothing}, u0,
124120
p::AbstractArray{<:ReverseDiff.TrackedReal},
125121
args...; kwargs...)
126-
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
122+
DiffEqBase.solve_up(
123+
prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...)
127124
end
128125

129126
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
@@ -132,7 +129,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
132129
Nothing}, u0::ReverseDiff.TrackedArray,
133130
p::AbstractArray{<:ReverseDiff.TrackedReal},
134131
args...; kwargs...)
135-
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
132+
DiffEqBase.solve_up(
133+
prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...)
136134
end
137135

138136
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
@@ -141,7 +139,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
141139
Nothing},
142140
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
143141
args...; kwargs...)
144-
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...)
142+
DiffEqBase.solve_up(
143+
prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...)
145144
end
146145

147146
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
@@ -150,7 +149,8 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
150149
Nothing},
151150
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray,
152151
args...; kwargs...)
153-
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...)
152+
DiffEqBase.solve_up(
153+
prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...)
154154
end
155155

156156
# Required becase ReverseDiff.@grad function DiffEqBase.solve_up is not supported!

0 commit comments

Comments
 (0)