Skip to content

Commit 737942e

Browse files
committed
Support in-place interpolation of symbolic idxs
1 parent 8c0a53a commit 737942e

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

src/solutions/ode_solutions.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ function is_discrete_expression(indp, expr)
213213
length(ts_idxs) > 1 || length(ts_idxs) == 1 && only(ts_idxs) != ContinuousTimeseries()
214214
end
215215

216+
# These are the two main documented user-facing interpolation API functions (out-of-place and in-place versions)
216217
function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
217218
continuity = :left) where {deriv}
218219
if t isa IndexedClock
@@ -225,9 +226,12 @@ function (sol::AbstractODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing
225226
if t isa IndexedClock
226227
t = canonicalize_indexed_clock(t, sol)
227228
end
228-
sol.interp(v, t, idxs, deriv, sol.prob.p, continuity)
229+
sol(v, t, deriv, idxs, continuity)
229230
end
230231

232+
# Below are many internal dispatches for different combinations of arguments to the main API
233+
# TODO: could use a clever rewrite, since a lot of reused code has accumulated
234+
231235
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::Nothing,
232236
continuity) where {deriv}
233237
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
@@ -365,6 +369,41 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
365369
return DiffEqArray(u, t, p, sol; discretes)
366370
end
367371

372+
function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
373+
continuity) where {deriv}
374+
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
375+
error_if_observed_derivative(sol, idxs, deriv)
376+
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
377+
getter = getsym(sol, idxs)
378+
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
379+
u = zeros(eltype(sol), size(sol)[1])
380+
v .= map(eachindex(t)) do ti
381+
sol.interp(u, t[ti], nothing, deriv, p, continuity)
382+
return getter(ProblemState(; u = u, p = p, t = t[ti]))
383+
end
384+
return v
385+
end
386+
error("In-place interpolation with discretes is not implemented.")
387+
end
388+
function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv},
389+
idxs::AbstractVector, continuity) where {deriv}
390+
if symbolic_type(idxs) == NotSymbolic() && isempty(idxs)
391+
return map(_ -> eltype(eltype(sol.u))[], t)
392+
end
393+
error_if_observed_derivative(sol, idxs, deriv)
394+
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
395+
getter = getsym(sol, idxs)
396+
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
397+
u = zeros(eltype(sol), size(sol)[1])
398+
v .= map(eachindex(t)) do ti
399+
sol.interp(u, t[ti], nothing, deriv, p, continuity)
400+
return getter(ProblemState(; u = u, p = p, t = t[ti]))
401+
end
402+
return v
403+
end
404+
error("In-place interpolation with discretes is not implemented.")
405+
end
406+
368407
struct DDESolutionHistoryWrapper{T}
369408
sol::T
370409
end

test/downstream/solution_interface.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,27 @@ sol9 = sol(0.0:1.0:10.0, idxs = 2)
148148
sol10 = sol(0.1, idxs = 2)
149149
@test sol10 isa Real
150150

151+
# in-place interpolation with single (unknown) symbolic index
152+
ts = 0.0:0.1:10.0
153+
out = zeros(eltype(sol), size(ts))
154+
idxs = unknowns(sys)[1]
155+
@test sol(out, ts; idxs) == sol(ts; idxs)
156+
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
157+
@test_nowarn @inferred sol(out, ts; idxs)
158+
159+
# in-place interpolation with single (observed) symbolic index
160+
idxs = observed(sys)[1].lhs
161+
@test sol(out, ts; idxs) == sol(ts; idxs)
162+
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
163+
@test_nowarn @inferred sol(out, ts; idxs)
164+
165+
# in-place interpolation with multiple (unknown+observed) symbolic indices
166+
idxs = [unknowns(sys)[1], observed(sys)[1].lhs]
167+
out = [zeros(eltype(sol), size(idxs)) for _ in eachindex(ts)]
168+
@test sol(out, ts; idxs) == sol(ts; idxs).u
169+
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
170+
@test_nowarn @inferred sol(out, ts; idxs)
171+
151172
@testset "Plot idxs" begin
152173
@variables x(t) y(t)
153174
@parameters p

0 commit comments

Comments
 (0)