Skip to content

Commit ccbb392

Browse files
Merge pull request #488 from sivasathyaseeelan/VR_DirectFW
VR_DirectFW aggregator
2 parents 069e129 + 80d33c7 commit ccbb392

10 files changed

+123
-47
lines changed

src/JumpProcesses.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ export reset_aggregated_jumps!
103103
export ExtendedJumpArray
104104

105105
# Export VariableRateAggregator types
106-
export VariableRateAggregator, VR_FRM, VR_Direct
106+
export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW
107107

108108
# spatial structs and functions
109109
export CartesianGrid, CartesianGridRej

src/jumps.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,13 +727,13 @@ function get_jump_info_tuples(jumps)
727727
rates, affects!
728728
end
729729

730-
function get_jump_info_fwrappers(u, p, t, constant_jumps)
730+
function get_jump_info_fwrappers(u, p, t, jumps)
731731
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
732732
Tuple{typeof(u), typeof(p), typeof(t)}}
733733

734-
if (constant_jumps !== nothing) && !isempty(constant_jumps)
735-
rates = [RateWrapper(c.rate) for c in constant_jumps]
736-
affects! = Any[(x -> (c.affect!(x); nothing)) for c in constant_jumps]
734+
if (jumps !== nothing) && !isempty(jumps)
735+
rates = [RateWrapper(c.rate) for c in jumps]
736+
affects! = Any[(x -> (c.affect!(x); nothing)) for c in jumps]
737737
else
738738
rates = Vector{RateWrapper}()
739739
affects! = Any[]

src/variable_rate.jl

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ end
231231
update_jumps!(du, u, p, t, idx, jumps...)
232232
end
233233

234-
################################### VR_Direct ####################################
234+
################################### VR_Direct and VR_DirectFW ####################################
235235

236236
"""
237237
$(TYPEDEF)
@@ -240,10 +240,11 @@ A concrete `VariableRateAggregator` implementing a direct method-based approach
240240
simulating `VariableRateJump`s. `VR_Direct` (Variable Rate Direct Callback) efficiently
241241
samples jump times using one continuous callback to integrate the total intensity /
242242
propensity for all `VariableRateJump`s, sample when the next jump occurs, and then sample
243-
which jump occurs at this time.
243+
which jump occurs at this time. `VR_DirectFW` a separate FunctionWrapper mode, which
244+
wraps things in FunctionWrappers in cases with large numbers of jumps
244245
245246
## Examples
246-
Simulating a birth-death process with `VR_Direct` (default):
247+
Simulating a birth-death process with `VR_Direct` (default) and VR_DirectFW:
247248
```julia
248249
using JumpProcesses, OrdinaryDiffEq
249250
u0 = [1.0] # Initial population
@@ -264,14 +265,18 @@ death_jump = VariableRateJump(death_rate, death_affect!)
264265
oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p)
265266
jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_Direct())
266267
sol = solve(jprob, Tsit5())
268+
269+
jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_DirectFW())
270+
sol = solve(jprob, Tsit5())
267271
```
268272
269273
## Notes
270-
- `VR_Direct` is expected to generally be more performant than `VR_FRM`.
274+
- `VR_Direct` and `VR_DirectFW` are expected to generally be more performant than `VR_FRM`.
271275
"""
272276
struct VR_Direct <: VariableRateAggregator end
277+
struct VR_DirectFW <: VariableRateAggregator end
273278

274-
mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG, F1, F2}
279+
mutable struct VR_DirectEventCache{T, RNG, F1, F2}
275280
prev_time::T
276281
prev_threshold::T
277282
current_time::T
@@ -281,20 +286,36 @@ mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG, F1, F2}
281286
rate_funcs::F1
282287
affect_funcs::F2
283288
cum_rate_sum::Vector{T}
289+
end
284290

285-
function VR_DirectEventCache(jumps::JumpSet, ::Type{T}; rng = DEFAULT_RNG) where T
286-
initial_threshold = randexp(rng, T)
287-
vjumps = jumps.variable_jumps
291+
function VR_DirectEventCache(jumps::JumpSet, ::VR_Direct, prob, ::Type{T}; rng = DEFAULT_RNG) where T
292+
initial_threshold = randexp(rng, T)
293+
vjumps = jumps.variable_jumps
288294

289-
# handle vjumps using tuples
290-
rate_funcs, affect_funcs = get_jump_info_tuples(vjumps)
295+
# handle vjumps using tuples
296+
rate_funcs, affect_funcs = get_jump_info_tuples(vjumps)
291297

292-
cum_rate_sum = Vector{T}(undef, length(vjumps))
298+
cum_rate_sum = Vector{T}(undef, length(vjumps))
293299

294-
new{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T),
295-
initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs,
296-
affect_funcs, cum_rate_sum)
297-
end
300+
VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T),
301+
initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs,
302+
affect_funcs, cum_rate_sum)
303+
end
304+
305+
function VR_DirectEventCache(jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}; rng = DEFAULT_RNG) where T
306+
initial_threshold = randexp(rng, T)
307+
vjumps = jumps.variable_jumps
308+
309+
t, u = prob.tspan[1], prob.u0
310+
311+
# handle vjumps using tuples
312+
rate_funcs, affect_funcs = get_jump_info_fwrappers(u, prob.p, t, vjumps)
313+
314+
cum_rate_sum = Vector{T}(undef, length(vjumps))
315+
316+
VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), Any}(zero(T),
317+
initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs,
318+
affect_funcs, cum_rate_sum)
298319
end
299320

300321
# Initialization function for VR_DirectEventCache
@@ -308,8 +329,24 @@ function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrato
308329
nothing
309330
end
310331

332+
@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache,
333+
::I) where {I <: DiffEqBase.DEIntegrator}
334+
if (cache.affect_funcs isa Vector) &&
335+
!(cache.affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}})
336+
AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}
337+
cache.affect_funcs = AffectWrapper[makewrapper(AffectWrapper, aff) for aff in cache.affect_funcs]
338+
end
339+
nothing
340+
end
341+
342+
@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2},
343+
::I) where {T, RNG, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator}
344+
nothing
345+
end
346+
311347
# Wrapper for initialize to match ContinuousCallback signature
312348
function initialize_vr_direct_wrapper(cb::ContinuousCallback, u, t, integrator)
349+
concretize_vr_direct_affects!(cb.condition, integrator)
313350
initialize_vr_direct_cache!(cb.condition, u, t, integrator)
314351
u_modified!(integrator, false)
315352
nothing
@@ -334,7 +371,15 @@ end
334371

335372
function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs; rng = DEFAULT_RNG)
336373
new_prob = prob
337-
cache = VR_DirectEventCache(jumps, eltype(prob.tspan); rng)
374+
cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan); rng)
375+
variable_jump_callback = build_variable_integcallback(cache, cvrjs)
376+
cont_agg = cvrjs
377+
return new_prob, variable_jump_callback, cont_agg
378+
end
379+
380+
function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs; rng = DEFAULT_RNG)
381+
new_prob = prob
382+
cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan); rng)
338383
variable_jump_callback = build_variable_integcallback(cache, cvrjs)
339384
cont_agg = cvrjs
340385
return new_prob, variable_jump_callback, cont_agg
@@ -402,9 +447,21 @@ function (cache::VR_DirectEventCache)(u, t, integrator)
402447
return cache.current_threshold
403448
end
404449

405-
@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, integrator, idx) where {T, RNG, F1, F2 <: Tuple}
450+
@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2},
451+
integrator::I, idx) where {T, RNG, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator}
406452
quote
407-
Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds cache.affect_funcs[i](integrator)) i -> (@inbounds cache.affect_funcs[fieldcount(F2)](integrator))
453+
@unpack affect_funcs = cache
454+
Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds affect_funcs[i](integrator)) i -> (@inbounds affect_funcs[fieldcount(F2)](integrator))
455+
end
456+
end
457+
458+
@inline function execute_affect!(cache::VR_DirectEventCache,
459+
integrator::I, idx) where {I <: DiffEqBase.DEIntegrator}
460+
@unpack affect_funcs = cache
461+
if affect_funcs isa Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{I}}}
462+
@inbounds affect_funcs[idx](integrator)
463+
else
464+
error("Error, invalid affect_funcs type. Expected a vector of function wrappers and got $(typeof(affect_funcs))")
408465
end
409466
end
410467

test/geneexpr_test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ let
186186
f(du, u, p, t) = (du .= 0; nothing)
187187
oprob = ODEProblem(f, u0f, (0.0, tf / 5), rates)
188188

189-
for vr_agg in (VR_FRM(), VR_Direct())
189+
for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
190190
vrjprob = JumpProblem(oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false), rng)
191191
vrjmean = runSSAs_ode(vrjprob)
192192
@test abs(vrjmean - crjmean) < reltol * crjmean

test/hawkes_test.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ end
133133

134134
# test stepping Coevolve with continuous integrator and bounded jumps
135135
let alg = Coevolve()
136-
for vr_aggregator in (VR_FRM(), VR_Direct())
136+
for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW())
137137
oprob = ODEProblem(f!, u0, tspan, p)
138138
jumps = hawkes_jump(u0, g, h)
139139
jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng)
@@ -152,7 +152,7 @@ end
152152
# test disabling bounded jumps and using continuous integrator
153153
Nsims = 500
154154
let alg = Coevolve()
155-
for vr_aggregator in (VR_FRM(), VR_Direct())
155+
for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW())
156156
oprob = ODEProblem(f!, u0, tspan, p)
157157
jumps = hawkes_jump(u0, g, h)
158158
jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng,

test/monte_carlo_test.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3,
2020
save_everystep = false, dt = 0.001, adaptive = false)
2121
@test allunique(sol.u[1].t)
2222

23+
jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_DirectFW(), rng)
24+
monte_prob = EnsembleProblem(jump_prob)
25+
sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3,
26+
save_everystep = false, dt = 0.001, adaptive = false)
27+
@test allunique(sol.u[1].t)
28+
2329
jump = ConstantRateJump(rate, affect!)
2430
jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false), rng)
2531
monte_prob = EnsembleProblem(jump_prob)

test/remake_test.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ let
7878
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng)
7979
sol = solve(jprob, Tsit5())
8080
@test all(==(0.0), sol[1, :])
81+
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng)
82+
sol = solve(jprob, Tsit5())
83+
@test all(==(0.0), sol[1, :])
8184
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng)
8285
sol = solve(jprob, Tsit5())
8386
@test all(==(0.0), sol[1, :])
@@ -107,6 +110,9 @@ let
107110
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng)
108111
sol = solve(jprob, Tsit5())
109112
@test all(==(0.0), sol[1, :])
113+
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng)
114+
sol = solve(jprob, Tsit5())
115+
@test all(==(0.0), sol[1, :])
110116
jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng)
111117
sol = solve(jprob, Tsit5())
112118
@test all(==(0.0), sol[1, :])

test/save_positions.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,13 @@ let
2222
oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan)
2323
jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1;
2424
urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0)
25-
jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = VR_Direct(), dep_graph = [[1]],
26-
save_positions = (false, true), rng)
27-
sol = solve(jumpproblem, Tsit5(); save_everystep = false)
28-
@test sol.t == [0.0, 30.0]
29-
30-
jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = VR_FRM(), dep_graph = [[1]],
31-
save_positions = (false, true), rng)
32-
sol = solve(jumpproblem, Tsit5(); save_everystep = false)
33-
@test sol.t == [0.0, 30.0]
25+
26+
for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
27+
jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = vr_agg, dep_graph = [[1]],
28+
save_positions = (false, true), rng)
29+
sol = solve(jumpproblem, Tsit5(); save_everystep = false)
30+
@test sol.t == [0.0, 30.0]
31+
end
3432
end
3533
end
3634

test/thread_safety.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ let
2424
ode_prob = ODEProblem(f!, u_0, (0.0, 10))
2525
vrj = VariableRateJump((u,p,t) -> 1.0, integrator -> nothing)
2626

27-
for agg in (VR_FRM(), VR_Direct())
27+
for agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
2828
jump_prob = JumpProblem(ode_prob, Direct(), vrj; vr_aggregator = agg)
2929
prob_func(prob, i, repeat) = deepcopy(prob)
3030
prob = EnsembleProblem(jump_prob,prob_func = prob_func)

test/variable_rate.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ let
285285
rng = StableRNG(seed)
286286
b = 2.0
287287
d = 1.0
288-
n0 = 1
288+
n0 = 1.0
289289
tspan = (0.0, 4.0)
290290
Nsims = 10000
291291
n(t) = n0 * exp((b - d) * t)
@@ -314,7 +314,7 @@ let
314314
ode_prob = ODEProblem(ode_fxn, u0, tspan, p)
315315
dt = 0.1
316316
tsave = range(tspan[1], tspan[2]; step = dt)
317-
for vr_aggregator in (VR_FRM(), VR_Direct())
317+
for vr_aggregator in (VR_Direct(), VR_DirectFW(), VR_FRM())
318318
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator, rng)
319319

320320
for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization()))
@@ -347,9 +347,11 @@ let
347347
prob = ODEProblem(f, [0.2], (0.0, 10.0))
348348

349349
mean_vrfr = run_ensemble(prob, Tsit5(), jump, jump2)
350-
mean_vrdcb = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_Direct())
350+
mean_vrcb = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_Direct())
351+
mean_vrcbfw = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_DirectFW())
351352

352-
@test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05)
353+
@test isapprox(mean_vrfr, mean_vrcb, rtol=0.05)
354+
@test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05)
353355
end
354356

355357
# Test 2: SDE with two variable rate jumps
@@ -364,9 +366,11 @@ let
364366
prob = SDEProblem(f, g, [10.0], (0.0, 10.0))
365367

366368
mean_vrfr = run_ensemble(prob, SRIW1(), jump, jump2)
367-
mean_vrdcb = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_Direct())
369+
mean_vrcb = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_Direct())
370+
mean_vrcbfw = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_DirectFW())
368371

369-
@test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05)
372+
@test isapprox(mean_vrfr, mean_vrcb, rtol=0.05)
373+
@test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05)
370374
end
371375

372376
# Test 3: ODE with analytical solution
@@ -380,14 +384,16 @@ let
380384
prob = ODEProblem(f, [0.2], (0.0, 10.0))
381385

382386
mean_vrfr = run_ensemble(prob, Tsit5(), jump)
383-
mean_vrdcb = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_Direct())
387+
mean_vrcb = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_Direct())
388+
mean_vrcbfw = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_DirectFW())
384389

385390
t = 10.0
386391
u0 = 0.2
387392
analytical_mean = u0 * exp(-t) + λ*(1 - exp(-t))
388393

389394
@test isapprox(mean_vrfr, analytical_mean, rtol=0.05)
390-
@test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05)
395+
@test isapprox(mean_vrfr, mean_vrcb, rtol=0.05)
396+
@test isapprox(mean_vrcb, mean_vrcbfw, rtol=0.05)
391397
end
392398

393399
# Test 4: No. of Jumps
@@ -416,7 +422,7 @@ let
416422
results = Dict()
417423
u0 = [1.0]
418424
tspan = (0.0, 10.0)
419-
for vr_aggregator in (VR_FRM(), VR_Direct())
425+
for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW())
420426
jump_counts = zeros(Int, Nsims)
421427
p = [0.0, 0.0, 0]
422428
prob = ODEProblem(f, u0, tspan, p)
@@ -433,6 +439,9 @@ let
433439
end
434440

435441
mean_jumps_vrfr = results[VR_FRM()].mean_jumps
436-
mean_jumps_vrdcb = results[VR_Direct()].mean_jumps
437-
@test isapprox(mean_jumps_vrfr, mean_jumps_vrdcb, rtol=0.1)
442+
mean_jumps_vrcb = results[VR_Direct()].mean_jumps
443+
mean_jumps_vrcbfw = results[VR_DirectFW()].mean_jumps
444+
445+
@test isapprox(mean_jumps_vrfr, mean_jumps_vrcb, rtol=0.1)
446+
@test isapprox(mean_jumps_vrcb, mean_jumps_vrcbfw, rtol=0.1)
438447
end

0 commit comments

Comments
 (0)