Skip to content

Commit 7c2effa

Browse files
refactor: ODESolution dispatches to avoid piracy
1 parent 32f9216 commit 7c2effa

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

src/NeuralPDE.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using Reexport: @reexport
2929
using RuntimeGeneratedFunctions: RuntimeGeneratedFunctions, @RuntimeGeneratedFunction
3030
using SciMLBase: SciMLBase, BatchIntegralFunction, IntegralProblem, NoiseProblem,
3131
OptimizationFunction, OptimizationProblem, ReturnCode, discretize,
32-
isinplace, solve, symbolic_discretize, ODEProblem, AbstractODESolution
32+
isinplace, solve, symbolic_discretize, ODEProblem, ODESolution
3333
using Statistics: Statistics, mean
3434
using QuasiMonteCarlo: QuasiMonteCarlo, LatinHypercubeSample
3535
using WeightInitializers: glorot_uniform, zeros32

src/ode_solve.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,12 @@ function (f::NNODEInterpolation)(t::Vector, idxs, ::Type{Val{0}}, p, continuity)
281281
return DiffEqArray([out[idxs, i] for i in axes(out, 2)], t)
282282
end
283283

284+
function (sol::ODESolution{T, N, U, U2, D, T2, R, D2, P, A})(
285+
t::AbstractVector{<:Number}, ::Type{deriv}, idxs::Nothing,
286+
continuity) where {T, N, U, U2, D, T2, R, D2, P, A <: NNODE, deriv}
287+
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
288+
end
289+
284290
SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation"
285291
SciMLBase.allowscomplex(::NNODE) = true
286292

src/pino_ode_solve.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,15 @@ end
292292
SciMLBase.interp_summary(::PINOODEInterpolation) = "Trained neural network interpolation"
293293
SciMLBase.allowscomplex(::PINOODE) = true
294294

295-
function (sol::AbstractODESolution)(
295+
function (sol::ODESolution{T, N, U, U2, D, T2, R, D2, P, A})(
296296
t::AbstractArray, ::Type{deriv}, idxs::Nothing,
297-
continuity) where {deriv}
297+
continuity) where {T, N, U, U2, D, T2, R, D2, P, A <: PINOODE, deriv}
298+
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
299+
end
300+
301+
function (sol::ODESolution{T, N, U, U2, D, T2, R, D2, P, A})(
302+
t::AbstractVector{<:Number}, ::Type{deriv}, idxs::Nothing,
303+
continuity) where {T, N, U, U2, D, T2, R, D2, P, A <: PINOODE, deriv}
298304
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
299305
end
300306

0 commit comments

Comments
 (0)