diff --git a/src/algorithms.jl b/src/algorithms.jl index a02e90f2..e2e93dfe 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -847,8 +847,102 @@ SKenCarp(;chunk_size=0,autodiff=true,diff_type=Val{:central}, # Jumps -struct TauLeaping <: StochasticDiffEqJumpAdaptiveAlgorithm end -struct CaoTauLeaping <: StochasticDiffEqJumpAdaptiveAlgorithm end +function TauLeaping_docstring( + description::String, + name::String; + references::String = "", + extra_keyword_description::String = "", + extra_keyword_default::String = "") + keyword_default = """ + adaptive = true, + """ * "\n" * extra_keyword_default + + keyword_default_description = """ + - `adaptive`: Boolean to enable/disable adaptive step sizing. When `true`, the step size `τ` is adjusted dynamically based on error estimates or bounds. Defaults to `true`. + """ * "\n" * extra_keyword_description + + docstring = """ + $description + + ### Algorithm Type + Stochastic Jump Method + + ### References + $references + + ### Keyword Arguments + $keyword_default_description + + ### Default Values + $keyword_default + """ + return docstring +end + +@doc TauLeaping_docstring( + "An explicit tau-leaping method for stochastic jump processes with optional post-leap step size adaptivity. " * + "This algorithm approximates the stochastic simulation algorithm (SSA) by advancing the system state over " * + "a fixed time step `τ` using Poisson-distributed jump counts based on initial propensities. When `adaptive=true`, " * + "it adjusts `τ` dynamically based on post-leap error estimates derived from propensity changes.", + "TauLeaping", + references = """@article{gillespie2001approximate, + title={Approximate accelerated stochastic simulation of chemically reacting systems}, + author={Gillespie, Daniel T}, + journal={The Journal of Chemical Physics}, + volume={115}, + number={4}, + pages={1716--1733}, + year={2001}, + publisher={AIP Publishing}}""", + extra_keyword_description = """ + - `dtmax`: Maximum allowed step size. + - `dtmin`: Minimum allowed step size. + """, + extra_keyword_default = """ + dtmax = 10.0, + dtmin = 1e-6 + """) +struct TauLeaping <: StochasticDiffEqJumpAdaptiveAlgorithm + adaptive::Bool +end + +function TauLeaping(; adaptive=true) + TauLeaping(adaptive) +end + +@doc TauLeaping_docstring( + "An adaptive tau-leaping method for stochastic jump processes that selects the step size `τ` prior to each leap " * + "based on bounds on the expected change in state variables. Introduced by Cao et al., this method ensures stability " * + "and accuracy by constraining the relative change in propensities, controlled by the `epsilon` parameter. " * + "When `adaptive=false`, a fixed step size is used.", + "CaoTauLeaping", + references = """@article{cao2006efficient, + title={Efficient step size selection for the tau-leaping simulation method}, + author={Cao, Yang and Gillespie, Daniel T and Petzold, Linda R}, + journal={The Journal of Chemical Physics}, + volume={124}, + number={4}, + pages={044109}, + year={2006}, + publisher={AIP Publishing}}""", + extra_keyword_description = """ + - `epsilon`: Tolerance parameter controlling the relative change in state variables for step size selection. + - `dtmax`: Maximum allowed step size. + - `dtmin`: Minimum allowed step size. + """, + extra_keyword_default = """ + epsilon = 0.03, + dtmax = 10.0, + dtmin = 1e-6 + """) +struct CaoTauLeaping <: StochasticDiffEqJumpAdaptiveAlgorithm + adaptive::Bool + epsilon::Float64 +end + +function CaoTauLeaping(; adaptive=true, epsilon=0.03) + CaoTauLeaping(adaptive, epsilon) +end ################################################################################ diff --git a/src/caches/tau_caches.jl b/src/caches/tau_caches.jl index e047b49c..4c04e343 100644 --- a/src/caches/tau_caches.jl +++ b/src/caches/tau_caches.jl @@ -1,25 +1,39 @@ -struct TauLeapingConstantCache <: StochasticDiffEqConstantCache end - -@cache struct TauLeapingCache{uType,rateType} <: StochasticDiffEqMutableCache +@cache mutable struct TauLeapingCache{uType, rateType, jumpRateType} <: StochasticDiffEqMutableCache u::uType uprev::uType tmp::uType + rate::rateType newrate::rateType - EEstcache::rateType + EEstcache::jumpRateType end -alg_cache(alg::TauLeaping,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = TauLeapingConstantCache() - -function alg_cache(alg::TauLeaping,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} +function alg_cache(alg::TauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tmp = zero(u) + rate = zero(jump_rate_prototype) newrate = zero(jump_rate_prototype) EEstcache = zero(jump_rate_prototype) - TauLeapingCache(u,uprev,tmp,newrate,EEstcache) + TauLeapingCache(u, uprev, tmp, rate, newrate, EEstcache) end -alg_cache(alg::CaoTauLeaping,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = TauLeapingConstantCache() +@cache mutable struct CaoTauLeapingCache{uType, rateType, muType} <: StochasticDiffEqMutableCache + u::uType + uprev::uType + tmp::uType + rate::rateType + mu::muType + sigma2::muType +end -function alg_cache(alg::CaoTauLeaping,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} +function alg_cache(alg::CaoTauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tmp = zero(u) - TauLeapingCache(u,uprev,tmp,nothing,nothing) + rate = zero(jump_rate_prototype) + mu = zero(u) + sigma2 = zero(u) + CaoTauLeapingCache(u, uprev, tmp, rate, mu, sigma2) end + +struct TauLeapingConstantCache <: StochasticDiffEqConstantCache end +struct CaoTauLeapingConstantCache <: StochasticDiffEqConstantCache end + +alg_cache(alg::TauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} = TauLeapingConstantCache() +alg_cache(alg::CaoTauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} = CaoTauLeapingConstantCache() diff --git a/src/integrators/stepsize_controllers.jl b/src/integrators/stepsize_controllers.jl index 01a0206b..9d4163f2 100644 --- a/src/integrators/stepsize_controllers.jl +++ b/src/integrators/stepsize_controllers.jl @@ -19,27 +19,85 @@ end function stepsize_controller!(integrator::SDEIntegrator, alg::TauLeaping) - nothing + nothing # Post-leap adjustment happens in perform_step! end function step_accept_controller!(integrator::SDEIntegrator, alg::TauLeaping) + if alg.adaptive integrator.q = min(integrator.opts.gamma / integrator.EEst, integrator.opts.qmax) return integrator.dt * integrator.q + else + return integrator.dt + end end function step_reject_controller!(integrator::SDEIntegrator, alg::TauLeaping) + if alg.adaptive integrator.dt = integrator.opts.gamma * integrator.dt / integrator.EEst + end end - +# CaoTauLeaping: Pre-leap τ computation function stepsize_controller!(integrator::SDEIntegrator, alg::CaoTauLeaping) - nothing + if !alg.adaptive + return + end + + @unpack u, p, t, P, opts, c = integrator + cache = integrator.cache + + P === nothing && error("CaoTauLeaping requires a JumpProblem with a RegularJump") + + # Handle both constant and mutable caches + if isa(cache, CaoTauLeapingConstantCache) + rate = P.cache.rate(u, p, t) # Compute propensities directly + mu = zero(u) + sigma2 = zero(u) + else # CaoTauLeapingCache + @unpack mu, sigma2, rate = cache + P.cache.rate(rate, u, p, t) # Compute propensities into cache + fill!(mu, zero(eltype(mu))) + fill!(sigma2, zero(eltype(sigma2))) + end + + # Infer ν_ij using c by applying unit counts for each reaction + num_reactions = length(rate) + ν = zeros(eltype(u), length(u), num_reactions) + unit_counts = zeros(eltype(rate), num_reactions) + for j in 1:num_reactions + unit_counts[j] = 1 + c(ν[:, j], u, p, t, unit_counts, nothing) # ν[:, j] is the change vector for reaction j + unit_counts[j] = 0 # Reset + end + + # Compute μ_i and σ_i^2 + for i in eachindex(u) + for j in 1:num_reactions + ν_ij = ν[i, j] + mu[i] += ν_ij * rate[j] + sigma2[i] += ν_ij^2 * rate[j] + end + end + + # Compute τ per species + ϵ = alg.epsilon + τ_vals = similar(u, Float64) + for i in eachindex(u) + max_term = max(ϵ * u[i], 1.0) + τ1 = abs(mu[i]) > 0 ? max_term / abs(mu[i]) : Inf + τ2 = sigma2[i] > 0 ? max_term^2 / sigma2[i] : Inf + τ_vals[i] = min(τ1, τ2) + end + + τ = min(minimum(τ_vals), opts.dtmax) + integrator.dt = max(τ, opts.dtmin) + integrator.EEst = 1.0 end function step_accept_controller!(integrator::SDEIntegrator, alg::CaoTauLeaping) - return integrator.EEst # use EEst for the τ + return integrator.dt end function step_reject_controller!(integrator::SDEIntegrator, alg::CaoTauLeaping) - error("CaoTauLeaping should never reject steps") + error("CaoTauLeaping should never reject steps") end diff --git a/src/perform_step/tau_leaping.jl b/src/perform_step/tau_leaping.jl index 9e0097a2..b0106a2d 100644 --- a/src/perform_step/tau_leaping.jl +++ b/src/perform_step/tau_leaping.jl @@ -1,40 +1,66 @@ -@muladd function perform_step!(integrator,cache::TauLeapingConstantCache) - @unpack t,dt,uprev,u,W,p,P,c = integrator +# Perform Step +@muladd function perform_step!(integrator, cache::TauLeapingConstantCache) + @unpack t, dt, uprev, u, p, P, c = integrator + + P === nothing && error("TauLeaping requires a JumpProblem with a RegularJump") + P.dt = dt tmp = c(uprev, p, t, P.dW, nothing) integrator.u = uprev .+ tmp - if integrator.opts.adaptive - if integrator.alg isa TauLeaping - oldrate = P.cache.currate - newrate = P.cache.rate(integrator.u,p,t+dt) - EEstcache = @. abs(newrate - oldrate) / max(50integrator.opts.reltol*oldrate,integrator.rate_constants/integrator.dt) - integrator.EEst = maximum(EEstcache) - if integrator.EEst <= 1 - P.cache.currate = newrate - end - elseif integrator.alg isa CaoTauLeaping - # Calculate τ as EEst + if integrator.alg.adaptive + oldrate = P.cache.currate + newrate = P.cache.rate(integrator.u, p, t + dt) + EEstcache = @. abs(newrate - oldrate) / max(50 * integrator.opts.reltol * oldrate, integrator.rate_constants / integrator.dt) + integrator.EEst = integrator.opts.internalnorm(EEstcache, t) + if integrator.EEst <= 1 + P.cache.currate = newrate end + else + integrator.EEst = 1.0 end end -@muladd function perform_step!(integrator,cache::TauLeapingCache) - @unpack t,dt,uprev,u,W,p,P,c = integrator - @unpack tmp, newrate, EEstcache = cache +@muladd function perform_step!(integrator, cache::TauLeapingCache) + @unpack t, dt, uprev, u, p, P, c = integrator + @unpack tmp, rate, newrate, EEstcache = cache + + P === nothing && error("TauLeaping requires a JumpProblem with a RegularJump") + P.dt = dt c(tmp, uprev, p, t, P.dW, nothing) @.. u = uprev + tmp - if integrator.opts.adaptive - if integrator.alg isa TauLeaping - oldrate = P.cache.currate - P.cache.rate(newrate,u,p,t+dt) - @.. EEstcache = abs(newrate - oldrate) / max(50integrator.opts.reltol*oldrate,integrator.rate_constants/integrator.dt) - integrator.EEst = maximum(EEstcache) - if integrator.EEst <= 1 - P.cache.currate .= newrate - end - elseif integrator.alg isa CaoTauLeaping - # Calculate τ as EEst + if integrator.alg.adaptive + P.cache.rate(newrate, u, p, t + dt) + P.cache.rate(rate, uprev, p, t) + @.. EEstcache = abs(newrate - rate) / max(50 * integrator.opts.reltol * rate, integrator.rate_constants / integrator.dt) + integrator.EEst = integrator.opts.internalnorm(EEstcache, t) + if integrator.EEst <= 1 + P.cache.currate .= newrate end + else + integrator.EEst = 1.0 end end + +@muladd function perform_step!(integrator, cache::CaoTauLeapingConstantCache) + @unpack t, dt, uprev, u, p, P, c = integrator + + P === nothing && error("CaoTauLeaping requires a JumpProblem with a RegularJump") + P.dt = dt + tmp = c(uprev, p, t, P.dW, nothing) + integrator.u = uprev .+ tmp + + integrator.EEst = 1.0 +end + +@muladd function perform_step!(integrator, cache::CaoTauLeapingCache) + @unpack t, dt, uprev, u, p, P, c = integrator + @unpack tmp = cache + + P === nothing && error("CaoTauLeaping requires a JumpProblem with a RegularJump") + P.dt = dt + c(tmp, uprev, p, t, P.dW, nothing) + @.. u = uprev + tmp + + integrator.EEst = 1.0 +end diff --git a/test/tau_leaping.jl b/test/tau_leaping.jl index 4fb87bab..c333630b 100644 --- a/test/tau_leaping.jl +++ b/test/tau_leaping.jl @@ -29,10 +29,14 @@ jump_iipprob = JumpProblem(iip_prob,Direct(),rj) N = 40_000 sol1 = solve(EnsembleProblem(jump_iipprob),SimpleTauLeaping();dt=1.0,trajectories = N) sol2 = solve(EnsembleProblem(jump_iipprob),TauLeaping();dt=1.0,adaptive=false,save_everystep=false,trajectories = N) +sol3 = solve(EnsembleProblem(jump_iipprob),CaoTauLeaping();dt=1.0,trajectories = N) mean1 = mean([sol1[i][end,end] for i in 1:N]) mean2 = mean([sol2[i][end,end] for i in 1:N]) +mean3 = mean([sol3[i][end,end] for i in 1:N]) @test mean1 ≈ mean2 rtol=1e-2 +@test mean2 ≈ mean3 rtol=1e-2 +@test mean1 ≈ mean3 rtol=1e-2 f(du,u,p,t) = (du .= 0) g(du,u,p,t) = (du .= 0) @@ -68,8 +72,12 @@ jump_prob = JumpProblem(prob,Direct(),rj) sol = solve(jump_prob,TauLeaping(),reltol=5e-2) sol2 = solve(EnsembleProblem(jump_prob),TauLeaping();dt=1.0,adaptive=false,save_everystep=false,trajectories = N) +sol3 = solve(EnsembleProblem(jump_prob),CaoTauLeaping();dt=1.0,adaptive=false,save_everystep=false,trajectories = N) mean2 = mean([sol2[i][end,end] for i in 1:N]) +mean3 = mean([sol3[i][end,end] for i in 1:N]) @test mean1 ≈ mean2 rtol=1e-2 +@test mean2 ≈ mean3 rtol=1e-2 +@test mean1 ≈ mean3 rtol=1e-2 foop(u,p,t) = [0.0,0.0,0.0] goop(u,p,t) = [0.0,0.0,0.0]