|
| 1 | +# Parameter Estimation |
| 2 | +The parameters of a model, generated by Catalyst, can be estimated using various packages available in the Julia ecosystem. Refer [here](https://diffeq.sciml.ai/stable/analysis/parameter_estimation/) for more extensive information. Below follows a quick tutorial of how [DiffEqFlux](https://diffeqflux.sciml.ai/dev/) can be used to fit a parameter set to data. |
| 3 | + |
| 4 | +First, we fetch the required packages. |
| 5 | +```julia |
| 6 | +using OrdinaryDiffEq |
| 7 | +using DiffEqFlux, Flux |
| 8 | +using Catalyst |
| 9 | +``` |
| 10 | + |
| 11 | +Next, we declare our model. For our example, we will use the Brusselator, a simple oscillator. |
| 12 | +```julia |
| 13 | +brusselator = @reaction_network begin |
| 14 | + A, ∅ → X |
| 15 | + 1, 2X + Y → 3X |
| 16 | + B, X → Y |
| 17 | + 1, X → ∅ |
| 18 | +end A B |
| 19 | +p_real = [1., 2.] |
| 20 | +``` |
| 21 | + |
| 22 | +We simulate our model, and from the simulation generate sampled data points (with added noise), to which we will attempt to fit a parameter et. |
| 23 | +```julia |
| 24 | +u0 = [1.0, 1.0] |
| 25 | +tspan = (0.0, 30.0) |
| 26 | + |
| 27 | +sample_times = range(tspan[1],stop=tspan[2],length=100) |
| 28 | +prob = ODEProblem(brusselator, u0, tspan, p_real) |
| 29 | +sol_real = solve(prob, Rosenbrock23(), tstops=sample_times) |
| 30 | + |
| 31 | +sample_vals = [sol_real.u[findfirst(sol_real.t .>= ts)][var] * (1+(0.1rand()-0.05)) for var in 1:2, ts in sample_times]; |
| 32 | +``` |
| 33 | + |
| 34 | +We can plot the real solution, as well as the noisy samples. |
| 35 | +```julia |
| 36 | +using Plots |
| 37 | +plot(sol_real,size=(1200,400),label="",framestyle=:box,lw=3,color=[:darkblue :darkred]) |
| 38 | +plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label="") |
| 39 | +``` |
| 40 | + |
| 41 | + |
| 42 | +Next, we create an optimisation function. For a given initial estimate of the parameter values, p, this function will fit parameter values to our data samples. However, it will only do so on the interval [0,tend]. |
| 43 | +```julia |
| 44 | +function optimise_p(p_init,tend) |
| 45 | + function loss(p) |
| 46 | + sol = solve(remake(prob,tspan=(0.,tend),p=p), Rosenbrock23(), tstops=sample_times) |
| 47 | + vals = hcat(map(ts -> sol.u[findfirst(sol.t .>= ts)], sample_times[1:findlast(sample_times .<= tend)])...) |
| 48 | + loss = sum(abs2, vals .- sample_vals[:,1:size(vals)[2]]) |
| 49 | + return loss, sol |
| 50 | + end |
| 51 | + return DiffEqFlux.sciml_train(loss,p_init,ADAM(0.1),maxiters = 100) |
| 52 | +end |
| 53 | +``` |
| 54 | + |
| 55 | +Next, we will fit a parameter set to the data on the interval [0,10]. |
| 56 | +```julia |
| 57 | +p_estimate = optimise_p([5.,5.],10.).minimizer |
| 58 | +``` |
| 59 | + |
| 60 | +We can compare this to the real solution, as well as the sample data |
| 61 | +```julia |
| 62 | +sol_estimate = solve(remake(prob,tspan=(0.,10.),p=p_estimate), Rosenbrock23()) |
| 63 | +plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2) |
| 64 | +plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4) |
| 65 | +plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan) |
| 66 | +``` |
| 67 | + |
| 68 | + |
| 69 | +Next, we use this parameter estimation as the input to the next iteration of our fitting process, this time on the interval [0,20]. |
| 70 | +```julia |
| 71 | +p_estimate = optimise_p(p_estimate,20.).minimizer |
| 72 | + |
| 73 | +sol_estimate = solve(remake(prob,tspan=(0.,20.),p=p_estimate), Rosenbrock23()) |
| 74 | +plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2) |
| 75 | +plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4) |
| 76 | +plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan) |
| 77 | +``` |
| 78 | + |
| 79 | + |
| 80 | +Finally, we use this estimate as the input to fit a parameter set on the full interval of sampled data. |
| 81 | +```julia |
| 82 | +p_estimate = optimise_p(p_estimate,30.).minimizer |
| 83 | + |
| 84 | +sol_estimate = solve(remake(prob,tspan=(0.,30.),p=p_estimate), Rosenbrock23()) |
| 85 | +plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2) |
| 86 | +plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4) |
| 87 | +plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan) |
| 88 | +``` |
| 89 | + |
| 90 | + |
| 91 | +The final parameter set becomes `[0.9996559014056948, 2.005632696191224]` (the real one was `[1.0, 2.0]`). |
| 92 | + |
| 93 | + |
| 94 | +### Why we fit the parameters in iterations. |
| 95 | +The reason we chose to fit the model on a smaller interval to begin with, and then extend the interval, is to avoid getting stuck in a local minimum. Here specifically, we chose our initial interval to be smaller than a full cycle of the oscillation. If we had chosen to fit a parameter set on the full interval immediately we would have received an inferior solution. |
| 96 | +```julia |
| 97 | +p_estimate = optimise_p([5.,5.],30.).minimizer |
| 98 | + |
| 99 | +sol_estimate = solve(remake(prob,tspan=(0.,30.),p=p_estimate), Rosenbrock23()) |
| 100 | +plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2) |
| 101 | +plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4) |
| 102 | +plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan) |
| 103 | +``` |
| 104 | + |
| 105 | + |
| 106 | + |
0 commit comments