From 83e7abe70bb4e06d58cda35738c969ed2e5dd587 Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Wed, 24 Jul 2024 13:26:54 +0200 Subject: [PATCH 1/4] implement et_tmp_cache --- Project.toml | 2 +- src/PositiveIntegrators.jl | 3 ++- src/mprk.jl | 6 ++++++ src/sspmprk.jl | 4 ++++ test/runtests.jl | 18 +++++++++++++++--- 5 files changed, 28 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index fdc71584..32939ff8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PositiveIntegrators" uuid = "d1b20bf0-b083-4985-a874-dc5121669aa5" authors = ["Stefan Kopecz, Hendrik Ranocha, and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" diff --git a/src/PositiveIntegrators.jl b/src/PositiveIntegrators.jl index dc87fb75..87c3f589 100644 --- a/src/PositiveIntegrators.jl +++ b/src/PositiveIntegrators.jl @@ -40,7 +40,8 @@ using OrdinaryDiffEq: @cache, recursivefill!, _vec, wrapprecs, dolinsolve import OrdinaryDiffEq: alg_order, isfsal, calculate_residuals, calculate_residuals!, - alg_cache, initialize!, perform_step!, + alg_cache, get_tmp_cache, + initialize!, perform_step!, _ode_interpolant, _ode_interpolant! # 2. Export functionality defining the public API diff --git a/src/mprk.jl b/src/mprk.jl index 6e54e4b2..6814f23b 100644 --- a/src/mprk.jl +++ b/src/mprk.jl @@ -337,6 +337,8 @@ struct MPEConservativeCache{PType, uType, tabType, F} <: OrdinaryDiffEqMutableCa linsolve::F end +get_tmp_cache(integrator, ::MPE, cache::OrdinaryDiffEqMutableCache) = (cache.σ,) + # In-place function alg_cache(alg::MPE, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, @@ -638,6 +640,8 @@ struct MPRK22ConservativeCache{uType, PType, tabType, F} <: linsolve::F end +get_tmp_cache(integrator, ::MPRK22, cache::OrdinaryDiffEqMutableCache) = (cache.σ,) + # In-place function alg_cache(alg::MPRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, @@ -1218,6 +1222,8 @@ struct MPRK43ConservativeCache{uType, PType, tabType, F} <: OrdinaryDiffEqMutabl linsolve::F end +get_tmp_cache(integrator, ::Union{MPRK43I, MPRK43II}, cache::OrdinaryDiffEqMutableCache) = (cache.σ,) + # In-place function alg_cache(alg::Union{MPRK43I, MPRK43II}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, diff --git a/src/sspmprk.jl b/src/sspmprk.jl index c25c79a3..5473033b 100644 --- a/src/sspmprk.jl +++ b/src/sspmprk.jl @@ -223,6 +223,8 @@ struct SSPMPRK22ConservativeCache{uType, PType, tabType, F} <: linsolve::F end +get_tmp_cache(integrator, ::SSPMPRK22, cache::OrdinaryDiffEqMutableCache) = (cache.σ,) + # In-place function alg_cache(alg::SSPMPRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, @@ -734,6 +736,8 @@ struct SSPMPRK43ConservativeCache{uType, PType, tabType, F} <: OrdinaryDiffEqMut linsolve::F end +get_tmp_cache(integrator, ::SSPMPRK43, cache::OrdinaryDiffEqMutableCache) = (cache.σ,) + # In-place function alg_cache(alg::SSPMPRK43, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, diff --git a/test/runtests.jl b/test/runtests.jl index 0148c6d6..a8a7c921 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1198,7 +1198,7 @@ end end end - # Here we check that the type of p_prototype actually + # Here we check that the type of p_prototype actually # defines the types of the Ps inside the algorithm caches. # We test sparse, tridiagonal, and dense matrices. @testset "Prototype type check" begin @@ -1246,7 +1246,7 @@ end p_prototype = P_dense) prob_sparse = ConservativePDSProblem(prod_sparse!, u0, tspan; p_prototype = P_sparse) - ## nonconservative PDS + ## nonconservative PDS prob_default2 = PDSProblem(prod_dense!, dest!, u0, tspan) prob_tridiagonal2 = PDSProblem(prod_tridiagonal!, dest!, u0, tspan; p_prototype = P_tridiagonal) @@ -1262,7 +1262,19 @@ end for prob in (prob_default, prob_tridiagonal, prob_dense, prob_sparse, prob_default2, prob_tridiagonal2, prob_dense2, prob_sparse2) - solve(prob, alg; dt, adaptive = false) + sol1 = solve(prob, alg; dt, adaptive = false) + + # test get_tmp_cache and integrator interface - modifying + # values from the cache should not changes the final results + integrator = init(prob, alg; dt, adaptive = false) + step!(integrator) + cache = @inferred get_tmp_cache(integrator) + @test !isempty(cache) + tmp = first(cache) + fill!(tmp, NaN) + sol2 = solve!(integrator) + @test sol1.t ≈ sol2.t + @test sol1.u ≈ sol2.u end end end From fc9b00b729a8201aceb47aa00400c2d0a7709d28 Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Wed, 24 Jul 2024 13:29:55 +0200 Subject: [PATCH 2/4] format --- src/mprk.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mprk.jl b/src/mprk.jl index 6814f23b..83f11491 100644 --- a/src/mprk.jl +++ b/src/mprk.jl @@ -1222,7 +1222,10 @@ struct MPRK43ConservativeCache{uType, PType, tabType, F} <: OrdinaryDiffEqMutabl linsolve::F end -get_tmp_cache(integrator, ::Union{MPRK43I, MPRK43II}, cache::OrdinaryDiffEqMutableCache) = (cache.σ,) +function get_tmp_cache(integrator, ::Union{MPRK43I, MPRK43II}, + cache::OrdinaryDiffEqMutableCache) + (cache.σ,) +end # In-place function alg_cache(alg::Union{MPRK43I, MPRK43II}, u, rate_prototype, ::Type{uEltypeNoUnits}, From 47ba351f11cfaae8ede527a43cfa80ff5fe16ef5 Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Wed, 24 Jul 2024 17:03:58 +0200 Subject: [PATCH 3/4] fix typo Co-authored-by: Joshua Lampert <51029046+JoshuaLampert@users.noreply.github.com> --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index a8a7c921..1907244e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1265,7 +1265,7 @@ end sol1 = solve(prob, alg; dt, adaptive = false) # test get_tmp_cache and integrator interface - modifying - # values from the cache should not changes the final results + # values from the cache should not change the final results integrator = init(prob, alg; dt, adaptive = false) step!(integrator) cache = @inferred get_tmp_cache(integrator) From 787d8a5a73b7bf3e89b0b058d388090cddd94b9e Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Sat, 27 Jul 2024 15:40:21 +0200 Subject: [PATCH 4/4] other plotting command for sum --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 22d89feb..5e0a8aed 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -174,7 +174,7 @@ Finally, we can use [Plots.jl](https://docs.juliaplots.org/stable/) to visualize using Plots plot(sol, label = ["S" "I" "R"], legend=:right) -plot!(sol, idxs = ((t, S, I, R) -> (t, S + I + R), 0, 1, 2, 3), label = "S+I+R") #Plot S+I+R over time. +plot!(sol.t, sum.(sol.u), label = "S+I+R") # Plot S+I+R over time. ``` We see that there is always a nonnegative number of people in each compartment, while the population ``S+I+R`` remains constant over time.