Skip to content

Commit 738a7d8

Browse files
authored
Merge pull request #1041 from vyudu/DSL
expanding equations when passed to `@reaction_network`
2 parents 31b0f99 + 672225a commit 738a7d8

File tree

4 files changed

+97
-12
lines changed

4 files changed

+97
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Catalyst"
22
uuid = "479239e8-5488-4da2-87a7-35f2df7eef83"
3-
version = "14.4"
3+
version = "14.4.0"
44

55
[deps]
66
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"

src/dsl.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
297297

298298
# Get macro options.
299299
if length(unique(arg.args[1] for arg in option_lines)) < length(option_lines)
300-
error("Some options where given multiple times.")
300+
error("Some options were given multiple times.")
301301
end
302302
options = Dict(map(arg -> Symbol(String(arg.args[1])[2:end]) => arg,
303303
option_lines))
@@ -315,12 +315,12 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
315315
parameters_declared = extract_syms(options, :parameters)
316316
variables_declared = extract_syms(options, :variables)
317317

318-
# Reads more options.
318+
# Reads equations.
319319
vars_extracted, add_default_diff, equations = read_equations_options(
320320
options, variables_declared)
321321
variables = vcat(variables_declared, vars_extracted)
322322

323-
# handle independent variables
323+
# Handle independent variables
324324
if haskey(options, :ivs)
325325
ivs = Tuple(extract_syms(options, :ivs))
326326
ivexpr = copy(options[:ivs])
@@ -339,14 +339,16 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
339339
combinatoric_ratelaws = true
340340
end
341341

342-
# Reads more options.
342+
# Reads observables.
343343
observed_vars, observed_eqs, obs_syms = read_observed_options(
344344
options, [species_declared; variables], all_ivs)
345345

346+
# Collect species and parameters, including ones inferred from the reactions.
346347
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
347348
variables)))
348-
species_extracted, parameters_extracted = extract_species_and_parameters!(reactions,
349-
declared_syms)
349+
species_extracted, parameters_extracted = extract_species_and_parameters!(
350+
reactions, declared_syms)
351+
350352
species = vcat(species_declared, species_extracted)
351353
parameters = vcat(parameters_declared, parameters_extracted)
352354

@@ -376,9 +378,11 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
376378
push!(rxexprs.args, get_rxexprs(reaction))
377379
end
378380
for equation in equations
381+
equation = escape_equation_RHS!(equation)
379382
push!(rxexprs.args, equation)
380383
end
381384

385+
# Output code corresponding to the reaction system.
382386
quote
383387
$ivexpr
384388
$ps
@@ -572,7 +576,7 @@ function get_rxexprs(rxstruct)
572576
subs_stoich_init = deepcopy(subs_init)
573577
prod_init = isempty(rxstruct.products) ? nothing : :([])
574578
prod_stoich_init = deepcopy(prod_init)
575-
reaction_func = :(Reaction($(recursive_expand_functions!(rxstruct.rate)), $subs_init,
579+
reaction_func = :(Reaction($(recursive_escape_functions!(rxstruct.rate)), $subs_init,
576580
$prod_init, $subs_stoich_init, $prod_stoich_init,
577581
metadata = $(rxstruct.metadata)))
578582
for sub in rxstruct.substrates
@@ -904,17 +908,24 @@ end
904908

905909
### Generic Expression Manipulation ###
906910

907-
# Recursively traverses an expression and replaces special function call like "hill(...)" with the actual corresponding expression.
908-
function recursive_expand_functions!(expr::ExprValues)
911+
# Recursively traverses an expression and escapes all the user-defined functions. Special function calls like "hill(...)" are not expanded.
912+
function recursive_escape_functions!(expr::ExprValues)
909913
(typeof(expr) != Expr) && (return expr)
910-
foreach(i -> expr.args[i] = recursive_expand_functions!(expr.args[i]),
914+
foreach(i -> expr.args[i] = recursive_escape_functions!(expr.args[i]),
911915
1:length(expr.args))
912916
if expr.head == :call
913917
!isdefined(Catalyst, expr.args[1]) && (expr.args[1] = esc(expr.args[1]))
914918
end
915919
expr
916920
end
917921

922+
# Recursively escape functions in the right-hand-side of an equation written using user-defined functions. Special function calls like "hill(...)" are not expanded.
923+
function escape_equation_RHS!(eqexpr::Expr)
924+
rhs = recursive_escape_functions!(eqexpr.args[3])
925+
eqexpr.args[3] = rhs
926+
eqexpr
927+
end
928+
918929
# Returns the length of a expression tuple, or 1 if it is not an expression tuple (probably a Symbol/Numerical).
919930
function tup_leng(ex::ExprValues)
920931
(typeof(ex) == Expr && ex.head == :tuple) && (return length(ex.args))

test/dsl/dsl_options.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ seed = rand(rng, 1:100)
1313

1414
# Sets the default `t` to use.
1515
t = default_t()
16+
D = default_time_deriv()
1617

1718
### Tests `@parameters`, `@species`, and `@variables` Options ###
1819

@@ -952,3 +953,76 @@ let
952953
@unpack k1, A = rn3
953954
@test isequal(rl, k1*A^2)
954955
end
956+
957+
# Test whether user-defined functions are properly expanded in equations.
958+
let
959+
f(A, t) = 2*A*t
960+
961+
# Test user-defined function
962+
rn = @reaction_network begin
963+
@equations D(A) ~ f(A, t)
964+
end
965+
@test length(equations(rn)) == 1
966+
@test equations(rn)[1] isa Equation
967+
@species A(t)
968+
@test isequal(equations(rn)[1], D(A) ~ 2*A*t)
969+
970+
971+
# Test whether expansion happens properly for unregistered/registered functions.
972+
hill_unregistered(A, v, K, n) = v*(A^n) / (A^n + K^n)
973+
rn2 = @reaction_network begin
974+
@parameters v K n
975+
@equations D(A) ~ hill_unregistered(A, v, K, n)
976+
end
977+
@test length(equations(rn2)) == 1
978+
@test equations(rn2)[1] isa Equation
979+
@parameters v K n
980+
@test isequal(equations(rn2)[1], D(A) ~ v*(A^n) / (A^n + K^n))
981+
982+
hill2(A, v, K, n) = v*(A^n) / (A^n + K^n)
983+
@register_symbolic hill2(A, v, K, n)
984+
# Registered symbolic function should not expand.
985+
rn2r = @reaction_network begin
986+
@parameters v K n
987+
@equations D(A) ~ hill2(A, v, K, n)
988+
end
989+
@test length(equations(rn2r)) == 1
990+
@test equations(rn2r)[1] isa Equation
991+
@parameters v K n
992+
@test isequal(equations(rn2r)[1], D(A) ~ hill2(A, v, K, n))
993+
994+
995+
rn3 = @reaction_network begin
996+
@species Iapp(t)
997+
@equations begin
998+
D(A) ~ Iapp
999+
Iapp ~ f(A,t)
1000+
end
1001+
end
1002+
@test length(equations(rn3)) == 2
1003+
@test equations(rn3)[1] isa Equation
1004+
@test equations(rn3)[2] isa Equation
1005+
@variables Iapp(t)
1006+
@test isequal(equations(rn3)[1], D(A) ~ Iapp)
1007+
@test isequal(equations(rn3)[2], Iapp ~ 2*A*t)
1008+
1009+
# Test whether the DSL and symbolic ways of creating the network generate the same system
1010+
@species Iapp(t) A(t)
1011+
eq = [D(A) ~ Iapp, Iapp ~ f(A, t)]
1012+
@named rn3_sym = ReactionSystem(eq, t)
1013+
rn3_sym = complete(rn3_sym)
1014+
@test isequivalent(rn3, rn3_sym)
1015+
1016+
1017+
# Test more complicated expression involving both registered function and a user-defined function.
1018+
g(A, K, n) = A^n + K^n
1019+
rn4 = @reaction_network begin
1020+
@parameters v K n
1021+
@equations D(A) ~ hill(A, v, K, n)*g(A, K, n)
1022+
end
1023+
@test length(equations(rn4)) == 1
1024+
@test equations(rn4)[1] isa Equation
1025+
@parameters v n
1026+
@test isequal(Catalyst.expand_registered_functions(equations(rn4)[1]), D(A) ~ v*(A^n))
1027+
end
1028+

test/reactionsystem_core/coupled_equation_crn_systems.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,4 +1042,4 @@ let
10421042
u0 = [S1 => 1.0, S2 => 2.0, V1 => 0.1]
10431043
ps = [p1 => 2.0, p2 => 3.0]
10441044
@test_throws Exception ODEProblem(rs, u0, (0.0, 1.0), ps; structural_simplify = true)
1045-
end
1045+
end

0 commit comments

Comments
 (0)