From 8586fc723bcdbcda61bd7cedb265ac9c3c963a43 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 2 Jun 2025 21:48:13 +0530 Subject: [PATCH 01/13] VR_DirectFW added --- src/JumpProcesses.jl | 2 +- src/jumps.jl | 15 +++++++++++ src/variable_rate.jl | 59 ++++++++++++++++++++++++++++++++------------ 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index d3d27967..3761830a 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -103,7 +103,7 @@ export reset_aggregated_jumps! export ExtendedJumpArray # Export VariableRateAggregator types -export VariableRateAggregator, VR_FRM, VR_Direct +export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW # spatial structs and functions export CartesianGrid, CartesianGridRej diff --git a/src/jumps.jl b/src/jumps.jl index 910f2175..ad14a6a5 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -741,3 +741,18 @@ function get_jump_info_fwrappers(u, p, t, constant_jumps) rates, affects! end + +function get_jump_info_vr_fwrappers(vjumps, prob) + RateWrapper = FunctionWrappers.FunctionWrapper{typeof(prob.tspan[1]), + Tuple{typeof(prob.u0), typeof(prob.p), typeof(prob.tspan[1])}} + + if (vjumps !== nothing) && !isempty(vjumps) + rates = [RateWrapper(c.rate) for c in vjumps] + affects! = Any[(x -> (c.affect!(x); nothing)) for c in vjumps] + else + rates = Vector{RateWrapper}() + affects! = Any[] + end + + rates, affects! +end diff --git a/src/variable_rate.jl b/src/variable_rate.jl index b4c46536..242b7052 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -231,7 +231,7 @@ end update_jumps!(du, u, p, t, idx, jumps...) end -################################### VR_Direct #################################### +################################### VR_Direct and VR_DirectFW #################################### """ $(TYPEDEF) @@ -240,10 +240,11 @@ A concrete `VariableRateAggregator` implementing a direct method-based approach simulating `VariableRateJump`s. `VR_Direct` (Variable Rate Direct Callback) efficiently samples jump times using one continuous callback to integrate the total intensity / propensity for all `VariableRateJump`s, sample when the next jump occurs, and then sample -which jump occurs at this time. +which jump occurs at this time. `VR_DirectFW` a separate FunctionWrapper mode, which +wraps things in FunctionWrappers in cases with large numbers of jumps ## Examples -Simulating a birth-death process with `VR_Direct` (default): +Simulating a birth-death process with `VR_Direct` (default) and VR_DirectFW: ```julia using JumpProcesses, OrdinaryDiffEq u0 = [1.0] # Initial population @@ -264,12 +265,16 @@ death_jump = VariableRateJump(death_rate, death_affect!) oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p) jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_Direct()) sol = solve(jprob, Tsit5()) + +jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_DirectFW()) +sol = solve(jprob, Tsit5()) ``` ## Notes -- `VR_Direct` is expected to generally be more performant than `VR_FRM`. +- `VR_Direct` and `VR_DirectFW` are expected to generally be more performant than `VR_FRM`. """ struct VR_Direct <: VariableRateAggregator end +struct VR_DirectFW <: VariableRateAggregator end mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG, F1, F2} prev_time::T @@ -281,20 +286,34 @@ mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG, F1, F2} rate_funcs::F1 affect_funcs::F2 cum_rate_sum::Vector{T} +end - function VR_DirectEventCache(jumps::JumpSet, ::Type{T}; rng = DEFAULT_RNG) where T - initial_threshold = randexp(rng, T) - vjumps = jumps.variable_jumps +function VR_DirectEventCache(jumps::JumpSet, ::VR_Direct, prob, ::Type{T}; rng = DEFAULT_RNG) where T + initial_threshold = randexp(rng, T) + vjumps = jumps.variable_jumps - # handle vjumps using tuples - rate_funcs, affect_funcs = get_jump_info_tuples(vjumps) + # handle vjumps using tuples + rate_funcs, affect_funcs = get_jump_info_tuples(vjumps) - cum_rate_sum = Vector{T}(undef, length(vjumps)) + cum_rate_sum = Vector{T}(undef, length(vjumps)) - new{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T), - initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, - affect_funcs, cum_rate_sum) - end + VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T), + initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, + affect_funcs, cum_rate_sum) +end + +function VR_DirectEventCache(jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}; rng = DEFAULT_RNG) where T + initial_threshold = randexp(rng, T) + vjumps = jumps.variable_jumps + + # handle vjumps using tuples + rate_funcs, affect_funcs = get_jump_info_vr_fwrappers(vjumps, prob) + + cum_rate_sum = Vector{T}(undef, length(vjumps)) + + VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T), + initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, + affect_funcs, cum_rate_sum) end # Initialization function for VR_DirectEventCache @@ -334,7 +353,15 @@ end function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs; rng = DEFAULT_RNG) new_prob = prob - cache = VR_DirectEventCache(jumps, eltype(prob.tspan); rng) + cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan); rng) + variable_jump_callback = build_variable_integcallback(cache, cvrjs) + cont_agg = cvrjs + return new_prob, variable_jump_callback, cont_agg +end + +function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs; rng = DEFAULT_RNG) + new_prob = prob + cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan); rng) variable_jump_callback = build_variable_integcallback(cache, cvrjs) cont_agg = cvrjs return new_prob, variable_jump_callback, cont_agg @@ -402,7 +429,7 @@ function (cache::VR_DirectEventCache)(u, t, integrator) return cache.current_threshold end -@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2 <: Tuple} +@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2} quote Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds cache.affect_funcs[i](integrator)) i -> (@inbounds cache.affect_funcs[fieldcount(F2)](integrator)) end From e7b5918e1de5c8751e723f916afcb6f7adb2b4b3 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 2 Jun 2025 21:57:38 +0530 Subject: [PATCH 02/13] test added --- test/geneexpr_test.jl | 2 +- test/hawkes_test.jl | 4 ++-- test/monte_carlo_test.jl | 6 ++++++ test/remake_test.jl | 6 ++++++ test/save_positions.jl | 16 +++++++--------- test/thread_safety.jl | 2 +- test/variable_rate.jl | 25 +++++++++++++++++-------- 7 files changed, 40 insertions(+), 21 deletions(-) diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index 23e6ddf0..25ef6398 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -186,7 +186,7 @@ let f(du, u, p, t) = (du .= 0; nothing) oprob = ODEProblem(f, u0f, (0.0, tf / 5), rates) - for vr_agg in (VR_FRM(), VR_Direct()) + for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) vrjprob = JumpProblem(oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false), rng) vrjmean = runSSAs_ode(vrjprob) @test abs(vrjmean - crjmean) < reltol * crjmean diff --git a/test/hawkes_test.jl b/test/hawkes_test.jl index 35bc1791..7299232d 100644 --- a/test/hawkes_test.jl +++ b/test/hawkes_test.jl @@ -133,7 +133,7 @@ end # test stepping Coevolve with continuous integrator and bounded jumps let alg = Coevolve() - for vr_aggregator in (VR_FRM(), VR_Direct()) + for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) oprob = ODEProblem(f!, u0, tspan, p) jumps = hawkes_jump(u0, g, h) jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng) @@ -152,7 +152,7 @@ end # test disabling bounded jumps and using continuous integrator Nsims = 500 let alg = Coevolve() - for vr_aggregator in (VR_FRM(), VR_Direct()) + for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) oprob = ODEProblem(f!, u0, tspan, p) jumps = hawkes_jump(u0, g, h) jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng, diff --git a/test/monte_carlo_test.jl b/test/monte_carlo_test.jl index 2235582a..24307e14 100644 --- a/test/monte_carlo_test.jl +++ b/test/monte_carlo_test.jl @@ -20,6 +20,12 @@ sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, save_everystep = false, dt = 0.001, adaptive = false) @test allunique(sol.u[1].t) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_DirectFW(), rng) +monte_prob = EnsembleProblem(jump_prob) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false) +@test allunique(sol.u[1].t) + jump = ConstantRateJump(rate, affect!) jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false), rng) monte_prob = EnsembleProblem(jump_prob) diff --git a/test/remake_test.jl b/test/remake_test.jl index 2d5512d7..91ab4dab 100644 --- a/test/remake_test.jl +++ b/test/remake_test.jl @@ -78,6 +78,9 @@ let jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng) sol = solve(jprob, Tsit5()) @test all(==(0.0), sol[1, :]) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng) + sol = solve(jprob, Tsit5()) + @test all(==(0.0), sol[1, :]) jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng) sol = solve(jprob, Tsit5()) @test all(==(0.0), sol[1, :]) @@ -107,6 +110,9 @@ let jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng) sol = solve(jprob, Tsit5()) @test all(==(0.0), sol[1, :]) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng) + sol = solve(jprob, Tsit5()) + @test all(==(0.0), sol[1, :]) jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng) sol = solve(jprob, Tsit5()) @test all(==(0.0), sol[1, :]) diff --git a/test/save_positions.jl b/test/save_positions.jl index e9194557..a4faba94 100644 --- a/test/save_positions.jl +++ b/test/save_positions.jl @@ -22,15 +22,13 @@ let oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan) jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1; urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0) - jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = VR_Direct(), dep_graph = [[1]], - save_positions = (false, true), rng) - sol = solve(jumpproblem, Tsit5(); save_everystep = false) - @test sol.t == [0.0, 30.0] - - jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = VR_FRM(), dep_graph = [[1]], - save_positions = (false, true), rng) - sol = solve(jumpproblem, Tsit5(); save_everystep = false) - @test sol.t == [0.0, 30.0] + + for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) + jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = vr_agg, dep_graph = [[1]], + save_positions = (false, true), rng) + sol = solve(jumpproblem, Tsit5(); save_everystep = false) + @test sol.t == [0.0, 30.0] + end end end diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 2c886672..5d736591 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -24,7 +24,7 @@ let ode_prob = ODEProblem(f!, u_0, (0.0, 10)) vrj = VariableRateJump((u,p,t) -> 1.0, integrator -> nothing) - for agg in (VR_FRM(), VR_Direct()) + for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jump_prob = JumpProblem(ode_prob, Direct(), vrj; vr_aggregator = agg) prob_func(prob, i, repeat) = deepcopy(prob) prob = EnsembleProblem(jump_prob,prob_func = prob_func) diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 20b012c0..a3f26fc2 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -314,7 +314,7 @@ let ode_prob = ODEProblem(ode_fxn, u0, tspan, p) dt = 0.1 tsave = range(tspan[1], tspan[2]; step = dt) - for vr_aggregator in (VR_FRM(), VR_Direct()) + for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator, rng) for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization())) @@ -347,9 +347,11 @@ let prob = ODEProblem(f, [0.2], (0.0, 10.0)) mean_vrfr = run_ensemble(prob, Tsit5(), jump, jump2) - mean_vrdcb = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_Direct()) + mean_vrcb = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_Direct()) + mean_vrcbfw = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_DirectFW()) - @test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05) + @test isapprox(mean_vrfr, mean_vrcb, rtol=0.05) + @test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05) end # Test 2: SDE with two variable rate jumps @@ -364,9 +366,11 @@ let prob = SDEProblem(f, g, [10.0], (0.0, 10.0)) mean_vrfr = run_ensemble(prob, SRIW1(), jump, jump2) - mean_vrdcb = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_Direct()) + mean_vrcb = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_Direct()) + mean_vrcbfw = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_DirectFW()) - @test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05) + @test isapprox(mean_vrfr, mean_vrcb, rtol=0.05) + @test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05) end # Test 3: ODE with analytical solution @@ -380,14 +384,16 @@ let prob = ODEProblem(f, [0.2], (0.0, 10.0)) mean_vrfr = run_ensemble(prob, Tsit5(), jump) - mean_vrdcb = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_Direct()) + mean_vrcb = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_Direct()) + mean_vrcbfw = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_DirectFW()) t = 10.0 u0 = 0.2 analytical_mean = u0 * exp(-t) + λ*(1 - exp(-t)) @test isapprox(mean_vrfr, analytical_mean, rtol=0.05) - @test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05) + @test isapprox(mean_vrfr, mean_vrcb, rtol=0.05) + @test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05) end # Test 4: No. of Jumps @@ -416,7 +422,7 @@ let results = Dict() u0 = [1.0] tspan = (0.0, 10.0) - for vr_aggregator in (VR_FRM(), VR_Direct()) + for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) jump_counts = zeros(Int, Nsims) p = [0.0, 0.0, 0] prob = ODEProblem(f, u0, tspan, p) @@ -434,5 +440,8 @@ let mean_jumps_vrfr = results[VR_FRM()].mean_jumps mean_jumps_vrdcb = results[VR_Direct()].mean_jumps + mean_jumps_vrdcbfw = results[VR_DirectFW()].mean_jumps + @test isapprox(mean_jumps_vrfr, mean_jumps_vrdcb, rtol=0.1) + @test isapprox(mean_jumps_vrcb, mean_jumps_vrdcbfw, rtol=0.1) end From 75122fe5af3b26b6c9b335d63053cbf88e4b1f02 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 2 Jun 2025 23:17:17 +0530 Subject: [PATCH 03/13] bug fix --- src/jumps.jl | 5 +++-- test/variable_rate.jl | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/jumps.jl b/src/jumps.jl index ad14a6a5..32e6aeb7 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -743,8 +743,9 @@ function get_jump_info_fwrappers(u, p, t, constant_jumps) end function get_jump_info_vr_fwrappers(vjumps, prob) - RateWrapper = FunctionWrappers.FunctionWrapper{typeof(prob.tspan[1]), - Tuple{typeof(prob.u0), typeof(prob.p), typeof(prob.tspan[1])}} + t, u = prob.tspan[1], prob.u0 + RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), + Tuple{typeof(u), typeof(prob.p), typeof(t)}} if (vjumps !== nothing) && !isempty(vjumps) rates = [RateWrapper(c.rate) for c in vjumps] diff --git a/test/variable_rate.jl b/test/variable_rate.jl index a3f26fc2..95b67595 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -285,7 +285,7 @@ let rng = StableRNG(seed) b = 2.0 d = 1.0 - n0 = 1 + n0 = 1.0 tspan = (0.0, 4.0) Nsims = 10000 n(t) = n0 * exp((b - d) * t) From b362e8e0454ca863ae28d9dbcdae2d19539103c9 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 2 Jun 2025 23:44:09 +0530 Subject: [PATCH 04/13] typo --- test/variable_rate.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 95b67595..6a30694f 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -314,7 +314,7 @@ let ode_prob = ODEProblem(ode_fxn, u0, tspan, p) dt = 0.1 tsave = range(tspan[1], tspan[2]; step = dt) - for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) + for vr_aggregator in (VR_Direct(), VR_DirectFW(), VR_FRM()) sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator, rng) for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization())) @@ -439,7 +439,7 @@ let end mean_jumps_vrfr = results[VR_FRM()].mean_jumps - mean_jumps_vrdcb = results[VR_Direct()].mean_jumps + mean_jumps_vrcb = results[VR_Direct()].mean_jumps mean_jumps_vrdcbfw = results[VR_DirectFW()].mean_jumps @test isapprox(mean_jumps_vrfr, mean_jumps_vrdcb, rtol=0.1) From 7a78ce4c177de1bd5951ee4aa87a4eae036b4cee Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Tue, 3 Jun 2025 00:05:57 +0530 Subject: [PATCH 05/13] typo fix --- test/variable_rate.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 6a30694f..621bd4b5 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -440,8 +440,8 @@ let mean_jumps_vrfr = results[VR_FRM()].mean_jumps mean_jumps_vrcb = results[VR_Direct()].mean_jumps - mean_jumps_vrdcbfw = results[VR_DirectFW()].mean_jumps + mean_jumps_vrdbfw = results[VR_DirectFW()].mean_jumps - @test isapprox(mean_jumps_vrfr, mean_jumps_vrdcb, rtol=0.1) - @test isapprox(mean_jumps_vrcb, mean_jumps_vrdcbfw, rtol=0.1) + @test isapprox(mean_jumps_vrfr, mean_jumps_vrcb, rtol=0.1) + @test isapprox(mean_jumps_vrcb, mean_jumps_vrdbfw, rtol=0.1) end From c4cf286818745425720b9bc7245b958c3c10969d Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Tue, 3 Jun 2025 00:07:28 +0530 Subject: [PATCH 06/13] typo fix --- test/variable_rate.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 621bd4b5..b48617e1 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -440,8 +440,8 @@ let mean_jumps_vrfr = results[VR_FRM()].mean_jumps mean_jumps_vrcb = results[VR_Direct()].mean_jumps - mean_jumps_vrdbfw = results[VR_DirectFW()].mean_jumps + mean_jumps_vrcbfw = results[VR_DirectFW()].mean_jumps @test isapprox(mean_jumps_vrfr, mean_jumps_vrcb, rtol=0.1) - @test isapprox(mean_jumps_vrcb, mean_jumps_vrdbfw, rtol=0.1) + @test isapprox(mean_jumps_vrcb, mean_jumps_vrcbfw, rtol=0.1) end From 7eced2c120e0f6eb827e7ae3303423e43c703978 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 4 Jun 2025 02:09:51 +0530 Subject: [PATCH 07/13] new dispach --- src/jumps.jl | 24 ++++-------------------- src/variable_rate.jl | 12 +++++++++--- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/src/jumps.jl b/src/jumps.jl index 32e6aeb7..cc326065 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -727,29 +727,13 @@ function get_jump_info_tuples(jumps) rates, affects! end -function get_jump_info_fwrappers(u, p, t, constant_jumps) +function get_jump_info_fwrappers(u, p, t, jumps) RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), Tuple{typeof(u), typeof(p), typeof(t)}} - if (constant_jumps !== nothing) && !isempty(constant_jumps) - rates = [RateWrapper(c.rate) for c in constant_jumps] - affects! = Any[(x -> (c.affect!(x); nothing)) for c in constant_jumps] - else - rates = Vector{RateWrapper}() - affects! = Any[] - end - - rates, affects! -end - -function get_jump_info_vr_fwrappers(vjumps, prob) - t, u = prob.tspan[1], prob.u0 - RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), - Tuple{typeof(u), typeof(prob.p), typeof(t)}} - - if (vjumps !== nothing) && !isempty(vjumps) - rates = [RateWrapper(c.rate) for c in vjumps] - affects! = Any[(x -> (c.affect!(x); nothing)) for c in vjumps] + if (jumps !== nothing) && !isempty(jumps) + rates = [RateWrapper(c.rate) for c in jumps] + affects! = Any[(x -> (c.affect!(x); nothing)) for c in jumps] else rates = Vector{RateWrapper}() affects! = Any[] diff --git a/src/variable_rate.jl b/src/variable_rate.jl index 242b7052..043c8d82 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -306,12 +306,14 @@ function VR_DirectEventCache(jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}; rng initial_threshold = randexp(rng, T) vjumps = jumps.variable_jumps + t, u = prob.tspan[1], prob.u0 + # handle vjumps using tuples - rate_funcs, affect_funcs = get_jump_info_vr_fwrappers(vjumps, prob) + rate_funcs, affect_funcs = get_jump_info_fwrappers(u, prob.p, t, vjumps) cum_rate_sum = Vector{T}(undef, length(vjumps)) - VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T), + VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), Any}(zero(T), initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, affect_funcs, cum_rate_sum) end @@ -429,12 +431,16 @@ function (cache::VR_DirectEventCache)(u, t, integrator) return cache.current_threshold end -@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2} +@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2 <: Tuple} quote Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds cache.affect_funcs[i](integrator)) i -> (@inbounds cache.affect_funcs[fieldcount(F2)](integrator)) end end +@inline function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2} + @inbounds cache.affect_funcs[idx](integrator) +end + # Affect functor defined directly on the cache function (cache::VR_DirectEventCache)(integrator) @unpack t, u, p = integrator From 524199e59c27eb21451884f4954d7048dd7df2d6 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 4 Jun 2025 02:38:28 +0530 Subject: [PATCH 08/13] added concretize_vr_direct_affects! --- src/variable_rate.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/variable_rate.jl b/src/variable_rate.jl index 043c8d82..ad5788d7 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -329,8 +329,24 @@ function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrato nothing end +@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2}, + ::I) where {T, RNG, F1, F2, I <: DiffEqBase.DEIntegrator} + if (cache.affect_funcs isa Vector) && + !(cache.affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}) + AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}} + cache.affect_funcs = AffectWrapper[makewrapper(AffectWrapper, aff) for aff in cache.affect_funcs] + end + nothing +end + +@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2}, + ::I) where {T, RNG, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator} + nothing +end + # Wrapper for initialize to match ContinuousCallback signature function initialize_vr_direct_wrapper(cb::ContinuousCallback, u, t, integrator) + concretize_vr_direct_affects!(cb.condition, integrator) initialize_vr_direct_cache!(cb.condition, u, t, integrator) u_modified!(integrator, false) nothing From ad2805a56808b3ba02b575752f78cd0d6e764d59 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Wed, 4 Jun 2025 22:28:43 +0530 Subject: [PATCH 09/13] Update src/variable_rate.jl Co-authored-by: Sam Isaacson --- src/variable_rate.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/variable_rate.jl b/src/variable_rate.jl index ad5788d7..9fbc275b 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -329,8 +329,8 @@ function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrato nothing end -@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - ::I) where {T, RNG, F1, F2, I <: DiffEqBase.DEIntegrator} +@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache, + ::I) where {I <: DiffEqBase.DEIntegrator} if (cache.affect_funcs isa Vector) && !(cache.affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}) AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}} From befa6a32c59e63894295aaeff3b76212969bfd46 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Wed, 4 Jun 2025 22:29:08 +0530 Subject: [PATCH 10/13] Update src/variable_rate.jl Co-authored-by: Sam Isaacson --- src/variable_rate.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variable_rate.jl b/src/variable_rate.jl index 9fbc275b..ee59c580 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -453,7 +453,7 @@ end end end -@inline function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2} +@inline function execute_affect!(cache::VR_DirectEventCache, integrator, idx) @inbounds cache.affect_funcs[idx](integrator) end From 676ea0c43b201ecde409832ac68fa7160d7aa940 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 4 Jun 2025 23:11:50 +0530 Subject: [PATCH 11/13] added type check in execute_affect --- src/variable_rate.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/variable_rate.jl b/src/variable_rate.jl index ee59c580..aec0279f 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -453,8 +453,13 @@ end end end -@inline function execute_affect!(cache::VR_DirectEventCache, integrator, idx) - @inbounds cache.affect_funcs[idx](integrator) +@inline function execute_affect!(cache::VR_DirectEventCache, integrator::I, idx) where {I <: DiffEqBase.DEIntegrator} + @unpack affect_funcs = cache + if affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}} + @inbounds affect_funcs[idx](integrator) + else + error("Error, invalid affect_funcs type. Expected a vector of function wrappers and got $(typeof(affect_funcs))") + end end # Affect functor defined directly on the cache From 4a12ce1f955f2c6dee382d9e6ac03b9f2060903f Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 5 Jun 2025 00:07:34 +0530 Subject: [PATCH 12/13] fixed bug --- src/variable_rate.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/variable_rate.jl b/src/variable_rate.jl index aec0279f..883beb49 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -329,8 +329,8 @@ function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrato nothing end -@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache, - ::I) where {I <: DiffEqBase.DEIntegrator} +@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2}, + ::I) where {T, RNG, F1, F2, I <: DiffEqBase.DEIntegrator} if (cache.affect_funcs isa Vector) && !(cache.affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}) AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}} @@ -447,13 +447,16 @@ function (cache::VR_DirectEventCache)(u, t, integrator) return cache.current_threshold end -@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2 <: Tuple} +@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, + integrator::I, idx) where {T, RNG, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator} quote - Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds cache.affect_funcs[i](integrator)) i -> (@inbounds cache.affect_funcs[fieldcount(F2)](integrator)) + @unpack affect_funcs = cache + Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds affect_funcs[i](integrator)) i -> (@inbounds affect_funcs[fieldcount(F2)](integrator)) end end -@inline function execute_affect!(cache::VR_DirectEventCache, integrator::I, idx) where {I <: DiffEqBase.DEIntegrator} +@inline function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, + integrator::I, idx) where {T, RNG, F1, F2, I <: DiffEqBase.DEIntegrator} @unpack affect_funcs = cache if affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}} @inbounds affect_funcs[idx](integrator) From 80d33c7e1a5ec39c42f9921c0ac31eb7e89a0a14 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Fri, 6 Jun 2025 19:52:55 +0530 Subject: [PATCH 13/13] ambigous bug fix --- src/variable_rate.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/variable_rate.jl b/src/variable_rate.jl index 883beb49..7587680c 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -276,7 +276,7 @@ sol = solve(jprob, Tsit5()) struct VR_Direct <: VariableRateAggregator end struct VR_DirectFW <: VariableRateAggregator end -mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG, F1, F2} +mutable struct VR_DirectEventCache{T, RNG, F1, F2} prev_time::T prev_threshold::T current_time::T @@ -329,8 +329,8 @@ function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrato nothing end -@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - ::I) where {T, RNG, F1, F2, I <: DiffEqBase.DEIntegrator} +@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache, + ::I) where {I <: DiffEqBase.DEIntegrator} if (cache.affect_funcs isa Vector) && !(cache.affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}) AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}} @@ -455,8 +455,8 @@ end end end -@inline function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - integrator::I, idx) where {T, RNG, F1, F2, I <: DiffEqBase.DEIntegrator} +@inline function execute_affect!(cache::VR_DirectEventCache, + integrator::I, idx) where {I <: DiffEqBase.DEIntegrator} @unpack affect_funcs = cache if affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}} @inbounds affect_funcs[idx](integrator)