From f75c944c9327845b85a2fa05f4cc0316e1457168 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 29 Apr 2024 11:49:13 -0400 Subject: [PATCH 1/6] Batch observed function eval if possible --- src/solutions/ode_solutions.jl | 5 +++-- src/solutions/solution_interface.jl | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 6ad75a3be0..7d56f8c7a3 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -203,7 +203,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol) - [is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs] + first(interp_sol[idxs]) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs, @@ -224,8 +224,9 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, error("Incorrect specification of `idxs`") interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + indexed_sol = interp_sol[idxs] return DiffEqArray( - [[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol) + [indexed_sol[i] for i in 1:length(t)], t, p, sol) end function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f72284ee99..609132a795 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -436,6 +436,18 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic, plot_vecs = [] labels = String[] varsyms = variable_symbols(sol) + batch_symbolic_vars = [] + for x in vars + for j in 2:length(x) + if (x[j] isa Integer && x[j] == 0) || isequal(x[j], getindepsym_defaultt(sol)) + else + push!(batch_symbolic_vars, x[j]) + end + end + end + batch_symbolic_vars = identity.(batch_symbolic_vars) + indexed_solution = sol(plott; idxs = batch_symbolic_vars) + idxx = 0 for x in vars tmp = [] strs = String[] @@ -444,7 +456,8 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic, push!(tmp, plott) push!(strs, "t") else - push!(tmp, sol(plott; idxs = x[j])) + idxx += 1 + push!(tmp, indexed_solution[idxx, :]) if !isempty(varsyms) && x[j] isa Integer push!(strs, String(getname(varsyms[x[j]]))) elseif hasname(x[j]) From 32aa2c694829129437dbffd33f726a4d75d56702 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Apr 2024 17:02:32 +0530 Subject: [PATCH 2/6] fix: handle edge case where `sol.u isa Vector{<:Number}` --- src/solutions/ode_solutions.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 7d56f8c7a3..46bc70f926 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -172,6 +172,9 @@ end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector{<:Integer}, continuity) where {deriv} + if eltype(sol.u) <: Number + idxs = only(idxs) + end sol.interp(t, idxs, deriv, sol.prob.p, continuity) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, @@ -183,6 +186,9 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::AbstractVector{<:Integer}, continuity) where {deriv} + if eltype(sol.u) <: Number + idxs = only(idxs) + end A = sol.interp(t, idxs, deriv, sol.prob.p, continuity) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing return DiffEqArray(A.u, A.t, p, sol) From b927080c834ac5b98f6cf22816ad9c54715104c0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Apr 2024 17:02:42 +0530 Subject: [PATCH 3/6] test: fix MTK remake test --- test/downstream/modelingtoolkit_remake.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 78d47822c5..61a6e3b174 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -137,7 +137,8 @@ eqs = [D(x) ~ Hold(ud) xd ~ Sample(t, dt)(x)] @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [p3 => 2p1]) prob = ODEProblem(sys, [x => 1.0], (0.0, 5.0), - [p1 => 1.0, p2 => 2, ud(k - 1) => 3.0, xd(k - 1) => 4.0, xd(k - 2) => 5.0]) + [p1 => 1.0, p2 => 2, ud(k - 1) => 3.0, + xd(k - 1) => 4.0, xd(k - 2) => 5.0, yd(k - 1) => 0.0]) # parameter dependencies prob2 = @inferred ODEProblem remake(prob; p = [p1 => 2.0]) From 3376d560db2e578dec66ef0b43813d152c46c7d7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Apr 2024 17:02:54 +0530 Subject: [PATCH 4/6] fix: fix indexing test --- test/downstream/symbol_indexing.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 893fdf437c..98d07c5ba7 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -93,9 +93,9 @@ end @test length(sol[(lorenz1.x, lorenz2.x)]) == length(sol) @test all(length.(sol[(lorenz1.x, lorenz2.x)]) .== 2) -@test sol[[lorenz1.x, lorenz2.x], :] isa Matrix{Float64} -@test size(sol[[lorenz1.x, lorenz2.x], :]) == (2, length(sol)) -@test size(sol[[lorenz1.x, lorenz2.x], :]) == size(sol[[1, 2], :]) == size(sol[1:2, :]) +@test sol[[lorenz1.x, lorenz2.x], :] isa Vector{Vector{Float64}} +@test length(sol[[lorenz1.x, lorenz2.x], :]) == length(sol) +@test length(sol[[lorenz1.x, lorenz2.x], :][1]) == 2 @variables q(t)[1:2] = [1.0, 2.0] eqs = [D(q[1]) ~ 2q[1] From cc1b0508457a90d41c9d613daf4424d6a9b457ae Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Apr 2024 20:29:47 +0530 Subject: [PATCH 5/6] build: bump RAT compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 645bbfa887..fc8bf10c83 100644 --- a/Project.toml +++ b/Project.toml @@ -76,7 +76,7 @@ PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" RecipesBase = "1.3.4" -RecursiveArrayTools = "3.8.0" +RecursiveArrayTools = "3.14.0" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "0.3.7" From 668f3f7927215bf44d306c856fc4b7b9745c6fe2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 May 2024 18:59:46 +0530 Subject: [PATCH 6/6] build: bump SII compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fc8bf10c83..38151cef35 100644 --- a/Project.toml +++ b/Project.toml @@ -84,7 +84,7 @@ SciMLStructures = "1.1" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" -SymbolicIndexingInterface = "0.3.15" +SymbolicIndexingInterface = "0.3.20" Tables = "1.11" Zygote = "0.6.67" julia = "1.10"