Skip to content

Commit f98b23e

Browse files
authored
Merge pull request #1257 from SciML/use_jumpinputs_in_tests
Remove` DiscreteProblem`s in tests
2 parents d2b8957 + cabdb63 commit f98b23e

8 files changed

+66
-68
lines changed

test/reactionsystem_core/events.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,15 @@ let
113113
@test prob.ps[α] isa Int64
114114
end
115115

116-
# Handles `DiscreteProblem`s and `JumpProblem`s (these cannot contain continuous events or variables).
116+
# Handles `JumpInput`s and `JumpProblem`s (these cannot contain continuous events or variables).
117117
discrete_events = [2.0 => [A ~ A + α]]
118118
@named rs_de_2 = ReactionSystem(rxs, t; discrete_events)
119119
rs_de_2 = complete(rs_de_2)
120-
dprob = DiscreteProblem(rs_de_2, u0, (0.0, 10.0), ps)
121-
jprob = JumpProblem(rs_de_2, dprob, Direct())
122-
for prob in [dprob, jprob]
123-
@test dprob[A] == 2
124-
@test dprob.ps[α] == 1
125-
@test dprob.ps[α] isa Int64
126-
end
120+
jin = JumpInputs(rs_de_2, u0, (0.0, 10.0), ps)
121+
jprob = JumpProblem(jin)
122+
@test jprob[A] == 2
123+
@test jprob.ps[α] == 1
124+
@test jprob.ps[α] isa Int64
127125
end
128126

129127

@@ -358,13 +356,13 @@ let
358356
# Simulates the model for conditions where it *definitely* will cross `X = 1000.0`
359357
u0 = [:X => 999]
360358
ps = [:p => 10.0, :d => 0.001]
361-
dprob = DiscreteProblem(rn, u0, (0.0, 2.0), ps)
362-
jprob = JumpProblem(rn, dprob, Direct(); rng)
359+
jin = JumpInputs(rn, u0, (0.0, 2.0), ps)
360+
jprob = JumpProblem(jin; rng)
363361
sol = solve(jprob, SSAStepper(); seed)
364362

365363
# Checks that all `e` parameters have been updated properly.
366364
@test sol.ps[:e1] == 1
367-
@test sol.ps[:e2] == 1
365+
@test sol.ps[:e2] == 1
368366
@test sol.ps[:e3] == 1
369367
end
370368

@@ -434,10 +432,10 @@ let
434432
# Checks for Jump simulations. (note, non-seed dependant test should be created instead)
435433
# Note that periodic discrete events are currently broken for jump processes (and unlikely to be fixed soon due to have events are implemented).
436434
callback = CallbackSet(cb_disc_1, cb_disc_2, cb_disc_3)
437-
dprob = DiscreteProblem(rn, u0, tspan, ps)
438-
dprob_events = DiscreteProblem(rn_dics_events, u0, tspan, ps)
439-
jprob = JumpProblem(rn, dprob, Direct(); rng)
440-
jprob_events = JumpProblem(rn_dics_events, dprob_events, Direct(); rng)
435+
jin = JumpInputs(rn, u0, tspan, ps)
436+
jin_events = JumpInputs(rn_dics_events, u0, tspan, ps)
437+
jprob = JumpProblem(jin)
438+
jprob_events = JumpProblem(jin_events; rng)
441439
sol = solve(jprob, SSAStepper(); seed, callback)
442440
sol_events = solve(jprob_events, SSAStepper(); seed)
443441
@test_broken sol == sol_events # seems to be not identical in the sample paths

test/reactionsystem_core/higher_order_reactions.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ let
8989
# Prepares JumpProblem via Catalyst.
9090
u0_base = rnd_u0_Int64(base_higher_order_network, rng)
9191
ps_base = rnd_ps(base_higher_order_network, rng)
92-
dprob_base = DiscreteProblem(base_higher_order_network, u0_base, (0.0, 100.0), ps_base)
93-
jprob_base = JumpProblem(base_higher_order_network, dprob_base, Direct(); rng = StableRNG(1234))
92+
jin_base = JumpInputs(base_higher_order_network, u0_base, (0.0, 100.0), ps_base)
93+
jprob_base = JumpProblem(jin_base; rng = StableRNG(1234))
9494

9595
# Prepares JumpProblem partially declared manually.
96-
dprob_alt1 = DiscreteProblem(higher_order_network_alt1, u0_base, (0.0, 100.0), ps_base)
97-
jprob_alt1 = JumpProblem(higher_order_network_alt1, dprob_alt1, Direct(); rng = StableRNG(1234))
96+
jin_alt1 = JumpInputs(higher_order_network_alt1, u0_base, (0.0, 100.0), ps_base)
97+
jprob_alt1 = JumpProblem(jin_alt1; rng = StableRNG(1234))
9898

9999
# Prepares JumpProblem via manually declared system.
100100
u0_alt2 = map_to_vec(u0_base, [:X1, :X2, :X3, :X4, :X5, :X6, :X7, :X8, :X9, :X10])

test/reactionsystem_core/parameter_type_designation.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ let
7474
oprob = ODEProblem(rs, u0, (0.0, 1.0), p_alts[1])
7575
sprob = SDEProblem(rs, u0, (0.0, 1.0), p_alts[1])
7676
dprob = DiscreteProblem(rs, u0, (0.0, 1.0), p_alts[1])
77-
jprob = JumpProblem(rs, dprob, Direct(); rng)
77+
jprob = JumpProblem(JumpInputs(rs, u0, (0.0, 1.0), p_alts[1]); rng)
7878
nprob = NonlinearProblem(rs, u0, p_alts[1])
7979

80-
oinit = init(oprob, Tsit5())
81-
sinit = init(sprob, ImplicitEM())
82-
jinit = init(jprob, SSAStepper())
83-
ninit = init(nprob, NewtonRaphson())
80+
oinit = init(oprob, Tsit5())
81+
sinit = init(sprob, ImplicitEM())
82+
jinit = init(jprob, SSAStepper())
83+
ninit = init(nprob, NewtonRaphson())
8484

8585
osol = solve(oprob, Tsit5())
8686
ssol = solve(sprob, ImplicitEM(); seed)
@@ -113,7 +113,7 @@ let
113113
@test unwrap(mtk_struct.ps[p5]) == 3//2
114114
@test unwrap(mtk_struct.ps[d5]) == Float32(1.5)
115115
end
116-
116+
117117
# Checks all stored variables (these should always be `Float64`).
118118
for mtk_struct in [oprob, sprob, dprob, jprob, nprob, oinit, sinit, jinit, ninit]
119119
# Checks that all variables have the correct type.

test/reactionsystem_core/reactionsystem.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ let
126126
sdesys = complete(convert(SDESystem, rs))
127127
js = complete(convert(JumpSystem, rs))
128128

129-
@test ModelingToolkit.get_defaults(rs) ==
129+
@test ModelingToolkit.get_defaults(rs) ==
130130
ModelingToolkit.get_defaults(js) == defs
131131

132-
# these systems add initial conditions to the defaults
132+
# these systems add initial conditions to the defaults
133133
@test ModelingToolkit.get_defaults(odesys) ==
134134
ModelingToolkit.get_defaults(sdesys)
135135
@test issubset(defs, ModelingToolkit.get_defaults(odesys))
@@ -551,9 +551,9 @@ let
551551
(@reaction k1, $A --> B2),
552552
(@reaction 10 * k1, ∅ --> B3)], t)
553553
rn = complete(rn)
554-
dprob = DiscreteProblem(rn, [A => 10, C => 10, B1 => 0, B2 => 0, B3 => 0], (0.0, 10.0),
554+
jin = JumpInputs(rn, [A => 10, C => 10, B1 => 0, B2 => 0, B3 => 0], (0.0, 10.0),
555555
[k1 => 1.0])
556-
jprob = JumpProblem(rn, dprob, Direct(); rng, save_positions = (false, false))
556+
jprob = JumpProblem(jin; rng, save_positions = (false, false))
557557
umean = zeros(4)
558558
Nsims = 40000
559559
for i in 1:Nsims
@@ -1017,7 +1017,7 @@ let
10171017
@test sys isa JumpSystem
10181018
@test MT.has_equations(sys)
10191019
@test length(massactionjumps(sys)) == 1
1020-
@test isempty(constantratejumps(sys))
1020+
@test isempty(constantratejumps(sys))
10211021
@test length(variableratejumps(sys)) == 3
10221022
@test length(odeeqs(sys)) == 4
10231023
@test length(continuous_events(sys)) == 1
@@ -1042,7 +1042,7 @@ let
10421042
@test sys isa JumpSystem
10431043
@test MT.has_equations(sys)
10441044
@test length(massactionjumps(sys)) == 1
1045-
@test isempty(constantratejumps(sys))
1045+
@test isempty(constantratejumps(sys))
10461046
@test length(variableratejumps(sys)) == 2
10471047
@test length(odeeqs(sys)) == 4
10481048
odes = union(eqs, [D(A) ~ 0, D(B) ~ -λ*A*B, D(C) ~ 0])
@@ -1069,8 +1069,8 @@ let
10691069
sys = jinput.sys
10701070
@test sys isa JumpSystem
10711071
@test MT.has_equations(sys)
1072-
@test isempty(massactionjumps(sys))
1073-
@test isempty(constantratejumps(sys))
1072+
@test isempty(massactionjumps(sys))
1073+
@test isempty(constantratejumps(sys))
10741074
@test length(variableratejumps(sys)) == 3
10751075
@test length(odeeqs(sys)) == 4
10761076
odes = union(eqs, [D(A) ~ 0, D(B) ~ -λ*A*B, D(C) ~ 0])

test/reactionsystem_core/symbolic_stoichiometry.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ include("../test_functions.jl")
1919
### Base Tests ###
2020

2121
# Checks that systems with symbolic stoichiometries, created using different approaches, are identical.
22-
let
22+
let
2323
@parameters p k d::Float64 n1::Int64 n2 n3
2424
@species X(t) Y(t)
2525
rxs1 = [
@@ -71,7 +71,7 @@ begin
7171
end
7272

7373
# Compares the Catalyst-generated ODE function to a manually computed ODE function.
74-
let
74+
let
7575
# With combinatoric ratelaws.
7676
function oderhs(u, p, t)
7777
k,α = p
@@ -89,7 +89,7 @@ let
8989
end
9090
@test f_eval(rs, u0_1, ps_1, τ) oderhs(u0_2, ps_2, τ)
9191

92-
# Without combinatoric ratelaws.
92+
# Without combinatoric ratelaws.
9393
function oderhs_no_crl(u, p, t)
9494
k,α = p
9595
A,B,C,D = u
@@ -108,8 +108,8 @@ let
108108
end
109109

110110
# Compares the Catalyst-generated SDE noise function to a manually computed SDE noise function.
111-
let
112-
# With combinatoric ratelaws.
111+
let
112+
# With combinatoric ratelaws.
113113
function sdenoise(u, p, t)
114114
k,α = p
115115
A,B,C,D = u
@@ -126,7 +126,7 @@ let
126126
end
127127
@test g_eval(rs, u0_1, ps_1, τ) sdenoise(u0_2, ps_2, τ)
128128

129-
# Without combinatoric ratelaws.
129+
# Without combinatoric ratelaws.
130130
function sdenoise_no_crl(u, p, t)
131131
k,α = p
132132
A,B,C,D = u
@@ -192,7 +192,7 @@ end
192192
# Tests symbolic stoichiometries in simulations.
193193
# Tests for decimal numbered symbolic stoichiometries.
194194
let
195-
# Declares models. The references models have the `n` parameters so they can use the same
195+
# Declares models. The references models have the `n` parameters so they can use the same
196196
# parameter vectors as the non-reference ones.
197197
rs_int = @reaction_network begin
198198
@parameters n::Int64
@@ -211,9 +211,9 @@ let
211211
(k1, k2), 2.5*X1 <--> X2
212212
end
213213

214-
# Set simulation settings. Initial conditions are design to start, more or less, at
214+
# Set simulation settings. Initial conditions are design to start, more or less, at
215215
# steady state concentrations.
216-
# Values are selected so that stochastic tests should always pass within the bounds (independent
216+
# Values are selected so that stochastic tests should always pass within the bounds (independent
217217
# of seed).
218218
u0_int = [:X1 => 150, :X2 => 600]
219219
u0_dec = [:X1 => 100, :X2 => 600]
@@ -247,10 +247,10 @@ let
247247
@test mean(ssol_dec[:X1]) mean(ssol_dec_ref[:X1]) atol = 2*1e0
248248

249249
# Test Jump simulations with integer coefficients.
250-
dprob_int = DiscreteProblem(rs_int, u0_int, tspan_stoch, ps_int)
251-
dprob_int_ref = DiscreteProblem(rs_ref_int, u0_int, tspan_stoch, ps_int)
252-
jprob_int = JumpProblem(rs_int, dprob_int, Direct(); rng, save_positions = (false, false))
253-
jprob_int_ref = JumpProblem(rs_ref_int, dprob_int_ref, Direct(); rng, save_positions = (false, false))
250+
jin_int = JumpInputs(rs_int, u0_int, tspan_stoch, ps_int)
251+
jin_int_ref = JumpInputs(rs_ref_int, u0_int, tspan_stoch, ps_int)
252+
jprob_int = JumpProblem(jin_int; rng, save_positions = (false, false))
253+
jprob_int_ref = JumpProblem(jin_int_ref; rng, save_positions = (false, false))
254254
jsol_int = solve(jprob_int, SSAStepper(); seed, saveat = 1.0)
255255
jsol_int_ref = solve(jprob_int_ref, SSAStepper(); seed, saveat = 1.0)
256256
@test mean(jsol_int[:X1]) mean(jsol_int_ref[:X1]) atol = 1e-2 rtol = 1e-2
@@ -265,11 +265,11 @@ let
265265
@parameters n::Int64 k::Int64
266266
i, S + n*I --> k*I
267267
r, n*I --> n*R
268-
end
268+
end
269269
sir_ref = @reaction_network begin
270270
i, S + I --> 2I
271271
r, I --> R
272-
end
272+
end
273273

274274
ps = [:i => 1e-4, :r => 1e-2, :n => 1.0, :k => 2.0]
275275
ps_ref = [:i => 1e-4, :r => 1e-2]
@@ -283,10 +283,10 @@ let
283283
@test solve(oprob, Tsit5()) solve(oprob_ref, Tsit5())
284284

285285
# Jumps. First ensemble problems for each systems is created.
286-
dprob = DiscreteProblem(sir, u0, tspan, ps)
287-
dprob_ref = DiscreteProblem(sir_ref, u0, tspan, ps_ref)
288-
jprob = JumpProblem(sir, dprob, Direct(); rng, save_positions = (false, false))
289-
jprob_ref = JumpProblem(sir_ref, dprob_ref, Direct(); rng, save_positions = (false, false))
286+
jin = JumpInputs(sir, u0, tspan, ps)
287+
jin_ref = JumpInputs(sir_ref, u0, tspan, ps_ref)
288+
jprob = JumpProblem(jin; rng, save_positions = (false, false))
289+
jprob_ref = JumpProblem(jin_ref; rng, save_positions = (false, false))
290290
eprob = EnsembleProblem(jprob)
291291
eprob_ref = EnsembleProblem(jprob_ref)
292292

test/simulation_and_solving/simulate_jumps.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,12 @@ let
130130
zip(catalyst_networks, manual_networks, u0_syms, ps_syms, u0s, ps, sps)
131131

132132
# Simulates the Catalyst-created model.
133-
dprob_1 = DiscreteProblem(rn_catalyst, u0_1, (0.0, 10000.0), ps_1)
134-
jprob_1 = JumpProblem(rn_catalyst, dprob_1, Direct(); rng)
133+
jin_1 = JumpInputs(rn_catalyst, u0_1, (0.0, 10000.0), ps_1)
134+
jprob_1 = JumpProblem(jin_1, Direct(); rng)
135135
sol1 = solve(jprob_1, SSAStepper(); seed, saveat = 1.0)
136136

137137
# simulate using auto-alg
138-
jprob_1b = JumpProblem(rn_catalyst, dprob_1; rng)
138+
jprob_1b = JumpProblem(jin_1; rng)
139139
sol1b = solve(jprob_1; seed, saveat = 1.0)
140140
@test mean(sol1[sp]) mean(sol1b[sp]) rtol = 1e-1
141141

@@ -157,8 +157,8 @@ let
157157
for rn in reaction_networks_all
158158
u0 = rnd_u0_Int64(rn, rng)
159159
ps = rnd_ps(rn, rng)
160-
dprob = DiscreteProblem(rn, u0, (0.0, 1.0), ps)
161-
jprob = JumpProblem(rn, dprob, Direct(); rng)
160+
jin = JumpInputs(rn, u0, (0.0, 1.0), ps)
161+
jprob = JumpProblem(jin; rng)
162162
@test SciMLBase.successful_retcode(solve(jprob, SSAStepper()))
163163
end
164164
end
@@ -169,8 +169,8 @@ let
169169
(1.2, 5), X1 X2
170170
end
171171
u0 = rnd_u0_Int64(no_param_network, rng)
172-
dprob = DiscreteProblem(no_param_network, u0, (0.0, 1000.0))
173-
jprob = JumpProblem(no_param_network, dprob, Direct(); rng)
172+
jin = JumpInputs(no_param_network, u0, (0.0, 1000.0))
173+
jprob = JumpProblem(jin; rng)
174174
sol = solve(jprob, SSAStepper())
175175
@test mean(sol[:X1]) > mean(sol[:X2])
176176
end

test/upstream/mtk_problem_inputs.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ end
133133
# Perform jump simulations (singular and ensemble).
134134
let
135135
# Creates normal and ensemble problems.
136-
base_dprob = DiscreteProblem(model, u0_alts[1], tspan, p_alts[1])
137-
base_jprob = JumpProblem(model, base_dprob, Direct(); rng)
136+
base_jin = JumpInputs(model, u0_alts[1], tspan, p_alts[1])
137+
base_jprob = JumpProblem(base_jin; rng)
138138
base_sol = solve(base_jprob, SSAStepper(); seed, saveat = 1.0)
139139
base_eprob = EnsembleProblem(base_jprob)
140140
base_esol = solve(base_eprob, SSAStepper(); seed, trajectories = 2, saveat = 1.0)
@@ -325,8 +325,8 @@ end
325325
# Perform jump simulations (singular and ensemble).
326326
let
327327
# Creates normal and ensemble problems.
328-
base_dprob = DiscreteProblem(model_vec, u0_alts_vec[1], tspan, p_alts_vec[1])
329-
base_jprob = JumpProblem(model_vec, base_dprob, Direct(); rng)
328+
base_jin = JumpInputs(model_vec, u0_alts_vec[1], tspan, p_alts_vec[1])
329+
base_jprob = JumpProblem(base_jin; rng)
330330
base_sol = solve(base_jprob, SSAStepper(); seed, saveat = 1.0)
331331
base_eprob = EnsembleProblem(base_jprob)
332332
base_esol = solve(base_eprob, SSAStepper(); seed, trajectories = 2, saveat = 1.0)
@@ -415,7 +415,7 @@ let
415415
# Loops through all potential parameter sets, checking that their inputs yield errors.
416416
for ps in [[ps_valid]; ps_invalid], u0 in [[u0_valid]; u0s_invalid]
417417
# Handles all types of time-dependent systems. The `isequal` is because some case should pass.
418-
for XProblem in [ODEProblem, SDEProblem, DiscreteProblem]
418+
for XProblem in [ODEProblem, SDEProblem, JumpInputs]
419419
if isequal(ps, ps_valid) && isequal(u0, u0_valid)
420420
XProblem(rn, u0, (0.0, 1.0), ps)
421421
else

test/upstream/mtk_structure_indexing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ begin
3333
oprob = ODEProblem(model, u0_vals, tspan, p_vals)
3434
sprob = SDEProblem(model,u0_vals, tspan, p_vals)
3535
dprob = DiscreteProblem(model, u0_vals, tspan, p_vals)
36-
jprob = JumpProblem(model, deepcopy(dprob), Direct(); rng)
36+
jprob = JumpProblem(JumpInputs(model, u0_vals, tspan, p_vals); rng)
3737
nprob = NonlinearProblem(model, u0_vals, p_vals)
3838
ssprob = SteadyStateProblem(model, u0_vals, p_vals)
3939
problems = [oprob, sprob, dprob, jprob, nprob, ssprob]
@@ -344,7 +344,7 @@ let
344344
ps = [k1 => 0.1, k2 => 0.2, V0 => 3.0]
345345
prob1 = XProblem(rs, u0, 0.001, ps; remove_conserved = true)
346346
Γ = prob1.f.sys.Γ
347-
347+
348348
# Creates various `remake` version of the problem.
349349
prob2 = remake(prob1, u0 = [X1 => 10.0])
350350
prob3 = remake(prob2, u0 = [X2 => 20.0])
@@ -431,8 +431,8 @@ let
431431
# Creates a JumpProblem and integrator. Checks that the initial mass action rate is correct.
432432
u0 = [:A => 1, :B => 2, :C => 3]
433433
ps = [:p1 => 3.0, :p2 => 2.0]
434-
dprob = DiscreteProblem(rn, u0, (0.0, 1.0), ps)
435-
jprob = JumpProblem(rn, dprob, Direct())
434+
jin = JumpInputs(rn, u0, (0.0, 1.0), ps)
435+
jprob = JumpProblem(jin)
436436
jint = init(jprob, SSAStepper())
437437
@test jprob.massaction_jump.scaled_rates[1] == 6.0
438438

0 commit comments

Comments
 (0)