Skip to content

Commit edc619f

Browse files
committed
remove randomizer
1 parent f1cbd74 commit edc619f

File tree

5 files changed

+55
-44
lines changed

5 files changed

+55
-44
lines changed

Examples/Delay_Discounting/Run_Delay_Discounting.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ cd(@__DIR__)
77
using Pkg
88
# activate the project environment
99
Pkg.activate("../../")
10-
using AdaptiveDesignOptimization, Random, UtilityModels, Distributions
10+
using Revise, AdaptiveDesignOptimization, Random, UtilityModels, Distributions
1111
include("Delay_Discounting.jl")
1212
#######################################################################################
1313
# Define Model
1414
#######################################################################################
15-
Random.seed!(204)
15+
Random.seed!(1204)
1616
prior = [Uniform(-5, 5), Uniform(-5, 50)]
1717

1818
model = Model(;prior, loglike)
@@ -41,14 +41,13 @@ design_list = (
4141
# )
4242

4343
data_list = (choice=[true, false],)
44-
45-
optimizer = Optimizer(;design_list, parm_list, data_list, model);
4644
#######################################################################################
4745
# Simulate Experiment
4846
#######################################################################################
4947
using DataFrames
5048
true_parms ==.12, τ=1.5)
5149
n_trials = 100
50+
optimizer = Optimizer(;design_list, parm_list, data_list, model);
5251
design = optimizer.best_design
5352
df = DataFrame(design=Symbol[], trial=Int[], mean_κ=Float64[], mean_τ=Float64[],
5453
std_κ=Float64[], std_τ=Float64[])
@@ -64,7 +63,7 @@ end
6463
#######################################################################################
6564
# Random Experiment
6665
#######################################################################################
67-
randomizer = Randomizer(;design_list, parm_list, data_list, model);
66+
randomizer = Optimizer(;design_list, parm_list, data_list, model, approach=Randomize);
6867
design = randomizer.best_design
6968
new_data = [:random, 0, mean_post(randomizer)..., std_post(randomizer)...]
7069
push!(df, new_data)
@@ -88,4 +87,3 @@ hline!([true_parms.τ], label="true")
8887
@df df plot(:trial, :std_κ, xlabel="trial", ylabel="σ of κ", grid=false, group=:design, ylims=(0,.3))
8988

9089
@df df plot(:trial, :std_τ, xlabel="trial", ylabel="σ of τ", grid=false, group=:design, ylims=(0,2))
91-

src/AdaptiveDesignOptimization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module AdaptiveDesignOptimization
22
using Distributions, Parameters, StatsFuns
3-
export Optimizer, Model, Randomizer
3+
export Optimizer, Model, Optimize, Randomize
44
export update!, mean_post, std_post, get_best_design
55

66
include("structs.jl")

src/functions.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ function find_best_design(mutual_info, design_grid, design_names)
9292
return best_design
9393
end
9494

95-
function find_best_design!(randomizer::Randomizer)
96-
@unpack design_grid,design_names = randomizer
95+
function find_best_design!(optimizer::Optimizer{A}) where {A<:Rand}
96+
@unpack design_grid,design_names = optimizer
9797
best_design = rand(design_grid)
98-
randomizer.best_design = best_design
98+
optimizer.best_design = best_design
9999
return best_design
100100
end
101101

@@ -120,9 +120,9 @@ function update!(optimizer, data)
120120
return best_design
121121
end
122122

123-
function update!(randomizer::Randomizer, data)
124-
update_posterior!(randomizer, data)
125-
best_design = find_best_design!(randomizer)
123+
function update!(optimizer::Optimizer{A}, data) where {A<:Rand}
124+
update_posterior!(optimizer, data)
125+
best_design = find_best_design!(optimizer)
126126
return best_design
127127
end
128128

@@ -169,4 +169,10 @@ function std_post(optimizer)
169169
@unpack log_post,parm_grid = optimizer
170170
post = exp.(log_post)
171171
return std_post(post, parm_grid)
172-
end
172+
end
173+
174+
function create_state(T, dims, args...; kwargs...)
175+
state = fill(T(args...; kwargs...), dims)
176+
state .= deepcopy.(state)
177+
end
178+

src/structs.jl

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
import Base.Iterators: product
22

3+
abstract type Approach end
4+
5+
struct Opt <: Approach
6+
end
7+
struct Rand <:Approach
8+
end
9+
10+
const Optimize = Opt()
11+
const Randomize = Rand()
12+
313
"""
414
**Model**
515
@@ -46,7 +56,8 @@ Constructor
4656
Optimizer(;task, model, grid_design, grid_parms, grid_response)
4757
````
4858
"""
49-
mutable struct Optimizer{M<:Model,T1,T2,T3,T4,T5,T6,T7,T8,T9}
59+
mutable struct Optimizer{A,M<:Model,T1,T2,T3,T4,T5,T6,T7,T8,T9}
60+
approach::A
5061
model::M
5162
design_grid::T1
5263
parm_grid::T2
@@ -64,7 +75,7 @@ mutable struct Optimizer{M<:Model,T1,T2,T3,T4,T5,T6,T7,T8,T9}
6475
design_names::T9
6576
end
6677

67-
function Optimizer(;model, design_list, parm_list, data_list)
78+
function Optimizer(; model, design_list, parm_list, data_list, approach=Optimize)
6879
design_names,design_grid = to_grid(design_list)
6980
parm_names,parm_grid = to_grid(parm_list)
7081
_,data_grid = to_grid(data_list)
@@ -78,34 +89,29 @@ function Optimizer(;model, design_list, parm_list, data_list)
7889
cond_entropy = conditional_entropy(entropy, post)
7990
mutual_info = mutual_information(marg_entropy, cond_entropy)
8091
best_design = find_best_design(mutual_info, design_grid, design_names)
81-
return Optimizer(model, design_grid, parm_grid, data_grid, log_like,
92+
return Optimizer(approach, model, design_grid, parm_grid, data_grid, log_like,
8293
marg_log_like, priors, log_post, entropy, marg_entropy, cond_entropy,
8394
mutual_info, best_design, parm_names, design_names)
8495
end
8596

86-
87-
mutable struct Randomizer{M<:Model,T1,T2,T3,T4,T5,T6,T7}
88-
model::M
89-
design_grid::T1
90-
parm_grid::T2
91-
data_grid::T3
92-
log_like::Array{Float64,3}
93-
priors::T4
94-
log_post::Vector{Float64}
95-
best_design::T5
96-
parm_names::T6
97-
design_names::T7
98-
end
99-
100-
function Randomizer(;model, design_list, parm_list, data_list)
101-
design_names,design_grid = to_grid(design_list)
102-
parm_names,parm_grid = to_grid(parm_list)
103-
_,data_grid = to_grid(data_list)
104-
log_like = loglikelihood(model, design_grid, parm_grid, data_grid)
105-
priors = prior_probs(model, parm_grid)
106-
post = priors[:]
107-
log_post = log.(post)
108-
best_design = rand(design_grid)
109-
return Randomizer(model, design_grid, parm_grid, data_grid, log_like,
110-
priors, log_post, best_design, parm_names, design_names)
111-
end
97+
# function Optimizer(args...; update_log_like, model, design_list, parm_list, data_list,
98+
# state_type, kwargs...)
99+
# design_names,design_grid = to_grid(design_list)
100+
# parm_names,parm_grid = to_grid(parm_list)
101+
# _,data_grid = to_grid(data_list)
102+
# dims = map(length, (parm_grid,design_grid,data_grid))
103+
# state = create_state(state_type, dims, args...; kwargs...)
104+
# log_like = loglikelihood(model, design_grid, parm_grid, data_grid)
105+
# priors = prior_probs(model, parm_grid)
106+
# post = priors[:]
107+
# log_post = log.(post)
108+
# entropy = compute_entropy(log_like)
109+
# marg_log_like = marginal_log_like(log_post, log_like)
110+
# marg_entropy = marginal_entropy(marg_log_like)
111+
# cond_entropy = conditional_entropy(entropy, post)
112+
# mutual_info = mutual_information(marg_entropy, cond_entropy)
113+
# best_design = find_best_design(mutual_info, design_grid, design_names)
114+
# return Optimizer(model, design_grid, parm_grid, data_grid, log_like,
115+
# marg_log_like, priors, log_post, entropy, marg_entropy, cond_entropy,
116+
# mutual_info, best_design, parm_names, design_names, update_log_like, state)
117+
# end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ using SafeTestsets
1717

1818
data_list = (choice=[true, false],)
1919

20-
randomizer = Randomizer(;design_list, parm_list, data_list, model)
20+
randomizer = Optimizer(;design_list, parm_list, data_list, model,
21+
approach=Randomize)
2122

2223
@test mean_post(randomizer)[1] mean(Beta(α,β)) atol = 5e-3
2324
@test std_post(randomizer)[1] std(Beta(α,β)) atol = 5e-3

0 commit comments

Comments
 (0)