Skip to content

VR_DirectFW aggregator #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 7, 2025
2 changes: 1 addition & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ export reset_aggregated_jumps!
export ExtendedJumpArray

# Export VariableRateAggregator types
export VariableRateAggregator, VR_FRM, VR_Direct
export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW

# spatial structs and functions
export CartesianGrid, CartesianGridRej
Expand Down
8 changes: 4 additions & 4 deletions src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -727,13 +727,13 @@ function get_jump_info_tuples(jumps)
rates, affects!
end

function get_jump_info_fwrappers(u, p, t, constant_jumps)
function get_jump_info_fwrappers(u, p, t, jumps)
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
Tuple{typeof(u), typeof(p), typeof(t)}}

if (constant_jumps !== nothing) && !isempty(constant_jumps)
rates = [RateWrapper(c.rate) for c in constant_jumps]
affects! = Any[(x -> (c.affect!(x); nothing)) for c in constant_jumps]
if (jumps !== nothing) && !isempty(jumps)
rates = [RateWrapper(c.rate) for c in jumps]
affects! = Any[(x -> (c.affect!(x); nothing)) for c in jumps]
else
rates = Vector{RateWrapper}()
affects! = Any[]
Expand Down
93 changes: 75 additions & 18 deletions src/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ end
update_jumps!(du, u, p, t, idx, jumps...)
end

################################### VR_Direct ####################################
################################### VR_Direct and VR_DirectFW ####################################

"""
$(TYPEDEF)
Expand All @@ -240,10 +240,11 @@ A concrete `VariableRateAggregator` implementing a direct method-based approach
simulating `VariableRateJump`s. `VR_Direct` (Variable Rate Direct Callback) efficiently
samples jump times using one continuous callback to integrate the total intensity /
propensity for all `VariableRateJump`s, sample when the next jump occurs, and then sample
which jump occurs at this time.
which jump occurs at this time. `VR_DirectFW` a separate FunctionWrapper mode, which
wraps things in FunctionWrappers in cases with large numbers of jumps

## Examples
Simulating a birth-death process with `VR_Direct` (default):
Simulating a birth-death process with `VR_Direct` (default) and VR_DirectFW:
```julia
using JumpProcesses, OrdinaryDiffEq
u0 = [1.0] # Initial population
Expand All @@ -264,14 +265,18 @@ death_jump = VariableRateJump(death_rate, death_affect!)
oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p)
jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_Direct())
sol = solve(jprob, Tsit5())

jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_DirectFW())
sol = solve(jprob, Tsit5())
```

## Notes
- `VR_Direct` is expected to generally be more performant than `VR_FRM`.
- `VR_Direct` and `VR_DirectFW` are expected to generally be more performant than `VR_FRM`.
"""
struct VR_Direct <: VariableRateAggregator end
struct VR_DirectFW <: VariableRateAggregator end

mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG, F1, F2}
mutable struct VR_DirectEventCache{T, RNG, F1, F2}
prev_time::T
prev_threshold::T
current_time::T
Expand All @@ -281,20 +286,36 @@ mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG, F1, F2}
rate_funcs::F1
affect_funcs::F2
cum_rate_sum::Vector{T}
end

function VR_DirectEventCache(jumps::JumpSet, ::Type{T}; rng = DEFAULT_RNG) where T
initial_threshold = randexp(rng, T)
vjumps = jumps.variable_jumps
function VR_DirectEventCache(jumps::JumpSet, ::VR_Direct, prob, ::Type{T}; rng = DEFAULT_RNG) where T
initial_threshold = randexp(rng, T)
vjumps = jumps.variable_jumps

# handle vjumps using tuples
rate_funcs, affect_funcs = get_jump_info_tuples(vjumps)
# handle vjumps using tuples
rate_funcs, affect_funcs = get_jump_info_tuples(vjumps)

cum_rate_sum = Vector{T}(undef, length(vjumps))
cum_rate_sum = Vector{T}(undef, length(vjumps))

new{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T),
initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs,
affect_funcs, cum_rate_sum)
end
VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T),
initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs,
affect_funcs, cum_rate_sum)
end

function VR_DirectEventCache(jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}; rng = DEFAULT_RNG) where T
initial_threshold = randexp(rng, T)
vjumps = jumps.variable_jumps

t, u = prob.tspan[1], prob.u0

# handle vjumps using tuples
rate_funcs, affect_funcs = get_jump_info_fwrappers(u, prob.p, t, vjumps)

cum_rate_sum = Vector{T}(undef, length(vjumps))

VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), Any}(zero(T),
initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs,
affect_funcs, cum_rate_sum)
end

# Initialization function for VR_DirectEventCache
Expand All @@ -308,8 +329,24 @@ function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrato
nothing
end

@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache,
::I) where {I <: DiffEqBase.DEIntegrator}
if (cache.affect_funcs isa Vector) &&
!(cache.affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}})
AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}
cache.affect_funcs = AffectWrapper[makewrapper(AffectWrapper, aff) for aff in cache.affect_funcs]
end
nothing
end

@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2},
::I) where {T, RNG, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator}
nothing
end

# Wrapper for initialize to match ContinuousCallback signature
function initialize_vr_direct_wrapper(cb::ContinuousCallback, u, t, integrator)
concretize_vr_direct_affects!(cb.condition, integrator)
initialize_vr_direct_cache!(cb.condition, u, t, integrator)
u_modified!(integrator, false)
nothing
Expand All @@ -334,7 +371,15 @@ end

function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs; rng = DEFAULT_RNG)
new_prob = prob
cache = VR_DirectEventCache(jumps, eltype(prob.tspan); rng)
cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan); rng)
variable_jump_callback = build_variable_integcallback(cache, cvrjs)
cont_agg = cvrjs
return new_prob, variable_jump_callback, cont_agg
end

function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs; rng = DEFAULT_RNG)
new_prob = prob
cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan); rng)
variable_jump_callback = build_variable_integcallback(cache, cvrjs)
cont_agg = cvrjs
return new_prob, variable_jump_callback, cont_agg
Expand Down Expand Up @@ -402,9 +447,21 @@ function (cache::VR_DirectEventCache)(u, t, integrator)
return cache.current_threshold
end

@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2 <: Tuple}
@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2},
integrator::I, idx) where {T, RNG, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator}
quote
Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds cache.affect_funcs[i](integrator)) i -> (@inbounds cache.affect_funcs[fieldcount(F2)](integrator))
@unpack affect_funcs = cache
Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds affect_funcs[i](integrator)) i -> (@inbounds affect_funcs[fieldcount(F2)](integrator))
end
end

@inline function execute_affect!(cache::VR_DirectEventCache,
integrator::I, idx) where {I <: DiffEqBase.DEIntegrator}
@unpack affect_funcs = cache
if affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}
@inbounds affect_funcs[idx](integrator)
else
error("Error, invalid affect_funcs type. Expected a vector of function wrappers and got $(typeof(affect_funcs))")
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/geneexpr_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ let
f(du, u, p, t) = (du .= 0; nothing)
oprob = ODEProblem(f, u0f, (0.0, tf / 5), rates)

for vr_agg in (VR_FRM(), VR_Direct())
for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
vrjprob = JumpProblem(oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false), rng)
vrjmean = runSSAs_ode(vrjprob)
@test abs(vrjmean - crjmean) < reltol * crjmean
Expand Down
4 changes: 2 additions & 2 deletions test/hawkes_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ end

# test stepping Coevolve with continuous integrator and bounded jumps
let alg = Coevolve()
for vr_aggregator in (VR_FRM(), VR_Direct())
for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW())
oprob = ODEProblem(f!, u0, tspan, p)
jumps = hawkes_jump(u0, g, h)
jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng)
Expand All @@ -152,7 +152,7 @@ end
# test disabling bounded jumps and using continuous integrator
Nsims = 500
let alg = Coevolve()
for vr_aggregator in (VR_FRM(), VR_Direct())
for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW())
oprob = ODEProblem(f!, u0, tspan, p)
jumps = hawkes_jump(u0, g, h)
jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng,
Expand Down
6 changes: 6 additions & 0 deletions test/monte_carlo_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3,
save_everystep = false, dt = 0.001, adaptive = false)
@test allunique(sol.u[1].t)

jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_DirectFW(), rng)
monte_prob = EnsembleProblem(jump_prob)
sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3,
save_everystep = false, dt = 0.001, adaptive = false)
@test allunique(sol.u[1].t)

jump = ConstantRateJump(rate, affect!)
jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false), rng)
monte_prob = EnsembleProblem(jump_prob)
Expand Down
6 changes: 6 additions & 0 deletions test/remake_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ let
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng)
sol = solve(jprob, Tsit5())
@test all(==(0.0), sol[1, :])
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng)
sol = solve(jprob, Tsit5())
@test all(==(0.0), sol[1, :])
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng)
sol = solve(jprob, Tsit5())
@test all(==(0.0), sol[1, :])
Expand Down Expand Up @@ -107,6 +110,9 @@ let
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng)
sol = solve(jprob, Tsit5())
@test all(==(0.0), sol[1, :])
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng)
sol = solve(jprob, Tsit5())
@test all(==(0.0), sol[1, :])
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng)
sol = solve(jprob, Tsit5())
@test all(==(0.0), sol[1, :])
Expand Down
16 changes: 7 additions & 9 deletions test/save_positions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ let
oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan)
jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1;
urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0)
jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = VR_Direct(), dep_graph = [[1]],
save_positions = (false, true), rng)
sol = solve(jumpproblem, Tsit5(); save_everystep = false)
@test sol.t == [0.0, 30.0]

jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = VR_FRM(), dep_graph = [[1]],
save_positions = (false, true), rng)
sol = solve(jumpproblem, Tsit5(); save_everystep = false)
@test sol.t == [0.0, 30.0]

for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = vr_agg, dep_graph = [[1]],
save_positions = (false, true), rng)
sol = solve(jumpproblem, Tsit5(); save_everystep = false)
@test sol.t == [0.0, 30.0]
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/thread_safety.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ let
ode_prob = ODEProblem(f!, u_0, (0.0, 10))
vrj = VariableRateJump((u,p,t) -> 1.0, integrator -> nothing)

for agg in (VR_FRM(), VR_Direct())
for agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
jump_prob = JumpProblem(ode_prob, Direct(), vrj; vr_aggregator = agg)
prob_func(prob, i, repeat) = deepcopy(prob)
prob = EnsembleProblem(jump_prob,prob_func = prob_func)
Expand Down
31 changes: 20 additions & 11 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ let
rng = StableRNG(seed)
b = 2.0
d = 1.0
n0 = 1
n0 = 1.0
tspan = (0.0, 4.0)
Nsims = 10000
n(t) = n0 * exp((b - d) * t)
Expand Down Expand Up @@ -314,7 +314,7 @@ let
ode_prob = ODEProblem(ode_fxn, u0, tspan, p)
dt = 0.1
tsave = range(tspan[1], tspan[2]; step = dt)
for vr_aggregator in (VR_FRM(), VR_Direct())
for vr_aggregator in (VR_Direct(), VR_DirectFW(), VR_FRM())
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator, rng)

for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization()))
Expand Down Expand Up @@ -347,9 +347,11 @@ let
prob = ODEProblem(f, [0.2], (0.0, 10.0))

mean_vrfr = run_ensemble(prob, Tsit5(), jump, jump2)
mean_vrdcb = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_Direct())
mean_vrcb = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_Direct())
mean_vrcbfw = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_DirectFW())

@test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05)
@test isapprox(mean_vrfr, mean_vrcb, rtol=0.05)
@test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05)
end

# Test 2: SDE with two variable rate jumps
Expand All @@ -364,9 +366,11 @@ let
prob = SDEProblem(f, g, [10.0], (0.0, 10.0))

mean_vrfr = run_ensemble(prob, SRIW1(), jump, jump2)
mean_vrdcb = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_Direct())
mean_vrcb = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_Direct())
mean_vrcbfw = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_DirectFW())

@test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05)
@test isapprox(mean_vrfr, mean_vrcb, rtol=0.05)
@test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05)
end

# Test 3: ODE with analytical solution
Expand All @@ -380,14 +384,16 @@ let
prob = ODEProblem(f, [0.2], (0.0, 10.0))

mean_vrfr = run_ensemble(prob, Tsit5(), jump)
mean_vrdcb = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_Direct())
mean_vrcb = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_Direct())
mean_vrcbfw = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_DirectFW())

t = 10.0
u0 = 0.2
analytical_mean = u0 * exp(-t) + λ*(1 - exp(-t))

@test isapprox(mean_vrfr, analytical_mean, rtol=0.05)
@test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05)
@test isapprox(mean_vrfr, mean_vrcb, rtol=0.05)
@test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05)
end

# Test 4: No. of Jumps
Expand Down Expand Up @@ -416,7 +422,7 @@ let
results = Dict()
u0 = [1.0]
tspan = (0.0, 10.0)
for vr_aggregator in (VR_FRM(), VR_Direct())
for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW())
jump_counts = zeros(Int, Nsims)
p = [0.0, 0.0, 0]
prob = ODEProblem(f, u0, tspan, p)
Expand All @@ -433,6 +439,9 @@ let
end

mean_jumps_vrfr = results[VR_FRM()].mean_jumps
mean_jumps_vrdcb = results[VR_Direct()].mean_jumps
@test isapprox(mean_jumps_vrfr, mean_jumps_vrdcb, rtol=0.1)
mean_jumps_vrcb = results[VR_Direct()].mean_jumps
mean_jumps_vrcbfw = results[VR_DirectFW()].mean_jumps

@test isapprox(mean_jumps_vrfr, mean_jumps_vrcb, rtol=0.1)
@test isapprox(mean_jumps_vrcb, mean_jumps_vrcbfw, rtol=0.1)
end
Loading