Skip to content

Commit 9c035aa

Browse files
committed
refactor to reuse expression generation code
1 parent 6af889f commit 9c035aa

File tree

3 files changed

+120
-71
lines changed

3 files changed

+120
-71
lines changed

src/maketype.jl

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function maketype(abstracttype,
2828
f_symfuncs::Union{Matrix{SymEngine.Basic},Nothing}
2929
g::Union{Function,Nothing}
3030
g_func::Union{Vector{Any},Nothing}
31-
jumps::Union{Tuple{DiffEqJump.AbstractJump,Vararg{DiffEqJump.AbstractJump}},Nothing}
31+
jumps::Union{Tuple{Vararg{DiffEqJump.AbstractJump}},Nothing}
3232
regular_jumps::Union{RegularJump,Nothing}
3333
jump_rate_expr::Union{Tuple{Any,Vararg{Any}},Nothing}
3434
jump_affect_expr::Union{Tuple{Vector{Expr},Vararg{Vector{Expr}}},Nothing}
@@ -90,19 +90,52 @@ function maketype(abstracttype,
9090
typeex,constructorex
9191
end
9292

93+
# type function expressions
94+
function gentypefun_exprs(name; esc_exprs=true, gen_inplace=true, gen_outofplace=true, gen_constructor=true)
95+
exprs = Vector{Expr}(undef,0)
96+
97+
## Overload the type so that it can act as a function.
98+
if gen_inplace
99+
overloadex = :(((f::$name))(du, u, p, t::Number) = (f.f(du, u, p, t); nothing))
100+
push!(exprs,overloadex)
101+
end
102+
103+
## Add a method which allocates the `du` and returns it instead of being inplace
104+
if gen_outofplace
105+
overloadex = :(((f::$name))(u,p,t::Number) = (du=similar(u); f(du,u,p,t); du))
106+
push!(exprs,overloadex)
107+
end
108+
109+
# export type constructor
110+
if gen_constructor
111+
def_const_ex = :(($name)())
112+
push!(exprs,def_const_ex)
113+
end
114+
115+
# escape expressions for macros
116+
if esc_exprs
117+
for i in eachindex(exprs)
118+
exprs[i] = exprs[i] |> esc
119+
end
120+
end
121+
122+
exprs
123+
end
124+
93125
function addodes!(rn::DiffEqBase.AbstractReactionNetwork)
94-
@unpack reactions, syms_to_ints, params_to_ints = rn
126+
@unpack reactions, syms_to_ints, params_to_ints, syms = rn
95127

96-
f_expr = get_f(reactions, syms_to_ints)
97-
rn.f = eval(make_func(f_expr, syms_to_ints, params_to_ints))
98-
rn.f_func = [element.args[2] for element in f_expr]
99-
rn.symjac = eval( Expr(:quote, calculate_jac(deepcopy(rn.f_func), rn.syms)) )
100-
rn.f_symfuncs = hcat([SymEngine.Basic(f) for f in rn.f_func])
128+
(f_expr, f, f_rhs, symjac, f_symfuncs) = genode_exprs(reactions, syms_to_ints, params_to_ints, syms)
129+
rn.f = eval(f)
130+
rn.f_func = f_rhs
131+
rn.symjac = eval(symjac)
132+
rn.f_symfuncs = f_symfuncs
101133
rn.odefun = ODEFunction(rn.f; syms=rn.syms)
102134

103135
# functor for evaluating f
104-
eval( :(((f::typeof($rn)))(du, u, p, t::Number) = f.f(du, u, p, t)) )
105-
136+
functor_exprs = gentypefun_exprs(typeof(rn), esc_exprs=false, gen_constructor=false)
137+
eval( expr_arr_to_block(functor_exprs) )
138+
106139
nothing
107140
end
108141

@@ -114,25 +147,32 @@ function addsdes!(rn::DiffEqBase.AbstractReactionNetwork)
114147
addodes!(rn)
115148
end
116149

117-
g_expr = get_g(reactions, syms_to_ints, scale_noise)
118-
rn.g = eval(make_func(g_expr, syms_to_ints, params_to_ints))
119-
rn.g_func = [element.args[2] for element in g_expr]
120-
rn.p_matrix = zeros(length(syms_to_ints), length(reactions))
150+
(g_expr, g, g_funcs, p_matrix) = gensde_exprs(reactions, syms_to_ints, params_to_ints, scale_noise)
151+
rn.g = eval(g)
152+
rn.g_func = g_funcs
153+
rn.p_matrix = p_matrix
121154
rn.sdefun = SDEFunction(rn.f, rn.g; syms=rn.syms)
122155

123156
nothing
124157
end
125158

126-
function addjumps!(rn::DiffEqBase.AbstractReactionNetwork)
159+
function addjumps!(rn::DiffEqBase.AbstractReactionNetwork;
160+
build_jumps=true,
161+
build_regular_jumps=true,
162+
minimal_jumps=false)
163+
127164
@unpack reactions, syms_to_ints, params_to_ints = rn
128165

129166
# parse the jumps
130-
(jump_rate_expr, jump_affect_expr, jumps, regular_jumps) = get_jumps(reactions, syms_to_ints, params_to_ints)
167+
(jump_rate_expr, jump_affect_expr, jumps, regular_jumps) = get_jumps(reactions,
168+
syms_to_ints,
169+
params_to_ints;
170+
minimal_jumps=minimal_jumps)
131171

132172
rn.jump_rate_expr = jump_rate_expr
133173
rn.jump_affect_expr = jump_affect_expr
134-
rn.jumps = eval(jumps)
135-
rn.regular_jumps = eval(regular_jumps)
174+
rn.jumps = build_jumps ? eval(jumps) : nothing
175+
rn.regular_jumps = build_regular_jumps ? eval(regular_jumps) : nothing
136176

137177
nothing
138178
end

src/massaction_jump_utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,13 @@ function network_to_jumpset(rn, specmap, ratemap, params, jumps)
8787
Base.eval(param_context, :($param = $(params[index])))
8888
end
8989

90+
idx = 1
9091
for (i,rs) in enumerate(rn.reactions)
9192
if rs.is_pure_mass_action
9293
push!(majumpvec, make_majump(rs, specmap, ratemap, params, param_context))
9394
else
94-
push!(cjumpvec, jumps[i])
95+
push!(cjumpvec, jumps[idx])
96+
idx += 1
9597
end
9698
end
9799

src/reaction_network.jl

Lines changed: 60 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -92,60 +92,76 @@ funcdict = Dict{Symbol, Function}() #Stores user def
9292

9393
#Coordination function, actually does all the work of the macro.
9494
function coordinate(name, ex::Expr, p, scale_noise)
95-
reactions = get_reactions(ex) ::Vector{ReactionStruct}
96-
reactants = get_reactants(reactions) ::OrderedDict{Symbol,Int}
97-
parameters = get_parameters(p) ::OrderedDict{Symbol,Int}
95+
96+
# minimal reaction network components
97+
(reactions, reactants, parameters, syms, params) = get_minnetwork(ex, p)
98+
99+
# expressions for ODEs
100+
(f_expr, f, f_rhs, symjac, f_symfuncs) = genode_exprs(reactions, reactants, parameters, syms)
101+
odefun = :(ODEFunction(f; syms=$syms))
98102

99-
syms = collect(keys(reactants))
100-
params = collect(keys(parameters))
101-
(in(:t,union(syms,params))) && error("t is reserved for the time variable and may neither be used as a reactant nor a parameter")
103+
# expressions for SDEs
104+
(g_expr, g, g_funcs, p_matrix) = gensde_exprs(reactions, reactants, parameters, scale_noise)
105+
sdefun = :(SDEFunction(f, g; syms=$syms))
102106

103-
update_reaction_info(reactions,syms)
107+
# expressions for jumps
108+
(jump_rate_expr, jump_affect_expr, jumps, regular_jumps) = get_jumps(reactions, reactants, parameters)
104109

105-
f_expr = get_f(reactions, reactants)
106-
f = make_func(f_expr, reactants, parameters)
110+
# Build the type
111+
exprs = Vector{Expr}(undef,0)
112+
typeex,constructorex = maketype(DiffEqBase.AbstractReactionNetwork, name, f, f_rhs, f_symfuncs, g, g_funcs, jumps, regular_jumps, Meta.quot(jump_rate_expr), Meta.quot(jump_affect_expr), p_matrix, syms, scale_noise; params=params, reactions=reactions, symjac=symjac, syms_to_ints=reactants, params_to_ints=parameters, odefun=odefun, sdefun=sdefun)
113+
push!(exprs,typeex)
114+
push!(exprs,constructorex)
107115

108-
g_expr = get_g(reactions, reactants, scale_noise)
109-
g = make_func(g_expr, reactants, parameters)
110-
p_matrix = zeros(length(reactants), length(reactions))
116+
# add type functions
117+
append!(exprs, gentypefun_exprs(name))
111118

112-
(jump_rate_expr, jump_affect_expr, jumps, regular_jumps) = get_jumps(reactions, reactants, parameters)
119+
# return as one expression block
120+
expr_arr_to_block(exprs)
121+
end
113122

114-
f_rhs = [element.args[2] for element in f_expr]
115-
symjac = Expr(:quote, calculate_jac(deepcopy(f_rhs), syms))
116-
f_symfuncs = hcat([SymEngine.Basic(f) for f in f_rhs])
123+
# min_reaction_network coordination function, actually does all the work of the macro.
124+
function min_coordinate(name, ex::Expr, p, scale_noise)
125+
126+
# minimal reaction network components
127+
(reactions, reactants, parameters, syms, params) = get_minnetwork(ex, p)
117128

118129
# Build the type
119130
exprs = Vector{Expr}(undef,0)
120-
121-
## only get the right-hand-side of the equations.
122-
f_funcs = [element.args[2] for element in f_expr]
123-
g_funcs = [element.args[2] for element in g_expr]
124-
125-
odefun = :(ODEFunction(f; syms=$syms))
126-
sdefun = :(SDEFunction(f, g; syms=$syms))
127-
typeex,constructorex = maketype(DiffEqBase.AbstractReactionNetwork, name, f, f_funcs, f_symfuncs, g, g_funcs, jumps, regular_jumps, Meta.quot(jump_rate_expr), Meta.quot(jump_affect_expr), p_matrix, syms, scale_noise; params=params, reactions=reactions, symjac=symjac, syms_to_ints=reactants, params_to_ints=parameters, odefun=odefun, sdefun=sdefun)
128-
131+
typeex,constructorex = maketype(DiffEqBase.AbstractReactionNetwork, name, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, syms, scale_noise; params=params, reactions=reactions, symjac=nothing, syms_to_ints=reactants, params_to_ints=parameters)
129132
push!(exprs,typeex)
130133
push!(exprs,constructorex)
131134

132-
## Overload the type so that it can act as a function.
133-
overloadex = :(((f::$name))(du, u, p, t::Number) = f.f(du, u, p, t)) |> esc
134-
push!(exprs,overloadex)
135+
# add type functions
136+
append!(exprs, gentypefun_exprs(name, gen_inplace=false, gen_outofplace=false))
135137

136-
## Add a method which allocates the `du` and returns it instead of being inplace
137-
overloadex = :(((f::$name))(u,p,t::Number) = (du=similar(u); f(du,u,p,t); du)) |> esc
138-
push!(exprs,overloadex)
138+
# return as one expression block
139+
expr_arr_to_block(exprs)
140+
end
139141

140-
# export type constructor
141-
def_const_ex = :(($name)()) |> esc
142-
push!(exprs,def_const_ex)
142+
# SDE expressions
143+
function gensde_exprs(reactions, reactants, parameters, scale_noise)
144+
g_expr = get_g(reactions, reactants, scale_noise)
145+
g = make_func(g_expr, reactants, parameters)
146+
g_funcs = [element.args[2] for element in g_expr]
147+
p_matrix = zeros(length(reactants), length(reactions))
143148

144-
expr_arr_to_block(exprs)
149+
(g_expr,g,g_funcs,p_matrix)
145150
end
146151

147-
# min_reaction_network coordination function, actually does all the work of the macro.
148-
function min_coordinate(name, ex::Expr, p, scale_noise)
152+
# ODE expressions
153+
function genode_exprs(reactions, reactants, parameters, syms)
154+
f_expr = get_f(reactions, reactants)
155+
f = make_func(f_expr, reactants, parameters)
156+
f_rhs = [element.args[2] for element in f_expr]
157+
symjac = Expr(:quote, calculate_jac(deepcopy(f_rhs), syms))
158+
f_symfuncs = hcat([SymEngine.Basic(f) for f in f_rhs])
159+
160+
(f_expr,f,f_rhs,symjac,f_symfuncs)
161+
end
162+
163+
# generate the minimal network components
164+
function get_minnetwork(ex::Expr, p)
149165
reactions = get_reactions(ex) ::Vector{ReactionStruct}
150166
reactants = get_reactants(reactions) ::OrderedDict{Symbol,Int}
151167
parameters = get_parameters(p) ::OrderedDict{Symbol,Int}
@@ -156,18 +172,7 @@ function min_coordinate(name, ex::Expr, p, scale_noise)
156172

157173
update_reaction_info(reactions,syms)
158174

159-
# Build the type
160-
exprs = Vector{Expr}(undef,0)
161-
typeex,constructorex = maketype(DiffEqBase.AbstractReactionNetwork, name, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, syms, scale_noise; params=params, reactions=reactions, symjac=nothing, syms_to_ints=reactants, params_to_ints=parameters)
162-
163-
push!(exprs,typeex)
164-
push!(exprs,constructorex)
165-
166-
# export type constructor
167-
def_const_ex = :(($name)()) |> esc
168-
push!(exprs,def_const_ex)
169-
170-
expr_arr_to_block(exprs)
175+
(reactions,reactants,parameters,syms,params)
171176
end
172177

173178
#Generates a vector containing a number of reaction structures, each containing the infromation about one reaction.
@@ -362,7 +367,7 @@ function make_func(func_expr::Vector{Expr},reactants::OrderedDict{Symbol,Int}, p
362367
end
363368

364369
#Creates expressions for jump affects and rates. Also creates and array with MassAction, ConstantRate and VariableRate Jumps.
365-
function get_jumps(reactions::Vector{ReactionStruct}, reactants::OrderedDict{Symbol,Int}, parameters::OrderedDict{Symbol,Int})
370+
function get_jumps(reactions::Vector{ReactionStruct}, reactants::OrderedDict{Symbol,Int}, parameters::OrderedDict{Symbol,Int}; minimal_jumps=false)
366371
rates = Vector{Any}(undef,length(reactions))
367372
affects = Vector{Vector{Expr}}(undef,length(reactions))
368373
jumps = Expr(:tuple)
@@ -375,17 +380,19 @@ function get_jumps(reactions::Vector{ReactionStruct}, reactants::OrderedDict{Sym
375380
reactant_set = union(getfield.(reaction.products, :reactant),getfield.(reaction.substrates, :reactant))
376381
foreach(r -> push!(affects[idx],:(@inbounds integrator.u[$(reactants[r])] += $(get_stoch_diff(reaction,r)))), reactant_set)
377382
syntax_rate = recursive_replace!(deepcopy(rates[idx]), (reactants,:internal_var___u), (parameters, :internal_var___p))
378-
#if reaction.is_pure_mass_action
383+
384+
if minimal_jumps && reaction.is_pure_mass_action
385+
recursive_contains(:t,rates[idx]) && push!(jumps.args,Expr(:call,:VariableRateJump))
379386
# ma_sub_stoch = :(reactant_stoich = [[]])
380387
# ma_stoch_change = :(reactant_stoich = [[]])
381388
# foreach(sub -> push!(ma_sub_stoch.args[2].args[1].args),:($(reactants[sub.reactant])=>$(sub.stoichiometry)),reaction.substrates)
382389
# foreach(reactant -> push!(ma_stoch_change.args[2].args[1].args),:($(reactants[reactant.reactant])=>$(get_stoch_diff(reaction,reactant))),reaction.substrates)
383390
# push!(jumps.args,:(MassActionJump($(reaction.rate_org),$(ma_sub_stoch),$(ma_stoch_change))))
384-
#else
391+
else
385392
recursive_contains(:t,rates[idx]) ? push!(jumps.args,Expr(:call,:VariableRateJump)) : push!(jumps.args,Expr(:call,:ConstantRateJump))
386393
push!(jumps.args[idx].args, :((internal_var___u,internal_var___p,t) -> $syntax_rate))
387394
push!(jumps.args[idx].args, :(integrator -> $(expr_arr_to_block(deepcopy(affects[idx])))))
388-
#end
395+
end
389396
push!(reg_rates.args,:(internal_var___out[$idx]=$syntax_rate))
390397
foreach(r -> push!(reg_c.args,:(internal_var___dc[$(reactants[r]),$idx]=$(get_stoch_diff(reaction,r)))), reactant_set)
391398
end

0 commit comments

Comments
 (0)