Skip to content

Commit 21378a0

Browse files
authored
Merge pull request #494 from isaacsas/dsl_updates
Parametric stoich via DSL
2 parents 406af9e + c83250c commit 21378a0

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

src/reaction_network.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ end
192192
#Structure containing information about one reactant in one reaction.
193193
struct ReactantStruct
194194
reactant::Union{Symbol,Expr}
195-
stoichiometry::Number
195+
stoichiometry::ExprValues
196196
end
197197
#Structure containing information about one Reaction. Contain all its substrates and products as well as its rate. Contains a specialized constructor.
198198
struct ReactionStruct
@@ -261,7 +261,7 @@ function make_reaction_system(ex::Expr, parameters; name=:(gensym(:ReactionSyste
261261
# parse DSL lines
262262
reactions = get_reactions(ex)
263263
reactants = get_reactants(reactions)
264-
allspecies = union(reactants, get_rate_species(reactions,parameters))
264+
allspecies = union(reactants, get_rx_species(reactions,parameters))
265265
!isempty(intersect(forbidden_symbols,union(allspecies,parameters))) &&
266266
error("The following symbol(s) are used as species or parameters: "*((map(s -> "'"*string(s)*"', ",intersect(forbidden_symbols,union(species,parameters)))...))*"this is not permited.")
267267

@@ -290,7 +290,7 @@ function make_reaction(ex::Expr)
290290
# parse DSL lines
291291
reaction = get_reaction(ex)
292292
allspecies = get_reactants(reaction) # species defined by stoich
293-
parameters = get_rate_species([reaction],Symbol[]) # anything in a rate is a parameter
293+
parameters = get_rx_species([reaction],Symbol[]) # anything in a rate is a parameter
294294
!isempty(intersect(forbidden_symbols,union(allspecies,parameters))) &&
295295
error("The following symbol(s) are used as species or parameters: "*((map(s -> "'"*string(s)*"', ",intersect(forbidden_symbols,union(species,parameters)))...))*"this is not permited.")
296296

@@ -305,11 +305,17 @@ function make_reaction(ex::Expr)
305305
end
306306
end
307307

308-
function get_rate_species(rxs, ps)
308+
function get_rx_species(rxs, ps)
309309
pset = Set(ps)
310310
species_set = Set{Symbol}()
311311
for rx in rxs
312312
find_species_in_rate!(species_set, rx.rate, pset)
313+
for sub in rx.substrates
314+
find_species_in_rate!(species_set, sub.stoichiometry, pset)
315+
end
316+
for prod in rx.products
317+
find_species_in_rate!(species_set, prod.stoichiometry, pset)
318+
end
313319
end
314320
collect(species_set)
315321
end
@@ -387,18 +393,31 @@ function push_reactions!(reactions::Vector{ReactionStruct}, sub_line::ExprValues
387393
end
388394
end
389395

396+
function processmult(op, mult, stoich)
397+
if (mult isa Number) && (stoich isa Number)
398+
op(mult, stoich)
399+
else
400+
:($op($mult,$stoich))
401+
end
402+
end
403+
390404
#Recursive function that loops through the reaction line and finds the reactants and their stoichiometry. Recursion makes it able to handle weird cases like 2(X+Y+3(Z+XY)).
391-
function recursive_find_reactants!(ex::ExprValues, mult::Number, reactants::Vector{ReactantStruct})
405+
function recursive_find_reactants!(ex::ExprValues, mult::ExprValues, reactants::Vector{ReactantStruct})
392406
if typeof(ex)!=Expr || (ex.head == :escape)
393407
(ex == 0 || in(ex,empty_set)) && (return reactants)
394-
if in(ex, getfield.(reactants,:reactant))
408+
if any(ex==reactant.reactant for reactant in reactants)
395409
idx = findall(x -> x==ex, getfield.(reactants,:reactant))[1]
396-
reactants[idx] = ReactantStruct(ex,mult+reactants[idx].stoichiometry)
410+
reactants[idx] = ReactantStruct(ex,processmult(+,mult,reactants[idx].stoichiometry))
397411
else
398412
push!(reactants, ReactantStruct(ex,mult))
399413
end
400414
elseif ex.args[1] == :*
401-
recursive_find_reactants!(ex.args[3],mult*ex.args[2],reactants)
415+
if length(ex.args) == 3
416+
recursive_find_reactants!(ex.args[3],processmult(*,mult,ex.args[2]),reactants)
417+
else
418+
newmult = processmult(*, mult, Expr(:call,ex.args[1:end-1]...))
419+
recursive_find_reactants!(ex.args[end],newmult,reactants)
420+
end
402421
elseif ex.args[1] == :+
403422
for i = 2:length(ex.args)
404423
recursive_find_reactants!(ex.args[i],mult,reactants)
@@ -409,6 +428,7 @@ function recursive_find_reactants!(ex::ExprValues, mult::Number, reactants::Vect
409428
return reactants
410429
end
411430

431+
412432
function get_reactants(reaction::ReactionStruct, reactants=Vector{Union{Symbol,Expr}}())
413433
for reactant in Iterators.flatten((reaction.substrates,reaction.products))
414434
!in(reactant.reactant,reactants) && push!(reactants,reactant.reactant)

test/symbolic_stoich.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Catalyst, ModelingToolkit, OrdinaryDiffEq, Test, LinearAlgebra, DiffEqJump
22

3-
@parameters t k α
3+
@parameters k α
44
@variables t, A(t), B(t), C(t), D(t)
55
rxs = [Reaction(t*k, [A], [B], [2*α^2], [k+α*C])
66
Reaction(1.0, [A,B], [C,D], [α,2], [k,α])
@@ -10,6 +10,19 @@ rxs = [Reaction(t*k, [A], [B], [2*α^2], [k+α*C])
1010
@test issetequal(parameters(rs), [k,α])
1111
osys = convert(ODESystem, rs)
1212

13+
g = (k+α*C)
14+
rs2 = @reaction_network rs begin
15+
t*k, 2*α^2*A --> $g*B
16+
1.0, α*A + 2*B --> k*C + α*D
17+
end k α
18+
@test rs2 == rs
19+
20+
21+
rxs2 = [(@reaction t*k, 2*α^2*A --> $g*B),
22+
(@reaction 1.0, α*A + 2*B --> k*C + α*D)]
23+
rs3 = ReactionSystem(rxs2, t; name=:rs)
24+
@test rs3 == rs
25+
1326
u0map = [A => 3.0, B => 2.0, C => 3.0, D => 1.5]
1427
pmap = (k => 2.5, α => 2)
1528
tspan = (0.0,5.0)
@@ -37,6 +50,8 @@ du2 = copy(du1)
3750
oprob2.f(du2,oprob2.u0,oprob2.p,1.5)
3851
@test norm(du1 .- du2) < 100*eps()
3952

53+
54+
4055
# test without rate law scalings
4156
osys = convert(ODESystem, rs, combinatoric_ratelaws=false)
4257
oprob = ODEProblem(osys, u0map, tspan, pmap)

0 commit comments

Comments
 (0)