Skip to content

Commit f027a45

Browse files
committed
add capability for dynamic models
1 parent edc619f commit f027a45

File tree

8 files changed

+293
-41
lines changed

8 files changed

+293
-41
lines changed

Examples/Delay_Discounting/Run_Delay_Discounting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
#######################################################################################
6464
# Random Experiment
6565
#######################################################################################
66-
randomizer = Optimizer(;design_list, parm_list, data_list, model, approach=Randomize);
66+
randomizer = Optimizer(;design_list, parm_list, data_list, model, design_type=Randomize);
6767
design = randomizer.best_design
6868
new_data = [:random, 0, mean_post(randomizer)..., std_post(randomizer)...]
6969
push!(df, new_data)

Examples/Dynamic/Delay_Discounting.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
discount(t, κ) = 1/(1 + κ*t)
2+
3+
function prob(κ, τ, t_ss, t_ll, r_ss, r_ll)
4+
u_ll = r_ll * discount(t_ll, κ)
5+
u_ss = r_ss * discount(t_ss, κ)
6+
return 1/(1 + exp(-τ * (u_ll - u_ss)))
7+
end
8+
9+
function loglike(κ, τ, t_ss, t_ll, r_ss, r_ll, data)
10+
p = prob(κ, τ, t_ss, t_ll, r_ss, r_ll)
11+
p = p == 1 ? 1 - eps() : p
12+
p = p == 0 ? eps() : p
13+
LL = data ? log(p) : log(1 - p)
14+
# println(" choice ", data, " kappa ", κ, " tau ", τ, " t_ss ", t_ss,
15+
# " t_ll ", t_ll, " r_ss ", r_ss, " r_ll ", r_ll, " LL ", LL)
16+
return LL
17+
end
18+
19+
function simulate(κ, τ, t_ss, t_ll, r_ss, r_ll)
20+
p = prob(κ, τ, t_ss, t_ll, r_ss, r_ll)
21+
return rand() p ? true : false
22+
end

Examples/Dynamic/Run_Dynamic.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#######################################################################################
2+
# Load Packages
3+
#######################################################################################
4+
# set the working directory to the directory in which this file is contained
5+
cd(@__DIR__)
6+
# load the package manager
7+
using Pkg
8+
# activate the project environment
9+
Pkg.activate("../../")
10+
using Revise, AdaptiveDesignOptimization, Random, UtilityModels, Distributions
11+
include("Delay_Discounting.jl")
12+
#######################################################################################
13+
# Define Model
14+
#######################################################################################
15+
Random.seed!(1204)
16+
prior = [Uniform(-5, 5), Uniform(-5, 50)]
17+
18+
model = Model(;prior, loglike)
19+
20+
parm_list == range(-5, 0, length=50) .|> x->10^x,
21+
τ = range(0, 5, length=11)[2:end])
22+
23+
# parm_list = (κ = [.1],
24+
# τ = [.2,.5])
25+
26+
design_list = (
27+
t_ss = [0.0],
28+
t_ll = [0.43, 0.714, 1, 2, 3,
29+
4.3, 6.44, 8.6, 10.8, 12.9,
30+
17.2, 21.5, 26, 52, 104,
31+
156, 260, 520],
32+
r_ss = 12.5:12.5:787.5,
33+
r_ll = [800.0]
34+
)
35+
36+
# design_list = (
37+
# t_ss = [0.0],
38+
# t_ll = [5.0, 10.0],
39+
# r_ss = [12.0, 20.0],
40+
# r_ll = [80.0]
41+
# )
42+
43+
data_list = (choice=[true, false],)
44+
#######################################################################################
45+
# Simulate Experiment
46+
#######################################################################################
47+
using DataFrames
48+
true_parms ==.12, τ=1.5)
49+
n_trials = 100
50+
optimizer = Optimizer(;design_list, parm_list, data_list, model);
51+
design = optimizer.best_design
52+
df = DataFrame(design=Symbol[], trial=Int[], mean_κ=Float64[], mean_τ=Float64[],
53+
std_κ=Float64[], std_τ=Float64[])
54+
new_data = [:optimal, 0, mean_post(optimizer)..., std_post(optimizer)...]
55+
push!(df, new_data)
56+
57+
for trial in 1:n_trials
58+
data = simulate(true_parms..., design...)
59+
design = update!(optimizer, data)
60+
new_data = [:optimal, trial, mean_post(optimizer)..., std_post(optimizer)...]
61+
push!(df, new_data)
62+
end
63+
#######################################################################################
64+
# Random Experiment
65+
#######################################################################################
66+
randomizer = Optimizer(;design_list, parm_list, data_list, model, approach=Randomize);
67+
design = randomizer.best_design
68+
new_data = [:random, 0, mean_post(randomizer)..., std_post(randomizer)...]
69+
push!(df, new_data)
70+
71+
for trial in 1:n_trials
72+
data = simulate(true_parms..., design...)
73+
design = update!(randomizer, data)
74+
new_data = [:random, trial, mean_post(randomizer)..., std_post(randomizer)...]
75+
push!(df, new_data)
76+
end
77+
#######################################################################################
78+
# Plot Results
79+
#######################################################################################
80+
using StatsPlots
81+
@df df plot(:trial, :mean_κ, xlabel="trial", ylabel="mean κ", group=:design, grid=false)
82+
hline!([true_parms.κ], label="true")
83+
84+
@df df plot(:trial, :mean_τ, xlabel="trial", ylabel="mean τ", group=:design, grid=false)
85+
hline!([true_parms.τ], label="true")
86+
87+
@df df plot(:trial, :std_κ, xlabel="trial", ylabel="σ of κ", grid=false, group=:design, ylims=(0,.3))
88+
89+
@df df plot(:trial, :std_τ, xlabel="trial", ylabel="σ of τ", grid=false, group=:design, ylims=(0,2))

Examples/Monetary_Gambles/Run_Monetary_Gamble.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@ design_names = (:p1,:v1,:p2,:v2)
3333
design_list = (design_names,design_vals[1:100])
3434

3535
data_list = (choice=[true, false],)
36-
37-
optimizer = Optimizer(;design_list, parm_list, data_list, model)
3836
#######################################################################################
3937
# Simulate Experiment
4038
#######################################################################################
4139
using DataFrames
4240
true_parms ==-1.0, β=1.0, γ=.7, θ=1.5)
4341
n_trials = 100
42+
optimizer = Optimizer(;design_list, parm_list, data_list, model)
4443
design = optimizer.best_design
4544
df = DataFrame(design=Symbol[], trial=Int[], mean_δ=Float64[], mean_β=Float64[],
4645
mean_γ=Float64[], mean_θ=Float64[], std_δ=Float64[], std_β=Float64[],
@@ -57,7 +56,7 @@ end
5756
#######################################################################################
5857
# Random Experiment
5958
#######################################################################################
60-
randomizer = Randomizer(;design_list, parm_list, data_list, model);
59+
randomizer = Optimizer(;design_list, parm_list, data_list, model, design_type=Randomize);
6160
design = randomizer.best_design
6261
new_data = [:random, 0, mean_post(randomizer)..., std_post(randomizer)...]
6362
push!(df, new_data)

src/AdaptiveDesignOptimization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module AdaptiveDesignOptimization
22
using Distributions, Parameters, StatsFuns
33
export Optimizer, Model, Optimize, Randomize
4+
export Dynamic, Static
45
export update!, mean_post, std_post, get_best_design
56

67
include("structs.jl")

src/functions.jl

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ function prior_probs(prior::Nothing, parm_grid)
1111
return fill(1/length(parm_grid), size(parm_grid))
1212
end
1313

14-
function loglikelihood(model::Model, design_grid, parm_grid, data_grid)
14+
function loglikelihood(model::Model, design_grid, parm_grid, data_grid, model_type, model_state)
1515
return loglikelihood(model.loglike, design_grid, parm_grid, data_grid)
1616
end
1717

18+
function loglikelihood(model::Model, design_grid, parm_grid, data_grid, model_type::Dyn, model_state)
19+
return loglikelihood(model.loglike, design_grid, parm_grid, data_grid, model_state)
20+
end
21+
1822
function loglikelihood(loglike, design_grid, parm_grid, data_grid)
1923
LLs = zeros(length(parm_grid), length(design_grid), length(data_grid))
2024
for (d, data) in enumerate(data_grid)
@@ -27,6 +31,38 @@ function loglikelihood(loglike, design_grid, parm_grid, data_grid)
2731
return LLs
2832
end
2933

34+
function loglikelihood(loglike, design_grid, parm_grid, data_grid, model_state)
35+
LLs = zeros(length(parm_grid), length(design_grid), length(data_grid))
36+
i = 0
37+
for (d, data) in enumerate(data_grid)
38+
for (k,design) in enumerate(design_grid)
39+
for (p,parms) in enumerate(parm_grid)
40+
i += 1
41+
LLs[p,k,d] = loglike(parms..., design..., data..., model_state[i])
42+
end
43+
end
44+
end
45+
return LLs
46+
end
47+
48+
function loglikelihood!(optimizer)
49+
@unpack model, log_like, design_grid, parm_grid, data_grid, model_state = optimizer
50+
return loglikelihood!(model.loglike, log_like, design_grid, parm_grid, data_grid, model_state)
51+
end
52+
53+
function loglikelihood!(loglike, log_like, design_grid, parm_grid, data_grid, model_state)
54+
i = 0
55+
for (d, data) in enumerate(data_grid)
56+
for (k,design) in enumerate(design_grid)
57+
for (p,parms) in enumerate(parm_grid)
58+
i += 1
59+
log_like[p,k,d] = loglike(parms..., design..., data..., model_state[i])
60+
end
61+
end
62+
end
63+
return nothing
64+
end
65+
3066
function marginal_log_like!(optimizer)
3167
@unpack marg_log_like,log_like,log_post = optimizer
3268
marg_log_like .= marginal_log_like(log_post, log_like)
@@ -120,12 +156,36 @@ function update!(optimizer, data)
120156
return best_design
121157
end
122158

159+
function update!(optimizer::Optimizer{A,MT}, data, args...; kwargs...) where {A,MT<:Dyn}
160+
update_posterior!(optimizer, data)
161+
update_states!(optimizer, data, args...; kwargs...)
162+
loglikelihood!(optimizer)
163+
marginal_log_like!(optimizer)
164+
marginal_entropy!(optimizer)
165+
conditional_entropy!(optimizer)
166+
mutual_information!(optimizer)
167+
best_design = find_best_design!(optimizer)
168+
return best_design
169+
end
170+
123171
function update!(optimizer::Optimizer{A}, data) where {A<:Rand}
124172
update_posterior!(optimizer, data)
125173
best_design = find_best_design!(optimizer)
126174
return best_design
127175
end
128176

177+
function update_states!(optimizer, obs_data, args...; kwargs...)
178+
@unpack model_state, data_grid, design_grid, parm_grid, update_state! = optimizer
179+
for (d, data) in enumerate(data_grid)
180+
for (k,design) in enumerate(design_grid)
181+
for (p,parms) in enumerate(parm_grid)
182+
update_state!(model_state, parms, design, data, obs_data, args...; kwargs...)
183+
end
184+
end
185+
end
186+
return nothing
187+
end
188+
129189
function to_grid(vals::NamedTuple)
130190
k = keys(vals)
131191
v = product(vals...) |> collect
@@ -171,8 +231,13 @@ function std_post(optimizer)
171231
return std_post(post, parm_grid)
172232
end
173233

174-
function create_state(T, dims, args...; kwargs...)
234+
function create_state(model_type::Dyn, T, dims, args...; kwargs...)
175235
state = fill(T(args...; kwargs...), dims)
176-
state .= deepcopy.(state)
236+
return state .= deepcopy.(state)
177237
end
178238

239+
function create_state(model_type::Stat, T, dims, args...; kwargs...)
240+
return nothing
241+
end
242+
243+

0 commit comments

Comments
 (0)