Skip to content

Commit ef2d280

Browse files
authored
Merge pull request #305 from SciML/param_estim_doc_tutorial
Simple Parameter Estimation Tutorial
2 parents 27c1502 + 785e6f7 commit ef2d280

7 files changed

+4888
-1
lines changed

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ makedocs(
4949
"tutorials/advanced.md",
5050
"tutorials/generated_systems.md",
5151
"tutorials/advanced_examples.md",
52-
"tutorials/bifurcation_diagram.md"
52+
"tutorials/bifurcation_diagram.md",
53+
"tutorials/parameter_estimation.md"
5354
],
5455
"API" => Any[
5556
"api/catalyst_api.md"

docs/src/assets/parameter_estimation_plot1.svg

Lines changed: 756 additions & 0 deletions
Loading

docs/src/assets/parameter_estimation_plot2.svg

Lines changed: 982 additions & 0 deletions
Loading

docs/src/assets/parameter_estimation_plot3.svg

Lines changed: 982 additions & 0 deletions
Loading

docs/src/assets/parameter_estimation_plot4.svg

Lines changed: 1000 additions & 0 deletions
Loading

docs/src/assets/parameter_estimation_plot5.svg

Lines changed: 1060 additions & 0 deletions
Loading
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Parameter Estimation
2+
The parameters of a model, generated by Catalyst, can be estimated using various packages available in the Julia ecosystem. Refer [here](https://diffeq.sciml.ai/stable/analysis/parameter_estimation/) for more extensive information. Below follows a quick tutorial of how [DiffEqFlux](https://diffeqflux.sciml.ai/dev/) can be used to fit a parameter set to data.
3+
4+
First, we fetch the required packages.
5+
```julia
6+
using OrdinaryDiffEq
7+
using DiffEqFlux, Flux
8+
using Catalyst
9+
```
10+
11+
Next, we declare our model. For our example, we will use the Brusselator, a simple oscillator.
12+
```julia
13+
brusselator = @reaction_network begin
14+
A, ∅ X
15+
1, 2X + Y 3X
16+
B, X Y
17+
1, X
18+
end A B
19+
p_real = [1., 2.]
20+
```
21+
22+
We simulate our model, and from the simulation generate sampled data points (with added noise), to which we will attempt to fit a parameter et.
23+
```julia
24+
u0 = [1.0, 1.0]
25+
tspan = (0.0, 30.0)
26+
27+
sample_times = range(tspan[1],stop=tspan[2],length=100)
28+
prob = ODEProblem(brusselator, u0, tspan, p_real)
29+
sol_real = solve(prob, Rosenbrock23(), tstops=sample_times)
30+
31+
sample_vals = [sol_real.u[findfirst(sol_real.t .>= ts)][var] * (1+(0.1rand()-0.05)) for var in 1:2, ts in sample_times];
32+
```
33+
34+
We can plot the real solution, as well as the noisy samples.
35+
```julia
36+
using Plots
37+
plot(sol_real,size=(1200,400),label="",framestyle=:box,lw=3,color=[:darkblue :darkred])
38+
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label="")
39+
```
40+
![parameter_estimation_plot1](../assets/parameter_estimation_plot1.svg)
41+
42+
Next, we create an optimisation function. For a given initial estimate of the parameter values, p, this function will fit parameter values to our data samples. However, it will only do so on the interval [0,tend].
43+
```julia
44+
function optimise_p(p_init,tend)
45+
function loss(p)
46+
sol = solve(remake(prob,tspan=(0.,tend),p=p), Rosenbrock23(), tstops=sample_times)
47+
vals = hcat(map(ts -> sol.u[findfirst(sol.t .>= ts)], sample_times[1:findlast(sample_times .<= tend)])...)
48+
loss = sum(abs2, vals .- sample_vals[:,1:size(vals)[2]])
49+
return loss, sol
50+
end
51+
return DiffEqFlux.sciml_train(loss,p_init,ADAM(0.1),maxiters = 100)
52+
end
53+
```
54+
55+
Next, we will fit a parameter set to the data on the interval [0,10].
56+
```julia
57+
p_estimate = optimise_p([5.,5.],10.).minimizer
58+
```
59+
60+
We can compare this to the real solution, as well as the sample data
61+
```julia
62+
sol_estimate = solve(remake(prob,tspan=(0.,10.),p=p_estimate), Rosenbrock23())
63+
plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2)
64+
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4)
65+
plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan)
66+
```
67+
![parameter_estimation_plot2](../assets/parameter_estimation_plot2.svg)
68+
69+
Next, we use this parameter estimation as the input to the next iteration of our fitting process, this time on the interval [0,20].
70+
```julia
71+
p_estimate = optimise_p(p_estimate,20.).minimizer
72+
73+
sol_estimate = solve(remake(prob,tspan=(0.,20.),p=p_estimate), Rosenbrock23())
74+
plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2)
75+
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4)
76+
plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan)
77+
```
78+
![parameter_estimation_plot3](../assets/parameter_estimation_plot3.svg)
79+
80+
Finally, we use this estimate as the input to fit a parameter set on the full interval of sampled data.
81+
```julia
82+
p_estimate = optimise_p(p_estimate,30.).minimizer
83+
84+
sol_estimate = solve(remake(prob,tspan=(0.,30.),p=p_estimate), Rosenbrock23())
85+
plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2)
86+
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4)
87+
plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan)
88+
```
89+
![parameter_estimation_plot4](../assets/parameter_estimation_plot4.svg)
90+
91+
The final parameter set becomes `[0.9996559014056948, 2.005632696191224]` (the real one was `[1.0, 2.0]`).
92+
93+
94+
### Why we fit the parameters in iterations.
95+
The reason we chose to fit the model on a smaller interval to begin with, and then extend the interval, is to avoid getting stuck in a local minimum. Here specifically, we chose our initial interval to be smaller than a full cycle of the oscillation. If we had chosen to fit a parameter set on the full interval immediately we would have received an inferior solution.
96+
```julia
97+
p_estimate = optimise_p([5.,5.],30.).minimizer
98+
99+
sol_estimate = solve(remake(prob,tspan=(0.,30.),p=p_estimate), Rosenbrock23())
100+
plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2)
101+
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4)
102+
plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan)
103+
```
104+
![parameter_estimation_plot5](../assets/parameter_estimation_plot5.svg)
105+
106+

0 commit comments

Comments
 (0)