Skip to content

Commit 2d20194

Browse files
committed
Use init!! for initialisation
1 parent 09283d2 commit 2d20194

File tree

1 file changed

+28
-120
lines changed

1 file changed

+28
-120
lines changed

src/sampler.jl

Lines changed: 28 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ end
6868
6969
Return a default varinfo object for the given `model` and `sampler`.
7070
71+
The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
72+
7173
# Arguments
7274
- `rng::Random.AbstractRNG`: Random number generator.
7375
- `model::Model`: Model for which we want to create a varinfo object.
@@ -76,9 +78,10 @@ Return a default varinfo object for the given `model` and `sampler`.
7678
# Returns
7779
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
7880
"""
79-
function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler)
80-
init_sampler = initialsampler(sampler)
81-
return typed_varinfo(rng, model, init_sampler)
81+
function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler)
82+
# Note that variable values are unconditionally initialized later, so no
83+
# point putting them in now.
84+
return typed_varinfo(VarInfo())
8285
end
8386

8487
function AbstractMCMC.sample(
@@ -96,24 +99,32 @@ function AbstractMCMC.sample(
9699
)
97100
end
98101

99-
# initial step: general interface for resuming and
102+
"""
103+
init_strategy(sampler)
104+
105+
Define the initialisation strategy used for generating initial values when
106+
sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
107+
"""
108+
init_strategy(::Sampler) = PriorInit()
109+
100110
function AbstractMCMC.step(
101-
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
111+
rng::Random.AbstractRNG,
112+
model::Model,
113+
spl::Sampler;
114+
initial_params::AbstractInitStrategy=init_strategy(spl),
115+
kwargs...,
102116
)
103-
# Sample initial values.
117+
# Generate the default varinfo (usually this just makes an empty VarInfo
118+
# with NamedTuple of Metadata).
104119
vi = default_varinfo(rng, model, spl)
105120

106-
# Update the parameters if provided.
107-
if initial_params !== nothing
108-
vi = initialize_parameters!!(vi, initial_params, model)
109-
110-
# Update joint log probability.
111-
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
112-
# and https://github.com/TuringLang/Turing.jl/issues/1563
113-
# to avoid that existing variables are resampled
114-
vi = last(evaluate!!(model, vi))
115-
end
121+
# Fill it with initial parameters. Note that, if `ParamsInit` is used, the
122+
# parameters provided must be in unlinked space (when inserted into the
123+
# varinfo, they will be adjusted to match the linking status of the
124+
# varinfo).
125+
_, vi = init!!(rng, model, vi, initial_params)
116126

127+
# Call the actual function that does the first step.
117128
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
118129
end
119130

@@ -131,110 +142,7 @@ loadstate(data) = data
131142
132143
Default type of the chain of posterior samples from `sampler`.
133144
"""
134-
default_chain_type(sampler::Sampler) = Any
135-
136-
"""
137-
initialsampler(sampler::Sampler)
138-
139-
Return the sampler that is used for generating the initial parameters when sampling with
140-
`sampler`.
141-
142-
By default, it returns an instance of [`SampleFromPrior`](@ref).
143-
"""
144-
initialsampler(spl::Sampler) = SampleFromPrior()
145-
146-
"""
147-
set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
148-
set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
149-
150-
Take the values inside `initial_params`, replace the corresponding values in
151-
the given VarInfo object, and return a new VarInfo object with the updated values.
152-
153-
This differs from `DynamicPPL.unflatten` in two ways:
154-
155-
1. It works with `NamedTuple` arguments.
156-
2. For the `AbstractVector` method, if any of the elements are missing, it will not
157-
overwrite the original value in the VarInfo (it will just use the original
158-
value instead).
159-
"""
160-
function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
161-
throw(
162-
ArgumentError(
163-
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
164-
"If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.",
165-
),
166-
)
167-
end
168-
169-
function set_initial_values(
170-
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
171-
)
172-
flattened_param_vals = varinfo[:]
173-
length(flattened_param_vals) == length(initial_params) || throw(
174-
DimensionMismatch(
175-
"Provided initial value size ($(length(initial_params))) doesn't match " *
176-
"the model size ($(length(flattened_param_vals))).",
177-
),
178-
)
179-
180-
# Update values that are provided.
181-
for i in eachindex(initial_params)
182-
x = initial_params[i]
183-
if x !== missing
184-
flattened_param_vals[i] = x
185-
end
186-
end
187-
188-
# Update in `varinfo`.
189-
new_varinfo = unflatten(varinfo, flattened_param_vals)
190-
return new_varinfo
191-
end
192-
193-
function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
194-
varinfo = deepcopy(varinfo)
195-
vars_in_varinfo = keys(varinfo)
196-
for v in keys(initial_params)
197-
vn = VarName{v}()
198-
if !(vn in vars_in_varinfo)
199-
for vv in vars_in_varinfo
200-
if subsumes(vn, vv)
201-
throw(
202-
ArgumentError(
203-
"The current model contains sub-variables of $v, such as ($vv). " *
204-
"Using NamedTuple for initial_params is not supported in such a case. " *
205-
"Please use AbstractVector for initial_params instead of NamedTuple.",
206-
),
207-
)
208-
end
209-
end
210-
throw(ArgumentError("Variable $v not found in the model."))
211-
end
212-
end
213-
initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
214-
return update_values!!(
215-
varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))
216-
)
217-
end
218-
219-
function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model)
220-
@debug "Using passed-in initial variable values" initial_params
221-
222-
# `link` the varinfo if needed.
223-
linked = islinked(vi)
224-
if linked
225-
vi = invlink!!(vi, model)
226-
end
227-
228-
# Set the values in `vi`.
229-
vi = set_initial_values(vi, initial_params)
230-
231-
# `invlink` if needed.
232-
if linked
233-
vi = link!!(vi, model)
234-
end
235-
236-
return vi
237-
end
145+
default_chain_type(::Sampler) = Any
238146

239147
"""
240148
initialstep(rng, model, sampler, varinfo; kwargs...)

0 commit comments

Comments
 (0)