Skip to content

Support in-place interpolation of symbolic idxs #988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ function is_discrete_expression(indp, expr)
length(ts_idxs) > 1 || length(ts_idxs) == 1 && only(ts_idxs) != ContinuousTimeseries()
end

# These are the two main documented user-facing interpolation API functions (out-of-place and in-place versions)
function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
continuity = :left) where {deriv}
if t isa IndexedClock
Expand All @@ -225,9 +226,12 @@ function (sol::AbstractODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing
if t isa IndexedClock
t = canonicalize_indexed_clock(t, sol)
end
sol.interp(v, t, idxs, deriv, sol.prob.p, continuity)
sol(v, t, deriv, idxs, continuity)
end

# Below are many internal dispatches for different combinations of arguments to the main API
# TODO: could use a clever rewrite, since a lot of reused code has accumulated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that the symbolic dispatch is kept open, since there's more than one issymbolic type. We should union that in the future, but it needs JuliaSymbolics/SymbolicUtils.jl#737


function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::Nothing,
continuity) where {deriv}
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
Expand Down Expand Up @@ -365,6 +369,52 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
return DiffEqArray(u, t, p, sol; discretes)
end

function (sol::AbstractODESolution)(v::AbstractArray, t::Number, ::Type{deriv},
idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv}
return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity)
end
function (sol::AbstractODESolution)(
v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv},
idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv}
return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity)
end
function (sol::AbstractODESolution)(
v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
getter = getsym(sol, idxs)
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
u = zeros(eltype(sol), size(sol)[1])
v .= map(eachindex(t)) do ti
sol.interp(u, t[ti], nothing, deriv, p, continuity)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
v .= map(eachindex(t)) do ti
sol.interp(u, t[ti], nothing, deriv, p, continuity)
if eltype(v) <: Number
v[ti] = sol.interp(t[ti], nothing, deriv, p, continuity)
else
for ti in eachindex(t)
sol.interp(v[ti], t[ti], nothing, deriv, p, continuity)
end
end

return getter(ProblemState(; u = u, p = p, t = t[ti]))
end
return v
end
error("In-place interpolation with discretes is not implemented.")
end
function (sol::AbstractODESolution)(
v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector, continuity) where {deriv}
if symbolic_type(idxs) == NotSymbolic() && isempty(idxs)
return map(_ -> eltype(eltype(sol.u))[], t)
end
error_if_observed_derivative(sol, idxs, deriv)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
getter = getsym(sol, idxs)
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
u = zeros(eltype(sol), size(sol)[1])
v .= map(eachindex(t)) do ti
sol.interp(u, t[ti], nothing, deriv, p, continuity)
return getter(ProblemState(; u = u, p = p, t = t[ti]))
end
return v
end
error("In-place interpolation with discretes is not implemented.")
end

struct DDESolutionHistoryWrapper{T}
sol::T
end
Expand Down
23 changes: 22 additions & 1 deletion test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, StochasticDiffEq, Test
using StochasticDiffEq
using SymbolicIndexingInterface
using ModelingToolkit: t_nounits as t, D_nounits as D
using ModelingToolkit: observed, t_nounits as t, D_nounits as D
using Plots: Plots, plot

### Tests on non-layered model (everything should work). ###
Expand Down Expand Up @@ -148,6 +148,27 @@ sol9 = sol(0.0:1.0:10.0, idxs = 2)
sol10 = sol(0.1, idxs = 2)
@test sol10 isa Real

# in-place interpolation with single (unknown) symbolic index
ts = 0.0:0.1:10.0
out = zeros(eltype(sol), size(ts))
idxs = unknowns(sys)[1]
@test sol(out, ts; idxs) == sol(ts; idxs)
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
@test_nowarn @inferred sol(out, ts; idxs)

# in-place interpolation with single (observed) symbolic index
idxs = observed(sys)[1].lhs
@test sol(out, ts; idxs) == sol(ts; idxs)
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
@test_nowarn @inferred sol(out, ts; idxs)

# in-place interpolation with multiple (unknown+observed) symbolic indices
idxs = [unknowns(sys)[1], observed(sys)[1].lhs]
out = [zeros(eltype(sol), size(idxs)) for _ in eachindex(ts)]
@test sol(out, ts; idxs) == sol(ts; idxs).u
@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs))
@test_nowarn @inferred sol(out, ts; idxs)

@testset "Plot idxs" begin
@variables x(t) y(t)
@parameters p
Expand Down
Loading