From 7861bc727c5124ee75489b7b3c51158ddec6b63f Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Sun, 10 Sep 2023 15:00:14 -0400 Subject: [PATCH 01/37] Write RSSACR-Direct but it's incorrect. I will turn it into CR-RSSA. --- src/JumpProcesses.jl | 3 +- src/aggregators/aggregators.jl | 3 + src/spatial/bracketing.jl | 43 +++++-- src/spatial/directcrdirect.jl | 5 +- src/spatial/hop_rates.jl | 37 +++--- src/spatial/nsm.jl | 4 +- src/spatial/reaction_rates.jl | 15 +++ src/spatial/rssacrdirect.jl | 228 +++++++++++++++++++++++++++++++++ src/spatial/utils.jl | 25 ++-- test/spatial/ABC.jl | 67 ++++++---- 10 files changed, 364 insertions(+), 66 deletions(-) create mode 100644 src/spatial/rssacrdirect.jl diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 39dd5465..640f832d 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -65,6 +65,7 @@ include("spatial/bracketing.jl") include("spatial/nsm.jl") include("spatial/directcrdirect.jl") +include("spatial/rssacrdirect.jl") include("aggregators/aggregated_api.jl") @@ -101,6 +102,6 @@ export ExtendedJumpArray export CartesianGrid, CartesianGridRej export SpatialMassActionJump export outdegree, num_sites, neighbors -export NSM, DirectCRDirect +export NSM, DirectCRDirect, RSSACRDirect end # module diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index c1553d03..46c220a5 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -159,6 +159,8 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end +struct RSSACRDirect <: AbstractAggregatorAlgorithm end + const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve()) @@ -187,3 +189,4 @@ supports_variablerates(aggregator::Coevolve) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true +is_spatial(aggregator::RSSACRDirect) = true diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 106add2f..290efc56 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -5,9 +5,15 @@ struct LowHigh{T} low::T high::T - LowHigh(low::T, high::T) where {T} = new{T}(deepcopy(low), deepcopy(high)) - LowHigh(pair::Tuple{T,T}) where {T} = new{T}(pair[1], pair[2]) - LowHigh(low_and_high::T) where {T} = new{T}(low_and_high, deepcopy(low_and_high)) + function LowHigh(low::T, high::T; do_copy = true) where {T} + if do_copy + return new{T}(deepcopy(low), deepcopy(high)) + else + return new{T}(low, high) + end + end + LowHigh(pair::Tuple{T,T}; kwargs...) where {T} = LowHigh(pair[1], pair[2]; kwargs...) + LowHigh(low_and_high::T; kwargs...) where {T} = LowHigh(low_and_high, low_and_high; kwargs...) end function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) @@ -16,25 +22,39 @@ function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) end @inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) - @inbounds for (i, uval) in enumerate(u) - u_low_high[i] = LowHigh(get_spec_brackets(bracket_data, i, uval)) + num_species, num_sites = size(u) + update_u_brackets!(u_low_high, bracket_data, u, 1:num_species, 1:num_sites) +end + +@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix, species_vec, sites) + @inbounds for site in sites + for species in species_vec + u_low_high[species, site] = LowHigh(get_spec_brackets(bracket_data, species, u[species, site])) + end end nothing end +function is_outside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} + return u[species, site] < u_low_high.low[species, site] || u[species, site] > u_low_high.high[species, site] +end + ### convenience functions for LowHigh ### -function setindex!(low_high::LowHigh, val::LowHigh, i) - low_high.low[i] = val.low - low_high.high[i] = val.high +function setindex!(low_high::LowHigh, val::LowHigh, i...) + low_high.low[i...] = val.low + low_high.high[i...] = val.high val end +get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low) + function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site) return LowHigh( total_site_rate(rx_rates.low, hop_rates.low, site), total_site_rate(rx_rates.high, hop_rates.high, site)) end +# Compatible with constant rate jumps, because u_low_high.low and u_low_high.high are used in rate(). function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site) update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site) update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site) @@ -44,3 +64,10 @@ function update_hop_rates!(hop_rates::LowHigh, species, u_low_high, site, spatia update_hop_rates!(hop_rates.low, species, u_low_high.low, site, spatial_system) update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system) end + +function reset!(low_high::LowHigh) + reset!(low_high.low) + reset!(low_high.high) +end + +reset!(array::AbstractArray) = fill!(array, zero(eltype(array))) \ No newline at end of file diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index bb44b144..611846c3 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -4,7 +4,6 @@ const MINJUMPRATE = 2.0^exponent(1e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j -#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, SS, U <: PriorityTable, W <: Function} <: @@ -107,12 +106,12 @@ end function initialize!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) - generate_jumps!(p, integrator, params, u, t) + generate_jumps!(p, integrator, u, params, t) nothing end # calculate the next jump / jump time -function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u, t) +function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.next_jump_time = t + randexp(p.rng) / p.rt.gsum p.next_jump_time >= p.end_time && return nothing site = sample(p.rt, p.site_rates, p.rng) diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index ef7f73a5..2b26283c 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -57,29 +57,28 @@ function HopRates(p::Pair{SpecHop, SiteHop}, end """ - update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, spatial_system) + update_hop_rates!(hop_rates::HopRatesGraphDsi, species_vec, u, site, spatial_system) -update rates of all specs in species at site + update rates of all species in species_vec at site """ -function update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, - spatial_system) - @inbounds for spec in species - update_hop_rate!(hop_rates, spec, u, site, spatial_system) +function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) + @inbounds for species in species_vec + rates = hop_rates.rates + old_rate = rates[species, site] + rates[species, site] = evalhoprate(hop_rates, u, species, site, + spatial_system) + hop_rates.sum_rates[site] += rates[species, site] - old_rate + old_rate end end -""" - update_hop_rate!(hop_rates::HopRatesGraphDsi, species, u, site, spatial_system) - -update rates of single species at site -""" -function update_hop_rate!(hop_rates::AbstractHopRates, species, u, site, spatial_system) - rates = hop_rates.rates - @inbounds old_rate = rates[species, site] - @inbounds rates[species, site] = evalhoprate(hop_rates, u, species, site, - spatial_system) - @inbounds hop_rates.sum_rates[site] += rates[species, site] - old_rate - old_rate +function recompute_site_hop_rate(hop_rates::HP, u, site, spatial_system) where {HP <: AbstractHopRates} + rate = zero(eltype(hop_rates.rates)) + num_species = size(hop_rates.rates, 1) + for species in 1:num_species + rate += evalhoprate(hop_rates, u, species, site, spatial_system) + end + return rate end """ @@ -197,7 +196,7 @@ end return hopping rate of species at site """ function evalhoprate(hop_rates::HopRatesGraphDsi, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.hopping_constants[species, site] * + u[species, site] * hop_rates.hopping_constants[species, site] * outdegree(spatial_system, site) end diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 3cfe7eed..a5341665 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -95,12 +95,12 @@ end function initialize!(p::NSMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) - generate_jumps!(p, integrator, params, u, t) + generate_jumps!(p, integrator, u, params, t) nothing end # calculate the next jump / jump time -function generate_jumps!(p::NSMJumpAggregation, integrator, params, u, t) +function generate_jumps!(p::NSMJumpAggregation, integrator, u, params, t) p.next_jump_time, site = top_with_handle(p.pq) p.next_jump_time >= p.end_time && return nothing p.next_jump = sample_jump_direct(p, site) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 2aba9df1..54f38324 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -26,6 +26,7 @@ function RxRates(num_sites::Int, ma_jumps::M) where {M} end num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) +get_majumps(rx_rates::RxRates) = rx_rates.ma_jumps """ reset!(rx_rates::RxRates) @@ -77,6 +78,20 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng) rand(rng) * total_site_rx_rate(rx_rates, site)) end +""" + recompute_site_rx_rate(rx_rates::RxRates, u, site) + +compute the total reaction rate at site at the current state u +""" +function recompute_site_rx_rate(rx_rates::RxRates, u, site) + rate = zero(eltype(rx_rates.rates)) + ma_jumps = rx_rates.ma_jumps + for rx in 1:num_rxs(rx_rates) + rate += eval_massaction_rate(u, rx, ma_jumps, site) + end + return rate +end + # helper functions function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate) @inbounds old_rate = rx_rates.rates[rx, site] diff --git a/src/spatial/rssacrdirect.jl b/src/spatial/rssacrdirect.jl new file mode 100644 index 00000000..8413570c --- /dev/null +++ b/src/spatial/rssacrdirect.jl @@ -0,0 +1,228 @@ +# site chosen with RSSACR, rx or hop chosen with Direct + +############################ RSSACRDirect ################################### +const MINJUMPRATE = 2.0^exponent(1e-12) + +#NOTE state vector u is a matrix. u[i,j] is species i, site j +mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, + VJMAP, JVMAP, SS, U <: PriorityTable, S, F1, F2} <: + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + next_jump::SpatialJump{J} + prev_jump::SpatialJump{J} + next_jump_time::T + end_time::T + bracket_data::BD + u_low_high::LowHigh{M} # species bracketing + rx_rates::LowHigh{RX} + hop_rates::LowHigh{HOP} + site_rates::LowHigh{Vector{T}} + save_positions::Tuple{Bool, Bool} + rng::RNG + dep_gr::DEPGR #dep graph is same for each site + vartojumps_map::VJMAP #vartojumps_map is same for each site + jumptovars_map::JVMAP #jumptovars_map is same for each site + spatial_system::SS + numspecies::Int #number of species + rt::U + rates::F1 # legacy, not used + affects!::F2 # legacy, not used +end + +function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, + u_low_high::LowHigh{M}, rx_rates::LowHigh{RX}, + hop_rates::LowHigh{HOP}, site_rates::LowHigh{Vector{T}}, + sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; + num_specs, minrate = convert(T, MINJUMPRATE), + vartojumps_map = nothing, jumptovars_map = nothing, + dep_graph = nothing, + kwargs...) where {J, T, BD, RX, HOP, RNG, SS, M} + + # a dependency graph is needed + if dep_graph === nothing + dg = make_dependency_graph(num_specs, rx_rates.low.ma_jumps) + else + dg = dep_graph + # make sure each jump depends on itself + add_self_dependencies!(dg) + end + + # a species-to-reactions graph is needed + if vartojumps_map === nothing + vtoj_map = var_to_jumps_map(num_specs, rx_rates.low.ma_jumps) + else + vtoj_map = vartojumps_map + end + + if jumptovars_map === nothing + jtov_map = jump_to_vars_map(rx_rates.low.ma_jumps) + else + jtov_map = jumptovars_map + end + + # mapping from jump rate to group id + minexponent = exponent(minrate) + + # use the largest power of two that is <= the passed in minrate + minrate = 2.0^minexponent + ratetogroup = rate -> priortogid(rate, minexponent) + + # construct an empty initial priority table -- we'll reset this in init + rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) + + RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), Nothing, Nothing, Nothing}( + nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) +end + +############################# Required Functions ############################## +# creating the JumpAggregation structure (function wrapper-based constant jumps) +function aggregate(aggregator::RSSACRDirect, starting_state, p, t, end_time, + constant_jumps, ma_jumps, save_positions, rng; hopping_constants, + spatial_system, bracket_data = nothing, kwargs...) + T = typeof(end_time) + num_species = size(starting_state, 1) + majumps = ma_jumps + if majumps === nothing + majumps = MassActionJump(Vector{T}(), + Vector{Vector{Pair{Int, Int}}}(), + Vector{Vector{Pair{Int, Int}}}()) + end + + next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder + next_jump_time = typemax(T) + rx_rates = LowHigh(RxRates(num_sites(spatial_system), majumps), + RxRates(num_sites(spatial_system), majumps); + do_copy = false) # do not copy ma_jumps + hop_rates = LowHigh(HopRates(hopping_constants, spatial_system), + HopRates(hopping_constants, spatial_system); + do_copy = false) # do not copy hopping_constants + site_rates = LowHigh(zeros(T, num_sites(spatial_system))) + bd = (bracket_data === nothing) ? BracketData{T, eltype(starting_state)}() : + bracket_data + u_low_high = LowHigh(starting_state) + + RSSACRDirectJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, + rx_rates, hop_rates, + site_rates, save_positions, rng, spatial_system; + num_specs = num_species, kwargs...) +end + +# set up a new simulation and calculate the first jump / jump time +function initialize!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) + p.end_time = integrator.sol.prob.tspan[2] + fill_rates_and_get_times!(p, integrator, t) + generate_jumps!(p, integrator, u, params, t) + nothing +end + +# calculate the next jump / jump time +function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) + @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p + time_delta = zero(t) + site = zero(eltype(u)) + while true + site = sample(rt, site_rates.high, rng) + time_delta += randexp(rng) + accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) && break + end + p.next_jump_time = t + time_delta / groupsum(rt) + p.next_jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) + nothing +end + +# execute one jump, changing the system state +function execute_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t, + affects!) + update_state!(p, integrator) + update_dependent_rates!(p, integrator, t) + nothing +end + +######################## SSA specific helper routines ######################## +# Return true if site is accepted. +function accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) + acceptance_threshold = rand(rng) * site_rates.high[site] + if acceptance_threshold < site_rates.low[site] + return true + else + site_rate = recompute_site_hop_rate(hop_rates.low, u, site, spatial_system) + + recompute_site_rx_rate(rx_rates.low, u, site) + return acceptance_threshold < site_rate + end +end + +""" + fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, u, t) + +reset all stucts, reevaluate all rates, repopulate the priority table +""" +function fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, integrator, t) + @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates, rt = aggregation + u = integrator.u + update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) + + reset!(rx_rates) + reset!(hop_rates) + reset!(site_rates) + + rxs = 1:num_rxs(rx_rates.low) + species = 1:(aggregation.numspecies) + + for site in 1:num_sites(spatial_system) + update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) + update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) + site_rates[site] = total_site_rate(rx_rates, hop_rates, site) + end + + # setup PriorityTable + reset!(rt) + for (pid, priority) in enumerate(site_rates.high) + insert!(rt, pid, priority) + end + nothing +end + +""" + update_dependent_rates!(p, integrator, t) + +recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) +""" +function update_dependent_rates!(p::RSSACRDirectJumpAggregation, + integrator, + t) + @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, jumptovars_map, spatial_system = p + + u = integrator.u + site_rates = p.site_rates + jump = p.prev_jump + + if is_hop(p, jump) + species_to_update = jump.jidx + sites_to_update = (jump.src, jump.dst) + else + species_to_update = jumptovars_map[reaction_id_from_jump(p, jump)] + sites_to_update = jump.src + end + + for site in sites_to_update, species in species_to_update + if is_outside_brackets(u_low_high, u, species, site) + update_u_brackets!(u_low_high, bracket_data, u, species, site) + update_rx_rates!(rx_rates, + vartojumps_map[species], + u_low_high, + integrator, + site) + update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) + + oldrate = site_rates.high[site] + site_rates[site] = total_site_rate(p.rx_rates, p.hop_rates, site) + update!(p.rt, site, oldrate, site_rates.high[site]) + end + end +end + +""" + num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) + +number of constant rate jumps +""" +num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) = 0 \ No newline at end of file diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index ddb9db41..16191002 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -27,18 +27,23 @@ end sample jump at site with direct method """ -function sample_jump_direct(p, site) - if rand(p.rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < - total_site_rx_rate(p.rx_rates, site) - rx = sample_rx_at_site(p.rx_rates, site, p.rng) - return SpatialJump(site, rx + p.numspecies, site) +function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) + numspecies = size(hop_rates.rates, 1) + if rand(rng) * (total_site_rate(rx_rates, hop_rates, site)) < + total_site_rx_rate(rx_rates, site) + rx = sample_rx_at_site(rx_rates, site, rng) + return SpatialJump(site, rx + numspecies, site) else - species_to_diffuse, target_site = sample_hop_at_site(p.hop_rates, site, p.rng, - p.spatial_system) + species_to_diffuse, target_site = sample_hop_at_site(hop_rates, site, rng, + spatial_system) return SpatialJump(site, species_to_diffuse, target_site) end end +function sample_jump_direct(p, site) + sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) +end + function total_site_rate(rx_rates::RxRates, hop_rates::AbstractHopRates, site) total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) end @@ -52,10 +57,10 @@ end function update_rates_after_hop!(p, integrator, source_site, target_site, species) u = integrator.u update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, source_site) - update_hop_rate!(p.hop_rates, species, u, source_site, p.spatial_system) + update_hop_rates!(p.hop_rates, species, u, source_site, p.spatial_system) update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, target_site) - update_hop_rate!(p.hop_rates, species, u, target_site, p.spatial_system) + update_hop_rates!(p.hop_rates, species, u, target_site, p.spatial_system) end """ @@ -70,7 +75,7 @@ function update_state!(p, integrator) else rx_index = reaction_id_from_jump(p, jump) @inbounds executerx!((@view integrator.u[:, jump.src]), rx_index, - p.rx_rates.ma_jumps) + get_majumps(p.rx_rates)) end # save jump that was just exectued p.prev_jump = jump diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 558a701a..48a22c3e 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -2,12 +2,12 @@ using JumpProcesses, DiffEqBase # using BenchmarkTools using Test, Graphs -Nsims = 100 +Nsims = 1000 reltol = 0.05 non_spatial_mean = [65.7395, 65.7395, 434.2605] #mean of 10,000 simulations dim = 1 -linear_size = 5 +linear_size = 1 dims = Tuple(repeat([linear_size], dim)) num_nodes = prod(dims) starting_site = trunc(Int, (linear_size^dim + 1) / 2) @@ -47,27 +47,27 @@ end # testing grids = [CartesianGridRej(dims), Graphs.grid(dims)] -jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, - hopping_constants = hopping_constants, - spatial_system = grid, - save_positions = (false, false)) for grid in grids] -push!(jump_problems, - JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false))) -# setup flattenned jump prob -push!(jump_problems, - JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false))) -# test -for spatial_jump_prob in jump_problems - solution = solve(spatial_jump_prob, SSAStepper()) - mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) - mean_end_state = reshape(mean_end_state, num_species, num_nodes) - diff = sum(mean_end_state, dims = 2) - non_spatial_mean - for (i, d) in enumerate(diff) - @test abs(d) < reltol * non_spatial_mean[i] - end -end +# jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, +# hopping_constants = hopping_constants, +# spatial_system = grid, +# save_positions = (false, false)) for grid in grids] +# push!(jump_problems, +# JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, +# spatial_system = grids[1], save_positions = (false, false))) +# # setup flattenned jump prob +# push!(jump_problems, +# JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, +# spatial_system = grids[1], save_positions = (false, false))) +# # test +# for spatial_jump_prob in jump_problems +# solution = solve(spatial_jump_prob, SSAStepper()) +# mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +# mean_end_state = reshape(mean_end_state, num_species, num_nodes) +# diff = sum(mean_end_state, dims = 2) - non_spatial_mean +# for (i, d) in enumerate(diff) +# @test abs(d) < reltol * non_spatial_mean[i] +# end +# end #using non-spatial SSAs to get the mean # non_spatial_rates = [0.1,1.0] @@ -77,3 +77,24 @@ end # non_spatial_prob = DiscreteProblem(u0,(0.0,end_time), non_spatial_rates) # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) # non_spatial_mean = get_mean_end_state(jump_prob, 10000) + +spatial_jump_prob = JumpProblem(prob, RSSACRDirect(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false)) +sol = solve(spatial_jump_prob, SSAStepper()) +mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +mean_end_state = reshape(mean_end_state, num_species, num_nodes) +diff = sum(mean_end_state, dims = 2) - non_spatial_mean +for (i, d) in enumerate(diff) + @test abs(d) < reltol * non_spatial_mean[i] +end + + +spatial_jump_prob = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false)) +sol = solve(spatial_jump_prob, SSAStepper()) +mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +mean_end_state = reshape(mean_end_state, num_species, num_nodes) +diff = sum(mean_end_state, dims = 2) - non_spatial_mean +for (i, d) in enumerate(diff) + @test abs(d) < reltol * non_spatial_mean[i] +end \ No newline at end of file From e36f4663d815ab8dcfa52d47587e618ae00bd873 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 00:38:26 -0400 Subject: [PATCH 02/37] Fix the main part of the SSA code. Time to clean up. --- src/spatial/bracketing.jl | 4 +- src/spatial/hop_rates.jl | 9 +--- src/spatial/reaction_rates.jl | 17 ++----- src/spatial/rssacrdirect.jl | 89 +++++++++++++++++++++++++---------- test/spatial/ABC.jl | 4 +- 5 files changed, 71 insertions(+), 52 deletions(-) diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 290efc56..3c810f95 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -35,8 +35,8 @@ end nothing end -function is_outside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} - return u[species, site] < u_low_high.low[species, site] || u[species, site] > u_low_high.high[species, site] +function is_inside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} + return u_low_high.low[species, site] < u[species, site] < u_low_high.high[species, site] end ### convenience functions for LowHigh ### diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index 2b26283c..9e5f430a 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -72,14 +72,7 @@ function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, sp end end -function recompute_site_hop_rate(hop_rates::HP, u, site, spatial_system) where {HP <: AbstractHopRates} - rate = zero(eltype(hop_rates.rates)) - num_species = size(hop_rates.rates, 1) - for species in 1:num_species - rate += evalhoprate(hop_rates, u, species, site, spatial_system) - end - return rate -end +hop_rate(hop_rates, species, site) = @inbounds hop_rates.rates[species, site] """ total_site_hop_rate(hop_rates::AbstractHopRates, site) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 54f38324..a9c78c9d 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -39,6 +39,9 @@ function reset!(rx_rates::RxRates) nothing end +rx_rate(rx_rates, rx, site) = rx_rates.rates[rx, site] +evalrxrate(rx_rates, u, rx, site) = eval_massaction_rate(u, rx, rx_rates.ma_jumps, site) + """ total_site_rx_rate(rx_rates::RxRates, site) @@ -78,20 +81,6 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng) rand(rng) * total_site_rx_rate(rx_rates, site)) end -""" - recompute_site_rx_rate(rx_rates::RxRates, u, site) - -compute the total reaction rate at site at the current state u -""" -function recompute_site_rx_rate(rx_rates::RxRates, u, site) - rate = zero(eltype(rx_rates.rates)) - ma_jumps = rx_rates.ma_jumps - for rx in 1:num_rxs(rx_rates) - rate += eval_massaction_rate(u, rx, ma_jumps, site) - end - return rate -end - # helper functions function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate) @inbounds old_rate = rx_rates.rates[rx, site] diff --git a/src/spatial/rssacrdirect.jl b/src/spatial/rssacrdirect.jl index 8413570c..80596624 100644 --- a/src/spatial/rssacrdirect.jl +++ b/src/spatial/rssacrdirect.jl @@ -15,7 +15,7 @@ mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, u_low_high::LowHigh{M} # species bracketing rx_rates::LowHigh{RX} hop_rates::LowHigh{HOP} - site_rates::LowHigh{Vector{T}} + site_rates::LowHigh{Vector{T}} # TODO(vilin97): we never use site_rates.low save_positions::Tuple{Bool, Bool} rng::RNG dep_gr::DEPGR #dep graph is same for each site @@ -69,8 +69,24 @@ function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # construct an empty initial priority table -- we'll reset this in init rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) - RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), Nothing, Nothing, Nothing}( - nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) + RSSACRDirectJumpAggregation{ + T, + BD, + M, + RNG, + J, + RX, + HOP, + typeof(dg), + typeof(vtoj_map), + typeof(jtov_map), + SS, + typeof(rt), + Nothing, + Nothing, + Nothing, + }(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, + vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) end ############################# Required Functions ############################## @@ -118,14 +134,16 @@ end function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p time_delta = zero(t) - site = zero(eltype(u)) while true site = sample(rt, site_rates.high, rng) + jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) time_delta += randexp(rng) - accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) && break + if accept_jump(p, u, jump) + p.next_jump_time = t + time_delta / groupsum(rt) + p.next_jump = jump + break + end end - p.next_jump_time = t + time_delta / groupsum(rt) - p.next_jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) nothing end @@ -139,14 +157,37 @@ end ######################## SSA specific helper routines ######################## # Return true if site is accepted. -function accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) - acceptance_threshold = rand(rng) * site_rates.high[site] - if acceptance_threshold < site_rates.low[site] +function accept_jump(p, u, jump) + if is_hop(p, jump) + return accept_hop(p, u, jump) + else + return accept_rx(p, u, jump) + end +end + +function accept_hop(p, u, jump) + @unpack hop_rates, spatial_system, rng = p + species, site = jump.jidx, jump.src + acceptance_threshold = rand(rng) * hop_rate(hop_rates.high, species, site) + if hop_rate(hop_rates.low, species, site) > acceptance_threshold return true else - site_rate = recompute_site_hop_rate(hop_rates.low, u, site, spatial_system) + - recompute_site_rx_rate(rx_rates.low, u, site) - return acceptance_threshold < site_rate + # compute the real rate. Could have used hop_rates.high as well. + real_rate = evalhoprate(hop_rates.low, u, species, site, spatial_system) + return real_rate > acceptance_threshold + end +end + +function accept_rx(p, u, jump) + @unpack rx_rates, rng = p + rx, site = reaction_id_from_jump(p, jump), jump.src + acceptance_threshold = rand(rng) * rx_rate(rx_rates.high, rx, site) + if rx_rate(rx_rates.low, rx, site) > acceptance_threshold + return true + else + # compute the real rate. Could have used rx_rates.high as well. + real_rate = evalrxrate(rx_rates.low, u, rx, site) + return real_rate > acceptance_threshold end end @@ -186,25 +227,20 @@ end recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) """ -function update_dependent_rates!(p::RSSACRDirectJumpAggregation, - integrator, - t) - @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, jumptovars_map, spatial_system = p - - u = integrator.u - site_rates = p.site_rates +function update_dependent_rates!(p::RSSACRDirectJumpAggregation, integrator, t) jump = p.prev_jump - if is_hop(p, jump) - species_to_update = jump.jidx - sites_to_update = (jump.src, jump.dst) + update_brackets!(p, integrator, jump.jidx, (jump.src, jump.dst)) else - species_to_update = jumptovars_map[reaction_id_from_jump(p, jump)] - sites_to_update = jump.src + update_brackets!(p, integrator, p.jumptovars_map[reaction_id_from_jump(p, jump)], jump.src) end +end +function update_brackets!(p, integrator, species_to_update, sites_to_update) + @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, spatial_system = p + u = integrator.u for site in sites_to_update, species in species_to_update - if is_outside_brackets(u_low_high, u, species, site) + if !is_inside_brackets(u_low_high, u, species, site) update_u_brackets!(u_low_high, bracket_data, u, species, site) update_rx_rates!(rx_rates, vartojumps_map[species], @@ -218,6 +254,7 @@ function update_dependent_rates!(p::RSSACRDirectJumpAggregation, update!(p.rt, site, oldrate, site_rates.high[site]) end end + nothing end """ diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 48a22c3e..6dd9c7ee 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -2,12 +2,12 @@ using JumpProcesses, DiffEqBase # using BenchmarkTools using Test, Graphs -Nsims = 1000 +Nsims = 100 reltol = 0.05 non_spatial_mean = [65.7395, 65.7395, 434.2605] #mean of 10,000 simulations dim = 1 -linear_size = 1 +linear_size = 5 dims = Tuple(repeat([linear_size], dim)) num_nodes = prod(dims) starting_site = trunc(Int, (linear_size^dim + 1) / 2) From 5ce7d25bd79430e5795eafdfe2fd91cc730b9133 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:19:57 -0400 Subject: [PATCH 03/37] Rename to `DirectCRRSSA`. --- src/JumpProcesses.jl | 4 +-- src/aggregators/aggregators.jl | 4 +-- .../{rssacrdirect.jl => directcrrssa.jl} | 30 +++++++++---------- test/spatial/ABC.jl | 13 +------- 4 files changed, 20 insertions(+), 31 deletions(-) rename src/spatial/{rssacrdirect.jl => directcrrssa.jl} (90%) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 640f832d..58f42e63 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -65,7 +65,7 @@ include("spatial/bracketing.jl") include("spatial/nsm.jl") include("spatial/directcrdirect.jl") -include("spatial/rssacrdirect.jl") +include("spatial/directcrrssa.jl") include("aggregators/aggregated_api.jl") @@ -102,6 +102,6 @@ export ExtendedJumpArray export CartesianGrid, CartesianGridRej export SpatialMassActionJump export outdegree, num_sites, neighbors -export NSM, DirectCRDirect, RSSACRDirect +export NSM, DirectCRDirect, DirectCRRSSA end # module diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index 46c220a5..8a3cb2f0 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -159,7 +159,7 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end -struct RSSACRDirect <: AbstractAggregatorAlgorithm end +struct DirectCRRSSA <: AbstractAggregatorAlgorithm end const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve()) @@ -189,4 +189,4 @@ supports_variablerates(aggregator::Coevolve) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true -is_spatial(aggregator::RSSACRDirect) = true +is_spatial(aggregator::DirectCRRSSA) = true diff --git a/src/spatial/rssacrdirect.jl b/src/spatial/directcrrssa.jl similarity index 90% rename from src/spatial/rssacrdirect.jl rename to src/spatial/directcrrssa.jl index 80596624..dd6c0af1 100644 --- a/src/spatial/rssacrdirect.jl +++ b/src/spatial/directcrrssa.jl @@ -1,10 +1,10 @@ -# site chosen with RSSACR, rx or hop chosen with Direct +# site chosen with DirectCR, rx or hop chosen with RSSA -############################ RSSACRDirect ################################### +############################ DirectCRRSSA ################################### const MINJUMPRATE = 2.0^exponent(1e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j -mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, +mutable struct DirectCRRSSAJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, SS, U <: PriorityTable, S, F1, F2} <: AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::SpatialJump{J} @@ -28,7 +28,7 @@ mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, affects!::F2 # legacy, not used end -function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, +function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, u_low_high::LowHigh{M}, rx_rates::LowHigh{RX}, hop_rates::LowHigh{HOP}, site_rates::LowHigh{Vector{T}}, sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; @@ -69,7 +69,7 @@ function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # construct an empty initial priority table -- we'll reset this in init rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) - RSSACRDirectJumpAggregation{ + DirectCRRSSAJumpAggregation{ T, BD, M, @@ -91,7 +91,7 @@ end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::RSSACRDirect, starting_state, p, t, end_time, +function aggregate(aggregator::DirectCRRSSA, starting_state, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; hopping_constants, spatial_system, bracket_data = nothing, kwargs...) T = typeof(end_time) @@ -116,14 +116,14 @@ function aggregate(aggregator::RSSACRDirect, starting_state, p, t, end_time, bracket_data u_low_high = LowHigh(starting_state) - RSSACRDirectJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, + DirectCRRSSAJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, rx_rates, hop_rates, site_rates, save_positions, rng, spatial_system; num_specs = num_species, kwargs...) end # set up a new simulation and calculate the first jump / jump time -function initialize!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) +function initialize!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) generate_jumps!(p, integrator, u, params, t) @@ -131,7 +131,7 @@ function initialize!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) end # calculate the next jump / jump time -function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) +function generate_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t) @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p time_delta = zero(t) while true @@ -148,7 +148,7 @@ function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, end # execute one jump, changing the system state -function execute_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t, +function execute_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t, affects!) update_state!(p, integrator) update_dependent_rates!(p, integrator, t) @@ -192,11 +192,11 @@ function accept_rx(p, u, jump) end """ - fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, u, t) + fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, u, t) reset all stucts, reevaluate all rates, repopulate the priority table """ -function fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, integrator, t) +function fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, integrator, t) @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates, rt = aggregation u = integrator.u update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) @@ -227,7 +227,7 @@ end recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) """ -function update_dependent_rates!(p::RSSACRDirectJumpAggregation, integrator, t) +function update_dependent_rates!(p::DirectCRRSSAJumpAggregation, integrator, t) jump = p.prev_jump if is_hop(p, jump) update_brackets!(p, integrator, jump.jidx, (jump.src, jump.dst)) @@ -258,8 +258,8 @@ function update_brackets!(p, integrator, species_to_update, sites_to_update) end """ - num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) + num_constant_rate_jumps(aggregator::DirectCRRSSAJumpAggregation) number of constant rate jumps """ -num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) = 0 \ No newline at end of file +num_constant_rate_jumps(aggregator::DirectCRRSSAJumpAggregation) = 0 \ No newline at end of file diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 6dd9c7ee..c77a95ff 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -78,18 +78,7 @@ grids = [CartesianGridRej(dims), Graphs.grid(dims)] # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) # non_spatial_mean = get_mean_end_state(jump_prob, 10000) -spatial_jump_prob = JumpProblem(prob, RSSACRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false)) -sol = solve(spatial_jump_prob, SSAStepper()) -mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) -mean_end_state = reshape(mean_end_state, num_species, num_nodes) -diff = sum(mean_end_state, dims = 2) - non_spatial_mean -for (i, d) in enumerate(diff) - @test abs(d) < reltol * non_spatial_mean[i] -end - - -spatial_jump_prob = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, +spatial_jump_prob = JumpProblem(prob, DirectCRRSSA(), majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false)) sol = solve(spatial_jump_prob, SSAStepper()) mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) From 2a8bfa79cb5769d5315ea643165f59511841e611 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:22:43 -0400 Subject: [PATCH 04/37] Fix a docstring. --- src/spatial/hop_rates.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index 9e5f430a..d35f370a 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -57,9 +57,9 @@ function HopRates(p::Pair{SpecHop, SiteHop}, end """ - update_hop_rates!(hop_rates::HopRatesGraphDsi, species_vec, u, site, spatial_system) + update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) - update rates of all species in species_vec at site +update rates of all species in species_vec at site """ function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) @inbounds for species in species_vec From 0d96e805ee4feeadf625f61b78145e3691cb0ea0 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:24:50 -0400 Subject: [PATCH 05/37] Shorten a function. --- src/spatial/reaction_rates.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index a9c78c9d..f33293aa 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -56,8 +56,8 @@ end update rates of all reactions in rxs at site """ -function update_rx_rates!(rx_rates::RxRates{F, M}, rxs, u::AbstractMatrix, integrator, - site) where {F, M} +function update_rx_rates!(rx_rates::RxRates, rxs, u, integrator, + site) ma_jumps = rx_rates.ma_jumps @inbounds for rx in rxs rate = eval_massaction_rate(u, rx, ma_jumps, site) @@ -65,11 +65,8 @@ function update_rx_rates!(rx_rates::RxRates{F, M}, rxs, u::AbstractMatrix, integ end end -function update_rx_rates!(rx_rates::RxRates{F, M}, rxs, integrator, - site) where {F, M <: AbstractMassActionJump} - u = integrator.u - update_rx_rates!(rx_rates, rxs, u, integrator, site) -end +update_rx_rates!(rx_rates::RxRates, rxs, integrator, + site) = update_rx_rates!(rx_rates, rxs, integrator.u, integrator, site) """ sample_rx_at_site(rx_rates::RxRates, site, rng) From 07873e8aeece2c31ed410257cd1de0b41e073ecc Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:33:51 -0400 Subject: [PATCH 06/37] Uncomment tests in `ABC.jl` --- test/spatial/ABC.jl | 59 ++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index c77a95ff..394ba3d1 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -47,43 +47,36 @@ end # testing grids = [CartesianGridRej(dims), Graphs.grid(dims)] -# jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, -# hopping_constants = hopping_constants, -# spatial_system = grid, -# save_positions = (false, false)) for grid in grids] -# push!(jump_problems, -# JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, -# spatial_system = grids[1], save_positions = (false, false))) -# # setup flattenned jump prob -# push!(jump_problems, -# JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, -# spatial_system = grids[1], save_positions = (false, false))) -# # test -# for spatial_jump_prob in jump_problems -# solution = solve(spatial_jump_prob, SSAStepper()) -# mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) -# mean_end_state = reshape(mean_end_state, num_species, num_nodes) -# diff = sum(mean_end_state, dims = 2) - non_spatial_mean -# for (i, d) in enumerate(diff) -# @test abs(d) < reltol * non_spatial_mean[i] -# end -# end +jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, + hopping_constants = hopping_constants, + spatial_system = grid, + save_positions = (false, false)) for grid in grids] -#using non-spatial SSAs to get the mean +# SSAs +for alg in [DirectCRDirect(), DirectCRRSSA()] + push!(jump_problems, JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false))) +end + +# setup flattenned jump prob +push!(jump_problems, + JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false))) +# test +for spatial_jump_prob in jump_problems + solution = solve(spatial_jump_prob, SSAStepper()) + mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) + mean_end_state = reshape(mean_end_state, num_species, num_nodes) + diff = sum(mean_end_state, dims = 2) - non_spatial_mean + for (i, d) in enumerate(diff) + @test abs(d) < reltol * non_spatial_mean[i] + end +end + +# using non-spatial SSAs to get the mean # non_spatial_rates = [0.1,1.0] # reactstoch = [[1 => 1, 2 => 1],[3 => 1]] # netstoch = [[1 => -1, 2 => -1, 3 => 1],[1 => 1, 2 => 1, 3 => -1]] # majumps = MassActionJump(non_spatial_rates, reactstoch, netstoch) # non_spatial_prob = DiscreteProblem(u0,(0.0,end_time), non_spatial_rates) # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) -# non_spatial_mean = get_mean_end_state(jump_prob, 10000) - -spatial_jump_prob = JumpProblem(prob, DirectCRRSSA(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false)) -sol = solve(spatial_jump_prob, SSAStepper()) -mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) -mean_end_state = reshape(mean_end_state, num_species, num_nodes) -diff = sum(mean_end_state, dims = 2) - non_spatial_mean -for (i, d) in enumerate(diff) - @test abs(d) < reltol * non_spatial_mean[i] -end \ No newline at end of file +# non_spatial_mean = get_mean_end_state(jump_prob, 10000) \ No newline at end of file From 7fce14e2f49c2b089261dbb5e0bc7b8a9818e770 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:35:27 -0400 Subject: [PATCH 07/37] Delete comment. --- src/spatial/nsm.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index a5341665..b941fea7 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -2,7 +2,6 @@ ############################ NSM ################################### #NOTE state vector u is a matrix. u[i,j] is species i, site j -#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j mutable struct NSMJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, PQ, SS} <: AbstractSSAJumpAggregator{T, S, F1, F2, RNG} From 74346b208394f2619d90c7e48bded6e72f5dbfa8 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:36:28 -0400 Subject: [PATCH 08/37] Add DirectCRRSSA to the diffusion test. --- test/spatial/diffusion.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spatial/diffusion.jl b/test/spatial/diffusion.jl index 1e89f0c6..4f4d0c6c 100644 --- a/test/spatial/diffusion.jl +++ b/test/spatial/diffusion.jl @@ -59,7 +59,7 @@ Nsims = 50000 rel_tol = 0.02 times = 0.0:(tf / num_time_points):tf -algs = [NSM(), DirectCRDirect()] +algs = [NSM(), DirectCRDirect(), DirectCRRSSA()] grids = [CartesianGridRej(dims), Graphs.grid(dims)] jump_problems = JumpProblem[JumpProblem(prob, algs[2], majumps, hopping_constants = hopping_constants, From c5f93dc5fe933fd128b845876f4e6ee2803a85a3 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:50:07 -0400 Subject: [PATCH 09/37] Add `DirectCRRSSA` to `ABC.jl`. --- test/spatial/ABC.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 394ba3d1..6307be0e 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -54,7 +54,7 @@ jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, # SSAs for alg in [DirectCRDirect(), DirectCRRSSA()] - push!(jump_problems, JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false))) + push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false))) end # setup flattenned jump prob From fd3ea1e14545ba1bf66327c3be69b3f116dbb7d7 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:03:11 -0400 Subject: [PATCH 10/37] Shorten `getindex`. --- src/spatial/bracketing.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 0e683c38..a2515bc5 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -40,15 +40,12 @@ function is_inside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where { end ### convenience functions for LowHigh ### -function setindex!(low_high::LowHigh, val::LowHigh, i...) +function setindex!(low_high::LowHigh{A}, val::LowHigh, i...) where {A <: AbstractArray} low_high.low[i...] = val.low low_high.high[i...] = val.high val end - -function getindex(low_high::LowHigh, i) - return LowHigh(low_high.low[i], low_high.high[i]) -end +getindex(low_high::LowHigh{A}, i) where {A <: AbstractArray} = LowHigh(low_high.low[i], low_high.high[i]) get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low) From 107068f18f0a82dd53fe10dc03fcde732a768e5d Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:13:40 -0400 Subject: [PATCH 11/37] Remove the low bound on site rates, as it is not used. --- src/spatial/bracketing.jl | 11 +---------- src/spatial/directcrrssa.jl | 36 ++++++++++++++++++------------------ 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index a2515bc5..12f267a2 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -49,13 +49,6 @@ getindex(low_high::LowHigh{A}, i) where {A <: AbstractArray} = LowHigh(low_high. get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low) -function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site) - return LowHigh( - total_site_rate(rx_rates.low, hop_rates.low, site), - total_site_rate(rx_rates.high, hop_rates.high, site)) -end - -# Compatible with constant rate jumps, because u_low_high.low and u_low_high.high are used in rate(). function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site) update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site) update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site) @@ -69,6 +62,4 @@ end function reset!(low_high::LowHigh) reset!(low_high.low) reset!(low_high.high) -end - -reset!(array::AbstractArray) = fill!(array, zero(eltype(array))) \ No newline at end of file +end \ No newline at end of file diff --git a/src/spatial/directcrrssa.jl b/src/spatial/directcrrssa.jl index dd6c0af1..9bb7b572 100644 --- a/src/spatial/directcrrssa.jl +++ b/src/spatial/directcrrssa.jl @@ -15,7 +15,7 @@ mutable struct DirectCRRSSAJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, u_low_high::LowHigh{M} # species bracketing rx_rates::LowHigh{RX} hop_rates::LowHigh{HOP} - site_rates::LowHigh{Vector{T}} # TODO(vilin97): we never use site_rates.low + site_rates_high::Vector{T} # we do not need site_rates_low save_positions::Tuple{Bool, Bool} rng::RNG dep_gr::DEPGR #dep graph is same for each site @@ -30,7 +30,7 @@ end function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, u_low_high::LowHigh{M}, rx_rates::LowHigh{RX}, - hop_rates::LowHigh{HOP}, site_rates::LowHigh{Vector{T}}, + hop_rates::LowHigh{HOP}, site_rates_high::Vector{T}, sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; num_specs, minrate = convert(T, MINJUMPRATE), vartojumps_map = nothing, jumptovars_map = nothing, @@ -39,7 +39,7 @@ function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # a dependency graph is needed if dep_graph === nothing - dg = make_dependency_graph(num_specs, rx_rates.low.ma_jumps) + dg = make_dependency_graph(num_specs, get_majumps(rx_rates)) else dg = dep_graph # make sure each jump depends on itself @@ -48,13 +48,13 @@ function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # a species-to-reactions graph is needed if vartojumps_map === nothing - vtoj_map = var_to_jumps_map(num_specs, rx_rates.low.ma_jumps) + vtoj_map = var_to_jumps_map(num_specs, get_majumps(rx_rates)) else vtoj_map = vartojumps_map end if jumptovars_map === nothing - jtov_map = jump_to_vars_map(rx_rates.low.ma_jumps) + jtov_map = jump_to_vars_map(get_majumps(rx_rates)) else jtov_map = jumptovars_map end @@ -85,7 +85,7 @@ function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, Nothing, Nothing, Nothing, - }(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, + }(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates_high, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) end @@ -111,14 +111,14 @@ function aggregate(aggregator::DirectCRRSSA, starting_state, p, t, end_time, hop_rates = LowHigh(HopRates(hopping_constants, spatial_system), HopRates(hopping_constants, spatial_system); do_copy = false) # do not copy hopping_constants - site_rates = LowHigh(zeros(T, num_sites(spatial_system))) + site_rates_high = zeros(T, num_sites(spatial_system)) bd = (bracket_data === nothing) ? BracketData{T, eltype(starting_state)}() : bracket_data u_low_high = LowHigh(starting_state) DirectCRRSSAJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, rx_rates, hop_rates, - site_rates, save_positions, rng, spatial_system; + site_rates_high, save_positions, rng, spatial_system; num_specs = num_species, kwargs...) end @@ -132,10 +132,10 @@ end # calculate the next jump / jump time function generate_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t) - @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p + @unpack rng, rt, site_rates_high, rx_rates, hop_rates, spatial_system = p time_delta = zero(t) while true - site = sample(rt, site_rates.high, rng) + site = sample(rt, site_rates_high, rng) jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) time_delta += randexp(rng) if accept_jump(p, u, jump) @@ -197,13 +197,13 @@ end reset all stucts, reevaluate all rates, repopulate the priority table """ function fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, integrator, t) - @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates, rt = aggregation + @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates_high, rt = aggregation u = integrator.u update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) reset!(rx_rates) reset!(hop_rates) - reset!(site_rates) + fill!(site_rates_high, zero(eltype(site_rates_high))) rxs = 1:num_rxs(rx_rates.low) species = 1:(aggregation.numspecies) @@ -211,12 +211,12 @@ function fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, int for site in 1:num_sites(spatial_system) update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) - site_rates[site] = total_site_rate(rx_rates, hop_rates, site) + site_rates_high[site] = total_site_rate(rx_rates.high, hop_rates.high, site) end # setup PriorityTable reset!(rt) - for (pid, priority) in enumerate(site_rates.high) + for (pid, priority) in enumerate(site_rates_high) insert!(rt, pid, priority) end nothing @@ -237,7 +237,7 @@ function update_dependent_rates!(p::DirectCRRSSAJumpAggregation, integrator, t) end function update_brackets!(p, integrator, species_to_update, sites_to_update) - @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, spatial_system = p + @unpack rx_rates, hop_rates, site_rates_high, u_low_high, bracket_data, vartojumps_map, spatial_system = p u = integrator.u for site in sites_to_update, species in species_to_update if !is_inside_brackets(u_low_high, u, species, site) @@ -249,9 +249,9 @@ function update_brackets!(p, integrator, species_to_update, sites_to_update) site) update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) - oldrate = site_rates.high[site] - site_rates[site] = total_site_rate(p.rx_rates, p.hop_rates, site) - update!(p.rt, site, oldrate, site_rates.high[site]) + oldrate = site_rates_high[site] + site_rates_high[site] = total_site_rate(rx_rates.high, hop_rates.high, site) + update!(p.rt, site, oldrate, site_rates_high[site]) end end nothing From e1ed533d516238f764c75f4cf95adb98b3bcbbaa Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:22:47 -0400 Subject: [PATCH 12/37] Add an `@inbounds`. --- src/spatial/hop_rates.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index d35f370a..1b1806a6 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -189,7 +189,7 @@ end return hopping rate of species at site """ function evalhoprate(hop_rates::HopRatesGraphDsi, u, species, site, spatial_system) - u[species, site] * hop_rates.hopping_constants[species, site] * + @inbounds u[species, site] * hop_rates.hopping_constants[species, site] * outdegree(spatial_system, site) end From 3290d30daa27a2314c996ef5d8222ac056cd86bb Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:23:34 -0400 Subject: [PATCH 13/37] Add `AbstractMatrix` back in. --- src/spatial/reaction_rates.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index f33293aa..f2b09367 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -56,7 +56,7 @@ end update rates of all reactions in rxs at site """ -function update_rx_rates!(rx_rates::RxRates, rxs, u, integrator, +function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, site) ma_jumps = rx_rates.ma_jumps @inbounds for rx in rxs From af28404c9671dae596d67da76515ea1b78a3bd84 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:24:10 -0400 Subject: [PATCH 14/37] Remove another change to shorten the PR. --- src/spatial/reaction_rates.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index f2b09367..9ef91d45 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -65,8 +65,11 @@ function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, end end -update_rx_rates!(rx_rates::RxRates, rxs, integrator, - site) = update_rx_rates!(rx_rates, rxs, integrator.u, integrator, site) +function update_rx_rates!(rx_rates::RxRates, rxs, integrator, + site) + u = integrator.u + update_rx_rates!(rx_rates, rxs, u, integrator, site) +end """ sample_rx_at_site(rx_rates::RxRates, site, rng) From f5bdc99ac60d7612d964534b8c59d2792e343aaf Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:25:44 -0400 Subject: [PATCH 15/37] Shorten a function. --- src/spatial/utils.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index 16191002..24e7ea66 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -40,9 +40,7 @@ function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) end end -function sample_jump_direct(p, site) - sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) -end +sample_jump_direct(p, site) = sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) function total_site_rate(rx_rates::RxRates, hop_rates::AbstractHopRates, site) total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) From 37576069c9aabcc16eea1a3ab6923b1f1574aa6f Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:26:29 -0400 Subject: [PATCH 16/37] Swap order of functions. --- src/spatial/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index 24e7ea66..2c50ddc6 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -27,6 +27,8 @@ end sample jump at site with direct method """ +sample_jump_direct(p, site) = sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) + function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) numspecies = size(hop_rates.rates, 1) if rand(rng) * (total_site_rate(rx_rates, hop_rates, site)) < @@ -40,8 +42,6 @@ function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) end end -sample_jump_direct(p, site) = sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) - function total_site_rate(rx_rates::RxRates, hop_rates::AbstractHopRates, site) total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) end From 22e8ecef1cc324ff59d0de8a393b8979e045c853 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:27:16 -0400 Subject: [PATCH 17/37] Remove typos from `ABC.jl`. --- test/spatial/ABC.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 6307be0e..db4dd6c2 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -72,11 +72,11 @@ for spatial_jump_prob in jump_problems end end -# using non-spatial SSAs to get the mean +#using non-spatial SSAs to get the mean # non_spatial_rates = [0.1,1.0] # reactstoch = [[1 => 1, 2 => 1],[3 => 1]] # netstoch = [[1 => -1, 2 => -1, 3 => 1],[1 => 1, 2 => 1, 3 => -1]] # majumps = MassActionJump(non_spatial_rates, reactstoch, netstoch) # non_spatial_prob = DiscreteProblem(u0,(0.0,end_time), non_spatial_rates) # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) -# non_spatial_mean = get_mean_end_state(jump_prob, 10000) \ No newline at end of file +# non_spatial_mean = get_mean_end_state(jump_prob, 10000) From 289a9dab7b7686c8e9ec8bac6f23b936a4d2259a Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Tue, 12 Sep 2023 00:01:00 -0400 Subject: [PATCH 18/37] Fix test. --- test/spatial/bracketing.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/spatial/bracketing.jl b/test/spatial/bracketing.jl index 476b1dd7..95e78352 100644 --- a/test/spatial/bracketing.jl +++ b/test/spatial/bracketing.jl @@ -10,7 +10,6 @@ n = 3 # number of sites # set up spatial system spatial_system = CartesianGrid((n,)) # n sites -site_rates = JP.LowHigh(zeros(n), zeros(n)) # set up reaction rates majump_rates = [0.1] # death at rate 0.1 @@ -36,7 +35,6 @@ integrator = Nothing # only needed for constant rate jumps for site in 1:num_sites(spatial_system) JP.update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) JP.update_hop_rates!(hop_rates, species_vec, u_low_high, site, spatial_system) - site_rates[site] = JP.total_site_rate(rx_rates, hop_rates, site) end # test species brackets From 0be2fd0e6c999cb64446be4a31f4ebde77e0ae45 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Sun, 10 Sep 2023 15:00:14 -0400 Subject: [PATCH 19/37] Write RSSACR-Direct but it's incorrect. I will turn it into CR-RSSA. --- src/JumpProcesses.jl | 3 +- src/aggregators/aggregators.jl | 3 + src/spatial/bracketing.jl | 43 +++++-- src/spatial/directcrdirect.jl | 5 +- src/spatial/hop_rates.jl | 37 +++--- src/spatial/nsm.jl | 4 +- src/spatial/reaction_rates.jl | 15 +++ src/spatial/rssacrdirect.jl | 228 +++++++++++++++++++++++++++++++++ src/spatial/utils.jl | 25 ++-- test/spatial/ABC.jl | 67 ++++++---- 10 files changed, 364 insertions(+), 66 deletions(-) create mode 100644 src/spatial/rssacrdirect.jl diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index ed35ae5f..5aac876d 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -65,6 +65,7 @@ include("spatial/bracketing.jl") include("spatial/nsm.jl") include("spatial/directcrdirect.jl") +include("spatial/rssacrdirect.jl") include("aggregators/aggregated_api.jl") @@ -101,6 +102,6 @@ export ExtendedJumpArray export CartesianGrid, CartesianGridRej export SpatialMassActionJump export outdegree, num_sites, neighbors -export NSM, DirectCRDirect +export NSM, DirectCRDirect, RSSACRDirect end # module diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index 86b81273..cea89520 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -163,6 +163,8 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end +struct RSSACRDirect <: AbstractAggregatorAlgorithm end + const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve()) @@ -191,3 +193,4 @@ supports_variablerates(aggregator::Coevolve) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true +is_spatial(aggregator::RSSACRDirect) = true diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 650cd035..7f62dfac 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -5,9 +5,15 @@ struct LowHigh{T} low::T high::T - LowHigh(low::T, high::T) where {T} = new{T}(deepcopy(low), deepcopy(high)) - LowHigh(pair::Tuple{T,T}) where {T} = new{T}(pair[1], pair[2]) - LowHigh(low_and_high::T) where {T} = new{T}(low_and_high, deepcopy(low_and_high)) + function LowHigh(low::T, high::T; do_copy = true) where {T} + if do_copy + return new{T}(deepcopy(low), deepcopy(high)) + else + return new{T}(low, high) + end + end + LowHigh(pair::Tuple{T,T}; kwargs...) where {T} = LowHigh(pair[1], pair[2]; kwargs...) + LowHigh(low_and_high::T; kwargs...) where {T} = LowHigh(low_and_high, low_and_high; kwargs...) end function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) @@ -16,16 +22,27 @@ function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) end @inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) - @inbounds for (i, uval) in enumerate(u) - u_low_high[i] = LowHigh(get_spec_brackets(bracket_data, i, uval)) + num_species, num_sites = size(u) + update_u_brackets!(u_low_high, bracket_data, u, 1:num_species, 1:num_sites) +end + +@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix, species_vec, sites) + @inbounds for site in sites + for species in species_vec + u_low_high[species, site] = LowHigh(get_spec_brackets(bracket_data, species, u[species, site])) + end end nothing end +function is_outside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} + return u[species, site] < u_low_high.low[species, site] || u[species, site] > u_low_high.high[species, site] +end + ### convenience functions for LowHigh ### -function setindex!(low_high::LowHigh, val::LowHigh, i) - low_high.low[i] = val.low - low_high.high[i] = val.high +function setindex!(low_high::LowHigh, val::LowHigh, i...) + low_high.low[i...] = val.low + low_high.high[i...] = val.high val end @@ -33,12 +50,15 @@ function getindex(low_high::LowHigh, i) return LowHigh(low_high.low[i], low_high.high[i]) end +get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low) + function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site) return LowHigh( total_site_rate(rx_rates.low, hop_rates.low, site), total_site_rate(rx_rates.high, hop_rates.high, site)) end +# Compatible with constant rate jumps, because u_low_high.low and u_low_high.high are used in rate(). function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site) update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site) update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site) @@ -48,3 +68,10 @@ function update_hop_rates!(hop_rates::LowHigh, species, u_low_high, site, spatia update_hop_rates!(hop_rates.low, species, u_low_high.low, site, spatial_system) update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system) end + +function reset!(low_high::LowHigh) + reset!(low_high.low) + reset!(low_high.high) +end + +reset!(array::AbstractArray) = fill!(array, zero(eltype(array))) \ No newline at end of file diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index 56252829..1596c336 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -4,7 +4,6 @@ const MINJUMPRATE = 2.0^exponent(1e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j -#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, SS, U <: PriorityTable, W <: Function} <: @@ -107,12 +106,12 @@ end function initialize!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) - generate_jumps!(p, integrator, params, u, t) + generate_jumps!(p, integrator, u, params, t) nothing end # calculate the next jump / jump time -function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u, t) +function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.next_jump_time = t + randexp(p.rng) / p.rt.gsum p.next_jump_time >= p.end_time && return nothing site = sample(p.rt, p.site_rates, p.rng) diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index ef7f73a5..2b26283c 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -57,29 +57,28 @@ function HopRates(p::Pair{SpecHop, SiteHop}, end """ - update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, spatial_system) + update_hop_rates!(hop_rates::HopRatesGraphDsi, species_vec, u, site, spatial_system) -update rates of all specs in species at site + update rates of all species in species_vec at site """ -function update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, - spatial_system) - @inbounds for spec in species - update_hop_rate!(hop_rates, spec, u, site, spatial_system) +function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) + @inbounds for species in species_vec + rates = hop_rates.rates + old_rate = rates[species, site] + rates[species, site] = evalhoprate(hop_rates, u, species, site, + spatial_system) + hop_rates.sum_rates[site] += rates[species, site] - old_rate + old_rate end end -""" - update_hop_rate!(hop_rates::HopRatesGraphDsi, species, u, site, spatial_system) - -update rates of single species at site -""" -function update_hop_rate!(hop_rates::AbstractHopRates, species, u, site, spatial_system) - rates = hop_rates.rates - @inbounds old_rate = rates[species, site] - @inbounds rates[species, site] = evalhoprate(hop_rates, u, species, site, - spatial_system) - @inbounds hop_rates.sum_rates[site] += rates[species, site] - old_rate - old_rate +function recompute_site_hop_rate(hop_rates::HP, u, site, spatial_system) where {HP <: AbstractHopRates} + rate = zero(eltype(hop_rates.rates)) + num_species = size(hop_rates.rates, 1) + for species in 1:num_species + rate += evalhoprate(hop_rates, u, species, site, spatial_system) + end + return rate end """ @@ -197,7 +196,7 @@ end return hopping rate of species at site """ function evalhoprate(hop_rates::HopRatesGraphDsi, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.hopping_constants[species, site] * + u[species, site] * hop_rates.hopping_constants[species, site] * outdegree(spatial_system, site) end diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 422c2bf1..87a2b7d7 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -95,12 +95,12 @@ end function initialize!(p::NSMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) - generate_jumps!(p, integrator, params, u, t) + generate_jumps!(p, integrator, u, params, t) nothing end # calculate the next jump / jump time -function generate_jumps!(p::NSMJumpAggregation, integrator, params, u, t) +function generate_jumps!(p::NSMJumpAggregation, integrator, u, params, t) p.next_jump_time, site = top_with_handle(p.pq) p.next_jump_time >= p.end_time && return nothing p.next_jump = sample_jump_direct(p, site) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 6d688719..e015515b 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -26,6 +26,7 @@ function RxRates(num_sites::Int, ma_jumps::M) where {M} end num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) +get_majumps(rx_rates::RxRates) = rx_rates.ma_jumps """ reset!(rx_rates::RxRates) @@ -77,6 +78,20 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng) rand(rng) * total_site_rx_rate(rx_rates, site)) end +""" + recompute_site_rx_rate(rx_rates::RxRates, u, site) + +compute the total reaction rate at site at the current state u +""" +function recompute_site_rx_rate(rx_rates::RxRates, u, site) + rate = zero(eltype(rx_rates.rates)) + ma_jumps = rx_rates.ma_jumps + for rx in 1:num_rxs(rx_rates) + rate += eval_massaction_rate(u, rx, ma_jumps, site) + end + return rate +end + # helper functions function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate) @inbounds old_rate = rx_rates.rates[rx, site] diff --git a/src/spatial/rssacrdirect.jl b/src/spatial/rssacrdirect.jl new file mode 100644 index 00000000..8413570c --- /dev/null +++ b/src/spatial/rssacrdirect.jl @@ -0,0 +1,228 @@ +# site chosen with RSSACR, rx or hop chosen with Direct + +############################ RSSACRDirect ################################### +const MINJUMPRATE = 2.0^exponent(1e-12) + +#NOTE state vector u is a matrix. u[i,j] is species i, site j +mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, + VJMAP, JVMAP, SS, U <: PriorityTable, S, F1, F2} <: + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + next_jump::SpatialJump{J} + prev_jump::SpatialJump{J} + next_jump_time::T + end_time::T + bracket_data::BD + u_low_high::LowHigh{M} # species bracketing + rx_rates::LowHigh{RX} + hop_rates::LowHigh{HOP} + site_rates::LowHigh{Vector{T}} + save_positions::Tuple{Bool, Bool} + rng::RNG + dep_gr::DEPGR #dep graph is same for each site + vartojumps_map::VJMAP #vartojumps_map is same for each site + jumptovars_map::JVMAP #jumptovars_map is same for each site + spatial_system::SS + numspecies::Int #number of species + rt::U + rates::F1 # legacy, not used + affects!::F2 # legacy, not used +end + +function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, + u_low_high::LowHigh{M}, rx_rates::LowHigh{RX}, + hop_rates::LowHigh{HOP}, site_rates::LowHigh{Vector{T}}, + sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; + num_specs, minrate = convert(T, MINJUMPRATE), + vartojumps_map = nothing, jumptovars_map = nothing, + dep_graph = nothing, + kwargs...) where {J, T, BD, RX, HOP, RNG, SS, M} + + # a dependency graph is needed + if dep_graph === nothing + dg = make_dependency_graph(num_specs, rx_rates.low.ma_jumps) + else + dg = dep_graph + # make sure each jump depends on itself + add_self_dependencies!(dg) + end + + # a species-to-reactions graph is needed + if vartojumps_map === nothing + vtoj_map = var_to_jumps_map(num_specs, rx_rates.low.ma_jumps) + else + vtoj_map = vartojumps_map + end + + if jumptovars_map === nothing + jtov_map = jump_to_vars_map(rx_rates.low.ma_jumps) + else + jtov_map = jumptovars_map + end + + # mapping from jump rate to group id + minexponent = exponent(minrate) + + # use the largest power of two that is <= the passed in minrate + minrate = 2.0^minexponent + ratetogroup = rate -> priortogid(rate, minexponent) + + # construct an empty initial priority table -- we'll reset this in init + rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) + + RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), Nothing, Nothing, Nothing}( + nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) +end + +############################# Required Functions ############################## +# creating the JumpAggregation structure (function wrapper-based constant jumps) +function aggregate(aggregator::RSSACRDirect, starting_state, p, t, end_time, + constant_jumps, ma_jumps, save_positions, rng; hopping_constants, + spatial_system, bracket_data = nothing, kwargs...) + T = typeof(end_time) + num_species = size(starting_state, 1) + majumps = ma_jumps + if majumps === nothing + majumps = MassActionJump(Vector{T}(), + Vector{Vector{Pair{Int, Int}}}(), + Vector{Vector{Pair{Int, Int}}}()) + end + + next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder + next_jump_time = typemax(T) + rx_rates = LowHigh(RxRates(num_sites(spatial_system), majumps), + RxRates(num_sites(spatial_system), majumps); + do_copy = false) # do not copy ma_jumps + hop_rates = LowHigh(HopRates(hopping_constants, spatial_system), + HopRates(hopping_constants, spatial_system); + do_copy = false) # do not copy hopping_constants + site_rates = LowHigh(zeros(T, num_sites(spatial_system))) + bd = (bracket_data === nothing) ? BracketData{T, eltype(starting_state)}() : + bracket_data + u_low_high = LowHigh(starting_state) + + RSSACRDirectJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, + rx_rates, hop_rates, + site_rates, save_positions, rng, spatial_system; + num_specs = num_species, kwargs...) +end + +# set up a new simulation and calculate the first jump / jump time +function initialize!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) + p.end_time = integrator.sol.prob.tspan[2] + fill_rates_and_get_times!(p, integrator, t) + generate_jumps!(p, integrator, u, params, t) + nothing +end + +# calculate the next jump / jump time +function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) + @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p + time_delta = zero(t) + site = zero(eltype(u)) + while true + site = sample(rt, site_rates.high, rng) + time_delta += randexp(rng) + accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) && break + end + p.next_jump_time = t + time_delta / groupsum(rt) + p.next_jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) + nothing +end + +# execute one jump, changing the system state +function execute_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t, + affects!) + update_state!(p, integrator) + update_dependent_rates!(p, integrator, t) + nothing +end + +######################## SSA specific helper routines ######################## +# Return true if site is accepted. +function accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) + acceptance_threshold = rand(rng) * site_rates.high[site] + if acceptance_threshold < site_rates.low[site] + return true + else + site_rate = recompute_site_hop_rate(hop_rates.low, u, site, spatial_system) + + recompute_site_rx_rate(rx_rates.low, u, site) + return acceptance_threshold < site_rate + end +end + +""" + fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, u, t) + +reset all stucts, reevaluate all rates, repopulate the priority table +""" +function fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, integrator, t) + @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates, rt = aggregation + u = integrator.u + update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) + + reset!(rx_rates) + reset!(hop_rates) + reset!(site_rates) + + rxs = 1:num_rxs(rx_rates.low) + species = 1:(aggregation.numspecies) + + for site in 1:num_sites(spatial_system) + update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) + update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) + site_rates[site] = total_site_rate(rx_rates, hop_rates, site) + end + + # setup PriorityTable + reset!(rt) + for (pid, priority) in enumerate(site_rates.high) + insert!(rt, pid, priority) + end + nothing +end + +""" + update_dependent_rates!(p, integrator, t) + +recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) +""" +function update_dependent_rates!(p::RSSACRDirectJumpAggregation, + integrator, + t) + @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, jumptovars_map, spatial_system = p + + u = integrator.u + site_rates = p.site_rates + jump = p.prev_jump + + if is_hop(p, jump) + species_to_update = jump.jidx + sites_to_update = (jump.src, jump.dst) + else + species_to_update = jumptovars_map[reaction_id_from_jump(p, jump)] + sites_to_update = jump.src + end + + for site in sites_to_update, species in species_to_update + if is_outside_brackets(u_low_high, u, species, site) + update_u_brackets!(u_low_high, bracket_data, u, species, site) + update_rx_rates!(rx_rates, + vartojumps_map[species], + u_low_high, + integrator, + site) + update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) + + oldrate = site_rates.high[site] + site_rates[site] = total_site_rate(p.rx_rates, p.hop_rates, site) + update!(p.rt, site, oldrate, site_rates.high[site]) + end + end +end + +""" + num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) + +number of constant rate jumps +""" +num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) = 0 \ No newline at end of file diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index e265da0b..f7ed7a7b 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -27,18 +27,23 @@ end sample jump at site with direct method """ -function sample_jump_direct(p, site) - if rand(p.rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < - total_site_rx_rate(p.rx_rates, site) - rx = sample_rx_at_site(p.rx_rates, site, p.rng) - return SpatialJump(site, rx + p.numspecies, site) +function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) + numspecies = size(hop_rates.rates, 1) + if rand(rng) * (total_site_rate(rx_rates, hop_rates, site)) < + total_site_rx_rate(rx_rates, site) + rx = sample_rx_at_site(rx_rates, site, rng) + return SpatialJump(site, rx + numspecies, site) else - species_to_diffuse, target_site = sample_hop_at_site(p.hop_rates, site, p.rng, - p.spatial_system) + species_to_diffuse, target_site = sample_hop_at_site(hop_rates, site, rng, + spatial_system) return SpatialJump(site, species_to_diffuse, target_site) end end +function sample_jump_direct(p, site) + sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) +end + function total_site_rate(rx_rates::RxRates, hop_rates::AbstractHopRates, site) total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) end @@ -52,10 +57,10 @@ end function update_rates_after_hop!(p, integrator, source_site, target_site, species) u = integrator.u update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, source_site) - update_hop_rate!(p.hop_rates, species, u, source_site, p.spatial_system) + update_hop_rates!(p.hop_rates, species, u, source_site, p.spatial_system) update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, target_site) - update_hop_rate!(p.hop_rates, species, u, target_site, p.spatial_system) + update_hop_rates!(p.hop_rates, species, u, target_site, p.spatial_system) end """ @@ -70,7 +75,7 @@ function update_state!(p, integrator) else rx_index = reaction_id_from_jump(p, jump) @inbounds executerx!((@view integrator.u[:, jump.src]), rx_index, - p.rx_rates.ma_jumps) + get_majumps(p.rx_rates)) end # save jump that was just executed p.prev_jump = jump diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index c44358e8..ec99d23b 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -4,12 +4,12 @@ using Test, Graphs using StableRNGs rng = StableRNG(12345) -Nsims = 100 +Nsims = 1000 reltol = 0.05 non_spatial_mean = [65.7395, 65.7395, 434.2605] #mean of 10,000 simulations dim = 1 -linear_size = 5 +linear_size = 1 dims = Tuple(repeat([linear_size], dim)) num_nodes = prod(dims) starting_site = trunc(Int, (linear_size^dim + 1) / 2) @@ -49,27 +49,27 @@ end # testing grids = [CartesianGridRej(dims), Graphs.grid(dims)] -jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, - hopping_constants = hopping_constants, - spatial_system = grid, - save_positions = (false, false), rng = rng) for grid in grids] -push!(jump_problems, - JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) -# setup flattenned jump prob -push!(jump_problems, - JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) -# test -for spatial_jump_prob in jump_problems - solution = solve(spatial_jump_prob, SSAStepper()) - mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) - mean_end_state = reshape(mean_end_state, num_species, num_nodes) - diff = sum(mean_end_state, dims = 2) - non_spatial_mean - for (i, d) in enumerate(diff) - @test abs(d) < reltol * non_spatial_mean[i] - end -end +# jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, +# hopping_constants = hopping_constants, +# spatial_system = grid, +# save_positions = (false, false), rng = rng) for grid in grids] +# push!(jump_problems, +# JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, +# spatial_system = grids[1], save_positions = (false, false), rng = rng)) +# # setup flattenned jump prob +# push!(jump_problems, +# JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, +# spatial_system = grids[1], save_positions = (false, false), rng = rng)) +# # test +# for spatial_jump_prob in jump_problems +# solution = solve(spatial_jump_prob, SSAStepper()) +# mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +# mean_end_state = reshape(mean_end_state, num_species, num_nodes) +# diff = sum(mean_end_state, dims = 2) - non_spatial_mean +# for (i, d) in enumerate(diff) +# @test abs(d) < reltol * non_spatial_mean[i] +# end +# end #using non-spatial SSAs to get the mean # non_spatial_rates = [0.1,1.0] @@ -79,3 +79,24 @@ end # non_spatial_prob = DiscreteProblem(u0,(0.0,end_time), non_spatial_rates) # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) # non_spatial_mean = get_mean_end_state(jump_prob, 10000) + +spatial_jump_prob = JumpProblem(prob, RSSACRDirect(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false)) +sol = solve(spatial_jump_prob, SSAStepper()) +mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +mean_end_state = reshape(mean_end_state, num_species, num_nodes) +diff = sum(mean_end_state, dims = 2) - non_spatial_mean +for (i, d) in enumerate(diff) + @test abs(d) < reltol * non_spatial_mean[i] +end + + +spatial_jump_prob = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false)) +sol = solve(spatial_jump_prob, SSAStepper()) +mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +mean_end_state = reshape(mean_end_state, num_species, num_nodes) +diff = sum(mean_end_state, dims = 2) - non_spatial_mean +for (i, d) in enumerate(diff) + @test abs(d) < reltol * non_spatial_mean[i] +end \ No newline at end of file From 675c140e269bc06dd0b430b8b8f0cc869400fcd1 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 00:38:26 -0400 Subject: [PATCH 20/37] Fix the main part of the SSA code. Time to clean up. --- src/spatial/bracketing.jl | 4 +- src/spatial/hop_rates.jl | 9 +--- src/spatial/reaction_rates.jl | 17 ++----- src/spatial/rssacrdirect.jl | 89 +++++++++++++++++++++++++---------- test/spatial/ABC.jl | 4 +- 5 files changed, 71 insertions(+), 52 deletions(-) diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 7f62dfac..0e683c38 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -35,8 +35,8 @@ end nothing end -function is_outside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} - return u[species, site] < u_low_high.low[species, site] || u[species, site] > u_low_high.high[species, site] +function is_inside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} + return u_low_high.low[species, site] < u[species, site] < u_low_high.high[species, site] end ### convenience functions for LowHigh ### diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index 2b26283c..9e5f430a 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -72,14 +72,7 @@ function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, sp end end -function recompute_site_hop_rate(hop_rates::HP, u, site, spatial_system) where {HP <: AbstractHopRates} - rate = zero(eltype(hop_rates.rates)) - num_species = size(hop_rates.rates, 1) - for species in 1:num_species - rate += evalhoprate(hop_rates, u, species, site, spatial_system) - end - return rate -end +hop_rate(hop_rates, species, site) = @inbounds hop_rates.rates[species, site] """ total_site_hop_rate(hop_rates::AbstractHopRates, site) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index e015515b..9ef91d45 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -39,6 +39,9 @@ function reset!(rx_rates::RxRates) nothing end +rx_rate(rx_rates, rx, site) = rx_rates.rates[rx, site] +evalrxrate(rx_rates, u, rx, site) = eval_massaction_rate(u, rx, rx_rates.ma_jumps, site) + """ total_site_rx_rate(rx_rates::RxRates, site) @@ -78,20 +81,6 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng) rand(rng) * total_site_rx_rate(rx_rates, site)) end -""" - recompute_site_rx_rate(rx_rates::RxRates, u, site) - -compute the total reaction rate at site at the current state u -""" -function recompute_site_rx_rate(rx_rates::RxRates, u, site) - rate = zero(eltype(rx_rates.rates)) - ma_jumps = rx_rates.ma_jumps - for rx in 1:num_rxs(rx_rates) - rate += eval_massaction_rate(u, rx, ma_jumps, site) - end - return rate -end - # helper functions function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate) @inbounds old_rate = rx_rates.rates[rx, site] diff --git a/src/spatial/rssacrdirect.jl b/src/spatial/rssacrdirect.jl index 8413570c..80596624 100644 --- a/src/spatial/rssacrdirect.jl +++ b/src/spatial/rssacrdirect.jl @@ -15,7 +15,7 @@ mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, u_low_high::LowHigh{M} # species bracketing rx_rates::LowHigh{RX} hop_rates::LowHigh{HOP} - site_rates::LowHigh{Vector{T}} + site_rates::LowHigh{Vector{T}} # TODO(vilin97): we never use site_rates.low save_positions::Tuple{Bool, Bool} rng::RNG dep_gr::DEPGR #dep graph is same for each site @@ -69,8 +69,24 @@ function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # construct an empty initial priority table -- we'll reset this in init rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) - RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), Nothing, Nothing, Nothing}( - nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) + RSSACRDirectJumpAggregation{ + T, + BD, + M, + RNG, + J, + RX, + HOP, + typeof(dg), + typeof(vtoj_map), + typeof(jtov_map), + SS, + typeof(rt), + Nothing, + Nothing, + Nothing, + }(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, + vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) end ############################# Required Functions ############################## @@ -118,14 +134,16 @@ end function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p time_delta = zero(t) - site = zero(eltype(u)) while true site = sample(rt, site_rates.high, rng) + jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) time_delta += randexp(rng) - accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) && break + if accept_jump(p, u, jump) + p.next_jump_time = t + time_delta / groupsum(rt) + p.next_jump = jump + break + end end - p.next_jump_time = t + time_delta / groupsum(rt) - p.next_jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) nothing end @@ -139,14 +157,37 @@ end ######################## SSA specific helper routines ######################## # Return true if site is accepted. -function accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) - acceptance_threshold = rand(rng) * site_rates.high[site] - if acceptance_threshold < site_rates.low[site] +function accept_jump(p, u, jump) + if is_hop(p, jump) + return accept_hop(p, u, jump) + else + return accept_rx(p, u, jump) + end +end + +function accept_hop(p, u, jump) + @unpack hop_rates, spatial_system, rng = p + species, site = jump.jidx, jump.src + acceptance_threshold = rand(rng) * hop_rate(hop_rates.high, species, site) + if hop_rate(hop_rates.low, species, site) > acceptance_threshold return true else - site_rate = recompute_site_hop_rate(hop_rates.low, u, site, spatial_system) + - recompute_site_rx_rate(rx_rates.low, u, site) - return acceptance_threshold < site_rate + # compute the real rate. Could have used hop_rates.high as well. + real_rate = evalhoprate(hop_rates.low, u, species, site, spatial_system) + return real_rate > acceptance_threshold + end +end + +function accept_rx(p, u, jump) + @unpack rx_rates, rng = p + rx, site = reaction_id_from_jump(p, jump), jump.src + acceptance_threshold = rand(rng) * rx_rate(rx_rates.high, rx, site) + if rx_rate(rx_rates.low, rx, site) > acceptance_threshold + return true + else + # compute the real rate. Could have used rx_rates.high as well. + real_rate = evalrxrate(rx_rates.low, u, rx, site) + return real_rate > acceptance_threshold end end @@ -186,25 +227,20 @@ end recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) """ -function update_dependent_rates!(p::RSSACRDirectJumpAggregation, - integrator, - t) - @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, jumptovars_map, spatial_system = p - - u = integrator.u - site_rates = p.site_rates +function update_dependent_rates!(p::RSSACRDirectJumpAggregation, integrator, t) jump = p.prev_jump - if is_hop(p, jump) - species_to_update = jump.jidx - sites_to_update = (jump.src, jump.dst) + update_brackets!(p, integrator, jump.jidx, (jump.src, jump.dst)) else - species_to_update = jumptovars_map[reaction_id_from_jump(p, jump)] - sites_to_update = jump.src + update_brackets!(p, integrator, p.jumptovars_map[reaction_id_from_jump(p, jump)], jump.src) end +end +function update_brackets!(p, integrator, species_to_update, sites_to_update) + @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, spatial_system = p + u = integrator.u for site in sites_to_update, species in species_to_update - if is_outside_brackets(u_low_high, u, species, site) + if !is_inside_brackets(u_low_high, u, species, site) update_u_brackets!(u_low_high, bracket_data, u, species, site) update_rx_rates!(rx_rates, vartojumps_map[species], @@ -218,6 +254,7 @@ function update_dependent_rates!(p::RSSACRDirectJumpAggregation, update!(p.rt, site, oldrate, site_rates.high[site]) end end + nothing end """ diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index ec99d23b..12d34f19 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -4,12 +4,12 @@ using Test, Graphs using StableRNGs rng = StableRNG(12345) -Nsims = 1000 +Nsims = 100 reltol = 0.05 non_spatial_mean = [65.7395, 65.7395, 434.2605] #mean of 10,000 simulations dim = 1 -linear_size = 1 +linear_size = 5 dims = Tuple(repeat([linear_size], dim)) num_nodes = prod(dims) starting_site = trunc(Int, (linear_size^dim + 1) / 2) From 551af91e24d71fdfc7db97aa62f0fcec45b1466b Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:19:57 -0400 Subject: [PATCH 21/37] Rename to `DirectCRRSSA`. --- src/JumpProcesses.jl | 4 +-- src/aggregators/aggregators.jl | 4 +-- .../{rssacrdirect.jl => directcrrssa.jl} | 30 +++++++++---------- test/spatial/ABC.jl | 13 +------- 4 files changed, 20 insertions(+), 31 deletions(-) rename src/spatial/{rssacrdirect.jl => directcrrssa.jl} (90%) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 5aac876d..900bb389 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -65,7 +65,7 @@ include("spatial/bracketing.jl") include("spatial/nsm.jl") include("spatial/directcrdirect.jl") -include("spatial/rssacrdirect.jl") +include("spatial/directcrrssa.jl") include("aggregators/aggregated_api.jl") @@ -102,6 +102,6 @@ export ExtendedJumpArray export CartesianGrid, CartesianGridRej export SpatialMassActionJump export outdegree, num_sites, neighbors -export NSM, DirectCRDirect, RSSACRDirect +export NSM, DirectCRDirect, DirectCRRSSA end # module diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index cea89520..1a85136d 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -163,7 +163,7 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end -struct RSSACRDirect <: AbstractAggregatorAlgorithm end +struct DirectCRRSSA <: AbstractAggregatorAlgorithm end const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve()) @@ -193,4 +193,4 @@ supports_variablerates(aggregator::Coevolve) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true -is_spatial(aggregator::RSSACRDirect) = true +is_spatial(aggregator::DirectCRRSSA) = true diff --git a/src/spatial/rssacrdirect.jl b/src/spatial/directcrrssa.jl similarity index 90% rename from src/spatial/rssacrdirect.jl rename to src/spatial/directcrrssa.jl index 80596624..dd6c0af1 100644 --- a/src/spatial/rssacrdirect.jl +++ b/src/spatial/directcrrssa.jl @@ -1,10 +1,10 @@ -# site chosen with RSSACR, rx or hop chosen with Direct +# site chosen with DirectCR, rx or hop chosen with RSSA -############################ RSSACRDirect ################################### +############################ DirectCRRSSA ################################### const MINJUMPRATE = 2.0^exponent(1e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j -mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, +mutable struct DirectCRRSSAJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, SS, U <: PriorityTable, S, F1, F2} <: AbstractSSAJumpAggregator{T, S, F1, F2, RNG} next_jump::SpatialJump{J} @@ -28,7 +28,7 @@ mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, affects!::F2 # legacy, not used end -function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, +function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, u_low_high::LowHigh{M}, rx_rates::LowHigh{RX}, hop_rates::LowHigh{HOP}, site_rates::LowHigh{Vector{T}}, sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; @@ -69,7 +69,7 @@ function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # construct an empty initial priority table -- we'll reset this in init rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) - RSSACRDirectJumpAggregation{ + DirectCRRSSAJumpAggregation{ T, BD, M, @@ -91,7 +91,7 @@ end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::RSSACRDirect, starting_state, p, t, end_time, +function aggregate(aggregator::DirectCRRSSA, starting_state, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; hopping_constants, spatial_system, bracket_data = nothing, kwargs...) T = typeof(end_time) @@ -116,14 +116,14 @@ function aggregate(aggregator::RSSACRDirect, starting_state, p, t, end_time, bracket_data u_low_high = LowHigh(starting_state) - RSSACRDirectJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, + DirectCRRSSAJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, rx_rates, hop_rates, site_rates, save_positions, rng, spatial_system; num_specs = num_species, kwargs...) end # set up a new simulation and calculate the first jump / jump time -function initialize!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) +function initialize!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) generate_jumps!(p, integrator, u, params, t) @@ -131,7 +131,7 @@ function initialize!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) end # calculate the next jump / jump time -function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) +function generate_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t) @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p time_delta = zero(t) while true @@ -148,7 +148,7 @@ function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, end # execute one jump, changing the system state -function execute_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t, +function execute_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t, affects!) update_state!(p, integrator) update_dependent_rates!(p, integrator, t) @@ -192,11 +192,11 @@ function accept_rx(p, u, jump) end """ - fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, u, t) + fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, u, t) reset all stucts, reevaluate all rates, repopulate the priority table """ -function fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, integrator, t) +function fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, integrator, t) @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates, rt = aggregation u = integrator.u update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) @@ -227,7 +227,7 @@ end recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) """ -function update_dependent_rates!(p::RSSACRDirectJumpAggregation, integrator, t) +function update_dependent_rates!(p::DirectCRRSSAJumpAggregation, integrator, t) jump = p.prev_jump if is_hop(p, jump) update_brackets!(p, integrator, jump.jidx, (jump.src, jump.dst)) @@ -258,8 +258,8 @@ function update_brackets!(p, integrator, species_to_update, sites_to_update) end """ - num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) + num_constant_rate_jumps(aggregator::DirectCRRSSAJumpAggregation) number of constant rate jumps """ -num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) = 0 \ No newline at end of file +num_constant_rate_jumps(aggregator::DirectCRRSSAJumpAggregation) = 0 \ No newline at end of file diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 12d34f19..f41a86c9 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -80,18 +80,7 @@ grids = [CartesianGridRej(dims), Graphs.grid(dims)] # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) # non_spatial_mean = get_mean_end_state(jump_prob, 10000) -spatial_jump_prob = JumpProblem(prob, RSSACRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false)) -sol = solve(spatial_jump_prob, SSAStepper()) -mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) -mean_end_state = reshape(mean_end_state, num_species, num_nodes) -diff = sum(mean_end_state, dims = 2) - non_spatial_mean -for (i, d) in enumerate(diff) - @test abs(d) < reltol * non_spatial_mean[i] -end - - -spatial_jump_prob = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, +spatial_jump_prob = JumpProblem(prob, DirectCRRSSA(), majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false)) sol = solve(spatial_jump_prob, SSAStepper()) mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) From 5d02237822124ca2593d81307d9a61f42e77e9ab Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:22:43 -0400 Subject: [PATCH 22/37] Fix a docstring. --- src/spatial/hop_rates.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index 9e5f430a..d35f370a 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -57,9 +57,9 @@ function HopRates(p::Pair{SpecHop, SiteHop}, end """ - update_hop_rates!(hop_rates::HopRatesGraphDsi, species_vec, u, site, spatial_system) + update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) - update rates of all species in species_vec at site +update rates of all species in species_vec at site """ function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) @inbounds for species in species_vec From 1a24b08aa813490cc1e6501a2e8e9640036569de Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:24:50 -0400 Subject: [PATCH 23/37] Shorten a function. --- src/spatial/reaction_rates.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 9ef91d45..f33293aa 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -56,7 +56,7 @@ end update rates of all reactions in rxs at site """ -function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, +function update_rx_rates!(rx_rates::RxRates, rxs, u, integrator, site) ma_jumps = rx_rates.ma_jumps @inbounds for rx in rxs @@ -65,11 +65,8 @@ function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, end end -function update_rx_rates!(rx_rates::RxRates, rxs, integrator, - site) - u = integrator.u - update_rx_rates!(rx_rates, rxs, u, integrator, site) -end +update_rx_rates!(rx_rates::RxRates, rxs, integrator, + site) = update_rx_rates!(rx_rates, rxs, integrator.u, integrator, site) """ sample_rx_at_site(rx_rates::RxRates, site, rng) From 4701171e594a328d7b2e0f15c2f5df1861090b7d Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:33:51 -0400 Subject: [PATCH 24/37] Uncomment tests in `ABC.jl` --- test/spatial/ABC.jl | 59 ++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index f41a86c9..6ca346fc 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -49,43 +49,36 @@ end # testing grids = [CartesianGridRej(dims), Graphs.grid(dims)] -# jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, -# hopping_constants = hopping_constants, -# spatial_system = grid, -# save_positions = (false, false), rng = rng) for grid in grids] -# push!(jump_problems, -# JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, -# spatial_system = grids[1], save_positions = (false, false), rng = rng)) -# # setup flattenned jump prob -# push!(jump_problems, -# JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, -# spatial_system = grids[1], save_positions = (false, false), rng = rng)) -# # test -# for spatial_jump_prob in jump_problems -# solution = solve(spatial_jump_prob, SSAStepper()) -# mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) -# mean_end_state = reshape(mean_end_state, num_species, num_nodes) -# diff = sum(mean_end_state, dims = 2) - non_spatial_mean -# for (i, d) in enumerate(diff) -# @test abs(d) < reltol * non_spatial_mean[i] -# end -# end +jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, + hopping_constants = hopping_constants, + spatial_system = grid, + save_positions = (false, false), rng = rng) for grid in grids] -#using non-spatial SSAs to get the mean +# SSAs +for alg in [DirectCRDirect(), DirectCRRSSA()] + push!(jump_problems, JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false), rng = rng)) +end + +# setup flattenned jump prob +push!(jump_problems, + JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false), rng = rng)) +# test +for spatial_jump_prob in jump_problems + solution = solve(spatial_jump_prob, SSAStepper()) + mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) + mean_end_state = reshape(mean_end_state, num_species, num_nodes) + diff = sum(mean_end_state, dims = 2) - non_spatial_mean + for (i, d) in enumerate(diff) + @test abs(d) < reltol * non_spatial_mean[i] + end +end + +# using non-spatial SSAs to get the mean # non_spatial_rates = [0.1,1.0] # reactstoch = [[1 => 1, 2 => 1],[3 => 1]] # netstoch = [[1 => -1, 2 => -1, 3 => 1],[1 => 1, 2 => 1, 3 => -1]] # majumps = MassActionJump(non_spatial_rates, reactstoch, netstoch) # non_spatial_prob = DiscreteProblem(u0,(0.0,end_time), non_spatial_rates) # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) -# non_spatial_mean = get_mean_end_state(jump_prob, 10000) - -spatial_jump_prob = JumpProblem(prob, DirectCRRSSA(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false)) -sol = solve(spatial_jump_prob, SSAStepper()) -mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) -mean_end_state = reshape(mean_end_state, num_species, num_nodes) -diff = sum(mean_end_state, dims = 2) - non_spatial_mean -for (i, d) in enumerate(diff) - @test abs(d) < reltol * non_spatial_mean[i] -end \ No newline at end of file +# non_spatial_mean = get_mean_end_state(jump_prob, 10000) \ No newline at end of file From 4a7e5621db06d56830d44792830c8bbb9c61c1aa Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:35:27 -0400 Subject: [PATCH 25/37] Delete comment. --- src/spatial/nsm.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 87a2b7d7..aa5f75e7 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -2,7 +2,6 @@ ############################ NSM ################################### #NOTE state vector u is a matrix. u[i,j] is species i, site j -#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j mutable struct NSMJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, PQ, SS} <: AbstractSSAJumpAggregator{T, S, F1, F2, RNG} From d3bda542888780a6b470f285f0266cd1178f8c90 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:36:28 -0400 Subject: [PATCH 26/37] Add DirectCRRSSA to the diffusion test. --- test/spatial/diffusion.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spatial/diffusion.jl b/test/spatial/diffusion.jl index 39577c97..536a87dd 100644 --- a/test/spatial/diffusion.jl +++ b/test/spatial/diffusion.jl @@ -61,7 +61,7 @@ Nsims = 10000 rel_tol = 0.02 times = 0.0:(tf / num_time_points):tf -algs = [NSM(), DirectCRDirect()] +algs = [NSM(), DirectCRDirect(), DirectCRRSSA()] grids = [CartesianGridRej(dims), Graphs.grid(dims)] jump_problems = JumpProblem[JumpProblem(prob, algs[2], majumps, hopping_constants = hopping_constants, From d14628033fc2d3d7fa6d5d397d82e0b09f6ad025 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 20:50:07 -0400 Subject: [PATCH 27/37] Add `DirectCRRSSA` to `ABC.jl`. --- test/spatial/ABC.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 6ca346fc..9775df2c 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -56,7 +56,7 @@ jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, # SSAs for alg in [DirectCRDirect(), DirectCRRSSA()] - push!(jump_problems, JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false), rng = rng)) + push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false), rng = rng)) end # setup flattenned jump prob From ab6a9f6f944e4f0851eec221d9dbc40e754e6681 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:03:11 -0400 Subject: [PATCH 28/37] Shorten `getindex`. --- src/spatial/bracketing.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 0e683c38..a2515bc5 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -40,15 +40,12 @@ function is_inside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where { end ### convenience functions for LowHigh ### -function setindex!(low_high::LowHigh, val::LowHigh, i...) +function setindex!(low_high::LowHigh{A}, val::LowHigh, i...) where {A <: AbstractArray} low_high.low[i...] = val.low low_high.high[i...] = val.high val end - -function getindex(low_high::LowHigh, i) - return LowHigh(low_high.low[i], low_high.high[i]) -end +getindex(low_high::LowHigh{A}, i) where {A <: AbstractArray} = LowHigh(low_high.low[i], low_high.high[i]) get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low) From 97a79749e3934087e90b611976337563bdfcd23c Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:13:40 -0400 Subject: [PATCH 29/37] Remove the low bound on site rates, as it is not used. --- src/spatial/bracketing.jl | 11 +---------- src/spatial/directcrrssa.jl | 36 ++++++++++++++++++------------------ 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index a2515bc5..12f267a2 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -49,13 +49,6 @@ getindex(low_high::LowHigh{A}, i) where {A <: AbstractArray} = LowHigh(low_high. get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low) -function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site) - return LowHigh( - total_site_rate(rx_rates.low, hop_rates.low, site), - total_site_rate(rx_rates.high, hop_rates.high, site)) -end - -# Compatible with constant rate jumps, because u_low_high.low and u_low_high.high are used in rate(). function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site) update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site) update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site) @@ -69,6 +62,4 @@ end function reset!(low_high::LowHigh) reset!(low_high.low) reset!(low_high.high) -end - -reset!(array::AbstractArray) = fill!(array, zero(eltype(array))) \ No newline at end of file +end \ No newline at end of file diff --git a/src/spatial/directcrrssa.jl b/src/spatial/directcrrssa.jl index dd6c0af1..9bb7b572 100644 --- a/src/spatial/directcrrssa.jl +++ b/src/spatial/directcrrssa.jl @@ -15,7 +15,7 @@ mutable struct DirectCRRSSAJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, u_low_high::LowHigh{M} # species bracketing rx_rates::LowHigh{RX} hop_rates::LowHigh{HOP} - site_rates::LowHigh{Vector{T}} # TODO(vilin97): we never use site_rates.low + site_rates_high::Vector{T} # we do not need site_rates_low save_positions::Tuple{Bool, Bool} rng::RNG dep_gr::DEPGR #dep graph is same for each site @@ -30,7 +30,7 @@ end function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, u_low_high::LowHigh{M}, rx_rates::LowHigh{RX}, - hop_rates::LowHigh{HOP}, site_rates::LowHigh{Vector{T}}, + hop_rates::LowHigh{HOP}, site_rates_high::Vector{T}, sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; num_specs, minrate = convert(T, MINJUMPRATE), vartojumps_map = nothing, jumptovars_map = nothing, @@ -39,7 +39,7 @@ function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # a dependency graph is needed if dep_graph === nothing - dg = make_dependency_graph(num_specs, rx_rates.low.ma_jumps) + dg = make_dependency_graph(num_specs, get_majumps(rx_rates)) else dg = dep_graph # make sure each jump depends on itself @@ -48,13 +48,13 @@ function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # a species-to-reactions graph is needed if vartojumps_map === nothing - vtoj_map = var_to_jumps_map(num_specs, rx_rates.low.ma_jumps) + vtoj_map = var_to_jumps_map(num_specs, get_majumps(rx_rates)) else vtoj_map = vartojumps_map end if jumptovars_map === nothing - jtov_map = jump_to_vars_map(rx_rates.low.ma_jumps) + jtov_map = jump_to_vars_map(get_majumps(rx_rates)) else jtov_map = jumptovars_map end @@ -85,7 +85,7 @@ function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, Nothing, Nothing, Nothing, - }(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, + }(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates_high, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) end @@ -111,14 +111,14 @@ function aggregate(aggregator::DirectCRRSSA, starting_state, p, t, end_time, hop_rates = LowHigh(HopRates(hopping_constants, spatial_system), HopRates(hopping_constants, spatial_system); do_copy = false) # do not copy hopping_constants - site_rates = LowHigh(zeros(T, num_sites(spatial_system))) + site_rates_high = zeros(T, num_sites(spatial_system)) bd = (bracket_data === nothing) ? BracketData{T, eltype(starting_state)}() : bracket_data u_low_high = LowHigh(starting_state) DirectCRRSSAJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, rx_rates, hop_rates, - site_rates, save_positions, rng, spatial_system; + site_rates_high, save_positions, rng, spatial_system; num_specs = num_species, kwargs...) end @@ -132,10 +132,10 @@ end # calculate the next jump / jump time function generate_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t) - @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p + @unpack rng, rt, site_rates_high, rx_rates, hop_rates, spatial_system = p time_delta = zero(t) while true - site = sample(rt, site_rates.high, rng) + site = sample(rt, site_rates_high, rng) jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) time_delta += randexp(rng) if accept_jump(p, u, jump) @@ -197,13 +197,13 @@ end reset all stucts, reevaluate all rates, repopulate the priority table """ function fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, integrator, t) - @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates, rt = aggregation + @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates_high, rt = aggregation u = integrator.u update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) reset!(rx_rates) reset!(hop_rates) - reset!(site_rates) + fill!(site_rates_high, zero(eltype(site_rates_high))) rxs = 1:num_rxs(rx_rates.low) species = 1:(aggregation.numspecies) @@ -211,12 +211,12 @@ function fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, int for site in 1:num_sites(spatial_system) update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) - site_rates[site] = total_site_rate(rx_rates, hop_rates, site) + site_rates_high[site] = total_site_rate(rx_rates.high, hop_rates.high, site) end # setup PriorityTable reset!(rt) - for (pid, priority) in enumerate(site_rates.high) + for (pid, priority) in enumerate(site_rates_high) insert!(rt, pid, priority) end nothing @@ -237,7 +237,7 @@ function update_dependent_rates!(p::DirectCRRSSAJumpAggregation, integrator, t) end function update_brackets!(p, integrator, species_to_update, sites_to_update) - @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, spatial_system = p + @unpack rx_rates, hop_rates, site_rates_high, u_low_high, bracket_data, vartojumps_map, spatial_system = p u = integrator.u for site in sites_to_update, species in species_to_update if !is_inside_brackets(u_low_high, u, species, site) @@ -249,9 +249,9 @@ function update_brackets!(p, integrator, species_to_update, sites_to_update) site) update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) - oldrate = site_rates.high[site] - site_rates[site] = total_site_rate(p.rx_rates, p.hop_rates, site) - update!(p.rt, site, oldrate, site_rates.high[site]) + oldrate = site_rates_high[site] + site_rates_high[site] = total_site_rate(rx_rates.high, hop_rates.high, site) + update!(p.rt, site, oldrate, site_rates_high[site]) end end nothing From 2b3a7a7920d956631bbb1245e28f7002c9304eaf Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:22:47 -0400 Subject: [PATCH 30/37] Add an `@inbounds`. --- src/spatial/hop_rates.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index d35f370a..1b1806a6 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -189,7 +189,7 @@ end return hopping rate of species at site """ function evalhoprate(hop_rates::HopRatesGraphDsi, u, species, site, spatial_system) - u[species, site] * hop_rates.hopping_constants[species, site] * + @inbounds u[species, site] * hop_rates.hopping_constants[species, site] * outdegree(spatial_system, site) end From 7c7a039c3b3cf8e98b247424c4868f3c24539317 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:23:34 -0400 Subject: [PATCH 31/37] Add `AbstractMatrix` back in. --- src/spatial/reaction_rates.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index f33293aa..f2b09367 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -56,7 +56,7 @@ end update rates of all reactions in rxs at site """ -function update_rx_rates!(rx_rates::RxRates, rxs, u, integrator, +function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, site) ma_jumps = rx_rates.ma_jumps @inbounds for rx in rxs From 92880dbf6c52c203f8aad0ecf9cb99b9d37169fd Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:24:10 -0400 Subject: [PATCH 32/37] Remove another change to shorten the PR. --- src/spatial/reaction_rates.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index f2b09367..9ef91d45 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -65,8 +65,11 @@ function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, end end -update_rx_rates!(rx_rates::RxRates, rxs, integrator, - site) = update_rx_rates!(rx_rates, rxs, integrator.u, integrator, site) +function update_rx_rates!(rx_rates::RxRates, rxs, integrator, + site) + u = integrator.u + update_rx_rates!(rx_rates, rxs, u, integrator, site) +end """ sample_rx_at_site(rx_rates::RxRates, site, rng) From 681be86a848b21e5922e02cf3a33d6a4c3d1c2c0 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:25:44 -0400 Subject: [PATCH 33/37] Shorten a function. --- src/spatial/utils.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index f7ed7a7b..04548b99 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -40,9 +40,7 @@ function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) end end -function sample_jump_direct(p, site) - sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) -end +sample_jump_direct(p, site) = sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) function total_site_rate(rx_rates::RxRates, hop_rates::AbstractHopRates, site) total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) From 6ed6e1d02b9eb045c07d5de0b79e73b4025032e1 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:26:29 -0400 Subject: [PATCH 34/37] Swap order of functions. --- src/spatial/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index 04548b99..ff32aba0 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -27,6 +27,8 @@ end sample jump at site with direct method """ +sample_jump_direct(p, site) = sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) + function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) numspecies = size(hop_rates.rates, 1) if rand(rng) * (total_site_rate(rx_rates, hop_rates, site)) < @@ -40,8 +42,6 @@ function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) end end -sample_jump_direct(p, site) = sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) - function total_site_rate(rx_rates::RxRates, hop_rates::AbstractHopRates, site) total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) end From 0ccc09a07002797206e1ea07bcc573b151477848 Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Mon, 11 Sep 2023 21:27:16 -0400 Subject: [PATCH 35/37] Remove typos from `ABC.jl`. --- test/spatial/ABC.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 9775df2c..25724625 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -74,11 +74,11 @@ for spatial_jump_prob in jump_problems end end -# using non-spatial SSAs to get the mean +#using non-spatial SSAs to get the mean # non_spatial_rates = [0.1,1.0] # reactstoch = [[1 => 1, 2 => 1],[3 => 1]] # netstoch = [[1 => -1, 2 => -1, 3 => 1],[1 => 1, 2 => 1, 3 => -1]] # majumps = MassActionJump(non_spatial_rates, reactstoch, netstoch) # non_spatial_prob = DiscreteProblem(u0,(0.0,end_time), non_spatial_rates) # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) -# non_spatial_mean = get_mean_end_state(jump_prob, 10000) \ No newline at end of file +# non_spatial_mean = get_mean_end_state(jump_prob, 10000) From 355778d1eee834dd61eca548200e592e9cd8c60e Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Tue, 12 Sep 2023 00:01:00 -0400 Subject: [PATCH 36/37] Fix test. --- test/spatial/bracketing.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/spatial/bracketing.jl b/test/spatial/bracketing.jl index 476b1dd7..95e78352 100644 --- a/test/spatial/bracketing.jl +++ b/test/spatial/bracketing.jl @@ -10,7 +10,6 @@ n = 3 # number of sites # set up spatial system spatial_system = CartesianGrid((n,)) # n sites -site_rates = JP.LowHigh(zeros(n), zeros(n)) # set up reaction rates majump_rates = [0.1] # death at rate 0.1 @@ -36,7 +35,6 @@ integrator = Nothing # only needed for constant rate jumps for site in 1:num_sites(spatial_system) JP.update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) JP.update_hop_rates!(hop_rates, species_vec, u_low_high, site, spatial_system) - site_rates[site] = JP.total_site_rate(rx_rates, hop_rates, site) end # test species brackets From 88d22b74793b68a28b8bf01a87d0f70749e729db Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Tue, 20 Aug 2024 23:40:38 -0700 Subject: [PATCH 37/37] Address comments. --- src/spatial/directcrrssa.jl | 9 ++------- src/spatial/hop_rates.jl | 30 ++++++++++++++++++++---------- test/spatial/ABC.jl | 2 +- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/spatial/directcrrssa.jl b/src/spatial/directcrrssa.jl index 9bb7b572..500a6f60 100644 --- a/src/spatial/directcrrssa.jl +++ b/src/spatial/directcrrssa.jl @@ -157,13 +157,8 @@ end ######################## SSA specific helper routines ######################## # Return true if site is accepted. -function accept_jump(p, u, jump) - if is_hop(p, jump) - return accept_hop(p, u, jump) - else - return accept_rx(p, u, jump) - end -end +@inline accept_jump(p, u, jump) = is_hop(p, jump) ? accept_hop(p, u, jump) : accept_rx(p, u, jump) + function accept_hop(p, u, jump) @unpack hop_rates, spatial_system, rng = p diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index 8d88db44..fdab78aa 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -57,23 +57,33 @@ function HopRates(p::Pair{SpecHop, SiteHop}, end """ - update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) + update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, spatial_system) -update rates of all species in species_vec at site +update rates of all specs in species at site """ -function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) - @inbounds for species in species_vec - rates = hop_rates.rates - old_rate = rates[species, site] - rates[species, site] = evalhoprate(hop_rates, u, species, site, - spatial_system) - hop_rates.sum_rates[site] += rates[species, site] - old_rate - old_rate +function update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, + spatial_system) + @inbounds for spec in species + update_hop_rate!(hop_rates, spec, u, site, spatial_system) end end hop_rate(hop_rates, species, site) = @inbounds hop_rates.rates[species, site] +""" + update_hop_rate!(hop_rates::HopRatesGraphDsi, species, u, site, spatial_system) + +update rates of single species at site +""" +function update_hop_rate!(hop_rates::AbstractHopRates, species, u, site, spatial_system) + rates = hop_rates.rates + @inbounds old_rate = rates[species, site] + @inbounds rates[species, site] = evalhoprate(hop_rates, u, species, site, + spatial_system) + @inbounds hop_rates.sum_rates[site] += rates[species, site] - old_rate + old_rate +end + """ total_site_hop_rate(hop_rates::AbstractHopRates, site) diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 24e7ac88..61c5a1dd 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -56,7 +56,7 @@ jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, # SSAs for alg in [DirectCRDirect(), DirectCRRSSA()] - push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false), rng = rng)) + push!(jump_problems, JumpProblem(prob, alg, majumps; hopping_constants, spatial_system = grids[1], save_positions = (false, false), rng)) end # setup flattenned jump prob