231
231
update_jumps! (du, u, p, t, idx, jumps... )
232
232
end
233
233
234
- # ################################## VR_Direct ####################################
234
+ # ################################## VR_Direct and VR_DirectFW ####################################
235
235
236
236
"""
237
237
$(TYPEDEF)
@@ -240,10 +240,11 @@ A concrete `VariableRateAggregator` implementing a direct method-based approach
240
240
simulating `VariableRateJump`s. `VR_Direct` (Variable Rate Direct Callback) efficiently
241
241
samples jump times using one continuous callback to integrate the total intensity /
242
242
propensity for all `VariableRateJump`s, sample when the next jump occurs, and then sample
243
- which jump occurs at this time.
243
+ which jump occurs at this time. `VR_DirectFW` a separate FunctionWrapper mode, which
244
+ wraps things in FunctionWrappers in cases with large numbers of jumps
244
245
245
246
## Examples
246
- Simulating a birth-death process with `VR_Direct` (default):
247
+ Simulating a birth-death process with `VR_Direct` (default) and VR_DirectFW :
247
248
```julia
248
249
using JumpProcesses, OrdinaryDiffEq
249
250
u0 = [1.0] # Initial population
@@ -264,14 +265,18 @@ death_jump = VariableRateJump(death_rate, death_affect!)
264
265
oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p)
265
266
jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_Direct())
266
267
sol = solve(jprob, Tsit5())
268
+
269
+ jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_DirectFW())
270
+ sol = solve(jprob, Tsit5())
267
271
```
268
272
269
273
## Notes
270
- - `VR_Direct` is expected to generally be more performant than `VR_FRM`.
274
+ - `VR_Direct` and `VR_DirectFW` are expected to generally be more performant than `VR_FRM`.
271
275
"""
272
276
struct VR_Direct <: VariableRateAggregator end
277
+ struct VR_DirectFW <: VariableRateAggregator end
273
278
274
- mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG , F1, F2}
279
+ mutable struct VR_DirectEventCache{T, RNG, F1, F2}
275
280
prev_time:: T
276
281
prev_threshold:: T
277
282
current_time:: T
@@ -281,20 +286,36 @@ mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG, F1, F2}
281
286
rate_funcs:: F1
282
287
affect_funcs:: F2
283
288
cum_rate_sum:: Vector{T}
289
+ end
284
290
285
- function VR_DirectEventCache (jumps:: JumpSet , :: Type{T} ; rng = DEFAULT_RNG) where T
286
- initial_threshold = randexp (rng, T)
287
- vjumps = jumps. variable_jumps
291
+ function VR_DirectEventCache (jumps:: JumpSet , :: VR_Direct , prob , :: Type{T} ; rng = DEFAULT_RNG) where T
292
+ initial_threshold = randexp (rng, T)
293
+ vjumps = jumps. variable_jumps
288
294
289
- # handle vjumps using tuples
290
- rate_funcs, affect_funcs = get_jump_info_tuples (vjumps)
295
+ # handle vjumps using tuples
296
+ rate_funcs, affect_funcs = get_jump_info_tuples (vjumps)
291
297
292
- cum_rate_sum = Vector {T} (undef, length (vjumps))
298
+ cum_rate_sum = Vector {T} (undef, length (vjumps))
293
299
294
- new {T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)} (zero (T),
295
- initial_threshold, zero (T), initial_threshold, zero (T), rng, rate_funcs,
296
- affect_funcs, cum_rate_sum)
297
- end
300
+ VR_DirectEventCache {T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)} (zero (T),
301
+ initial_threshold, zero (T), initial_threshold, zero (T), rng, rate_funcs,
302
+ affect_funcs, cum_rate_sum)
303
+ end
304
+
305
+ function VR_DirectEventCache (jumps:: JumpSet , :: VR_DirectFW , prob, :: Type{T} ; rng = DEFAULT_RNG) where T
306
+ initial_threshold = randexp (rng, T)
307
+ vjumps = jumps. variable_jumps
308
+
309
+ t, u = prob. tspan[1 ], prob. u0
310
+
311
+ # handle vjumps using tuples
312
+ rate_funcs, affect_funcs = get_jump_info_fwrappers (u, prob. p, t, vjumps)
313
+
314
+ cum_rate_sum = Vector {T} (undef, length (vjumps))
315
+
316
+ VR_DirectEventCache {T, typeof(rng), typeof(rate_funcs), Any} (zero (T),
317
+ initial_threshold, zero (T), initial_threshold, zero (T), rng, rate_funcs,
318
+ affect_funcs, cum_rate_sum)
298
319
end
299
320
300
321
# Initialization function for VR_DirectEventCache
@@ -308,8 +329,24 @@ function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrato
308
329
nothing
309
330
end
310
331
332
+ @inline function concretize_vr_direct_affects! (cache:: VR_DirectEventCache ,
333
+ :: I ) where {I <: DiffEqBase.DEIntegrator }
334
+ if (cache. affect_funcs isa Vector) &&
335
+ ! (cache. affect_funcs isa Vector{FunctionWrappers. FunctionWrapper{Nothing, Tuple{I}}})
336
+ AffectWrapper = FunctionWrappers. FunctionWrapper{Nothing, Tuple{I}}
337
+ cache. affect_funcs = AffectWrapper[makewrapper (AffectWrapper, aff) for aff in cache. affect_funcs]
338
+ end
339
+ nothing
340
+ end
341
+
342
+ @inline function concretize_vr_direct_affects! (cache:: VR_DirectEventCache{T, RNG, F1, F2} ,
343
+ :: I ) where {T, RNG, F1, F2 <: Tuple , I <: DiffEqBase.DEIntegrator }
344
+ nothing
345
+ end
346
+
311
347
# Wrapper for initialize to match ContinuousCallback signature
312
348
function initialize_vr_direct_wrapper (cb:: ContinuousCallback , u, t, integrator)
349
+ concretize_vr_direct_affects! (cb. condition, integrator)
313
350
initialize_vr_direct_cache! (cb. condition, u, t, integrator)
314
351
u_modified! (integrator, false )
315
352
nothing
334
371
335
372
function configure_jump_problem (prob, :: VR_Direct , jumps, cvrjs; rng = DEFAULT_RNG)
336
373
new_prob = prob
337
- cache = VR_DirectEventCache (jumps, eltype (prob. tspan); rng)
374
+ cache = VR_DirectEventCache (jumps, VR_Direct (), prob, eltype (prob. tspan); rng)
375
+ variable_jump_callback = build_variable_integcallback (cache, cvrjs)
376
+ cont_agg = cvrjs
377
+ return new_prob, variable_jump_callback, cont_agg
378
+ end
379
+
380
+ function configure_jump_problem (prob, :: VR_DirectFW , jumps, cvrjs; rng = DEFAULT_RNG)
381
+ new_prob = prob
382
+ cache = VR_DirectEventCache (jumps, VR_DirectFW (), prob, eltype (prob. tspan); rng)
338
383
variable_jump_callback = build_variable_integcallback (cache, cvrjs)
339
384
cont_agg = cvrjs
340
385
return new_prob, variable_jump_callback, cont_agg
@@ -402,9 +447,21 @@ function (cache::VR_DirectEventCache)(u, t, integrator)
402
447
return cache. current_threshold
403
448
end
404
449
405
- @generated function execute_affect! (cache:: VR_DirectEventCache{T, RNG, F1, F2} , integrator, idx) where {T, RNG, F1, F2 <: Tuple }
450
+ @generated function execute_affect! (cache:: VR_DirectEventCache{T, RNG, F1, F2} ,
451
+ integrator:: I , idx) where {T, RNG, F1, F2 <: Tuple , I <: DiffEqBase.DEIntegrator }
406
452
quote
407
- Base. Cartesian. @nif $ (fieldcount (F2)) i -> (i == idx) i -> (@inbounds cache. affect_funcs[i](integrator)) i -> (@inbounds cache. affect_funcs[fieldcount (F2)](integrator))
453
+ @unpack affect_funcs = cache
454
+ Base. Cartesian. @nif $ (fieldcount (F2)) i -> (i == idx) i -> (@inbounds affect_funcs[i](integrator)) i -> (@inbounds affect_funcs[fieldcount (F2)](integrator))
455
+ end
456
+ end
457
+
458
+ @inline function execute_affect! (cache:: VR_DirectEventCache ,
459
+ integrator:: I , idx) where {I <: DiffEqBase.DEIntegrator }
460
+ @unpack affect_funcs = cache
461
+ if affect_funcs isa Vector{FunctionWrappers. FunctionWrapper{Nothing, Tuple{I}}}
462
+ @inbounds affect_funcs[idx](integrator)
463
+ else
464
+ error (" Error, invalid affect_funcs type. Expected a vector of function wrappers and got $(typeof (affect_funcs)) " )
408
465
end
409
466
end
410
467
0 commit comments