Skip to content

Commit b2b5873

Browse files
committed
Use init!! for initialisation
1 parent f856389 commit b2b5873

File tree

3 files changed

+45
-174
lines changed

3 files changed

+45
-174
lines changed

docs/src/api.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,11 @@ AbstractPPL.evaluate!!
456456

457457
This method mutates the `varinfo` used for execution.
458458
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
459+
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:
460+
461+
```@docs
462+
DynamicPPL.evaluate_and_sample!!
463+
```
459464

460465
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
461466
Contexts are subtypes of `AbstractPPL.AbstractContext`.
@@ -514,7 +519,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
514519
```@docs
515520
DynamicPPL.initialstep
516521
DynamicPPL.loadstate
517-
DynamicPPL.initialsampler
522+
DynamicPPL.init_strategy
518523
```
519524

520525
Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

src/sampler.jl

Lines changed: 28 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ end
6868
6969
Return a default varinfo object for the given `model` and `sampler`.
7070
71+
The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
72+
7173
# Arguments
7274
- `rng::Random.AbstractRNG`: Random number generator.
7375
- `model::Model`: Model for which we want to create a varinfo object.
@@ -76,9 +78,10 @@ Return a default varinfo object for the given `model` and `sampler`.
7678
# Returns
7779
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
7880
"""
79-
function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler)
80-
init_sampler = initialsampler(sampler)
81-
return typed_varinfo(rng, model, init_sampler)
81+
function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler)
82+
# Note that variable values are unconditionally initialized later, so no
83+
# point putting them in now.
84+
return typed_varinfo(VarInfo())
8285
end
8386

8487
function AbstractMCMC.sample(
@@ -96,24 +99,32 @@ function AbstractMCMC.sample(
9699
)
97100
end
98101

99-
# initial step: general interface for resuming and
102+
"""
103+
init_strategy(sampler)
104+
105+
Define the initialisation strategy used for generating initial values when
106+
sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
107+
"""
108+
init_strategy(::Sampler) = PriorInit()
109+
100110
function AbstractMCMC.step(
101-
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
111+
rng::Random.AbstractRNG,
112+
model::Model,
113+
spl::Sampler;
114+
initial_params::AbstractInitStrategy=init_strategy(spl),
115+
kwargs...,
102116
)
103-
# Sample initial values.
117+
# Generate the default varinfo (usually this just makes an empty VarInfo
118+
# with NamedTuple of Metadata).
104119
vi = default_varinfo(rng, model, spl)
105120

106-
# Update the parameters if provided.
107-
if initial_params !== nothing
108-
vi = initialize_parameters!!(vi, initial_params, model)
109-
110-
# Update joint log probability.
111-
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
112-
# and https://github.com/TuringLang/Turing.jl/issues/1563
113-
# to avoid that existing variables are resampled
114-
vi = last(evaluate!!(model, vi))
115-
end
121+
# Fill it with initial parameters. Note that, if `ParamsInit` is used, the
122+
# parameters provided must be in unlinked space (when inserted into the
123+
# varinfo, they will be adjusted to match the linking status of the
124+
# varinfo).
125+
_, vi = init!!(rng, model, vi, initial_params)
116126

127+
# Call the actual function that does the first step.
117128
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
118129
end
119130

@@ -131,110 +142,7 @@ loadstate(data) = data
131142
132143
Default type of the chain of posterior samples from `sampler`.
133144
"""
134-
default_chain_type(sampler::Sampler) = Any
135-
136-
"""
137-
initialsampler(sampler::Sampler)
138-
139-
Return the sampler that is used for generating the initial parameters when sampling with
140-
`sampler`.
141-
142-
By default, it returns an instance of [`SampleFromPrior`](@ref).
143-
"""
144-
initialsampler(spl::Sampler) = SampleFromPrior()
145-
146-
"""
147-
set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
148-
set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
149-
150-
Take the values inside `initial_params`, replace the corresponding values in
151-
the given VarInfo object, and return a new VarInfo object with the updated values.
152-
153-
This differs from `DynamicPPL.unflatten` in two ways:
154-
155-
1. It works with `NamedTuple` arguments.
156-
2. For the `AbstractVector` method, if any of the elements are missing, it will not
157-
overwrite the original value in the VarInfo (it will just use the original
158-
value instead).
159-
"""
160-
function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
161-
throw(
162-
ArgumentError(
163-
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
164-
"If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.",
165-
),
166-
)
167-
end
168-
169-
function set_initial_values(
170-
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
171-
)
172-
flattened_param_vals = varinfo[:]
173-
length(flattened_param_vals) == length(initial_params) || throw(
174-
DimensionMismatch(
175-
"Provided initial value size ($(length(initial_params))) doesn't match " *
176-
"the model size ($(length(flattened_param_vals))).",
177-
),
178-
)
179-
180-
# Update values that are provided.
181-
for i in eachindex(initial_params)
182-
x = initial_params[i]
183-
if x !== missing
184-
flattened_param_vals[i] = x
185-
end
186-
end
187-
188-
# Update in `varinfo`.
189-
new_varinfo = unflatten(varinfo, flattened_param_vals)
190-
return new_varinfo
191-
end
192-
193-
function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
194-
varinfo = deepcopy(varinfo)
195-
vars_in_varinfo = keys(varinfo)
196-
for v in keys(initial_params)
197-
vn = VarName{v}()
198-
if !(vn in vars_in_varinfo)
199-
for vv in vars_in_varinfo
200-
if subsumes(vn, vv)
201-
throw(
202-
ArgumentError(
203-
"The current model contains sub-variables of $v, such as ($vv). " *
204-
"Using NamedTuple for initial_params is not supported in such a case. " *
205-
"Please use AbstractVector for initial_params instead of NamedTuple.",
206-
),
207-
)
208-
end
209-
end
210-
throw(ArgumentError("Variable $v not found in the model."))
211-
end
212-
end
213-
initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
214-
return update_values!!(
215-
varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))
216-
)
217-
end
218-
219-
function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model)
220-
@debug "Using passed-in initial variable values" initial_params
221-
222-
# `link` the varinfo if needed.
223-
linked = islinked(vi)
224-
if linked
225-
vi = invlink!!(vi, model)
226-
end
227-
228-
# Set the values in `vi`.
229-
vi = set_initial_values(vi, initial_params)
230-
231-
# `invlink` if needed.
232-
if linked
233-
vi = link!!(vi, model)
234-
end
235-
236-
return vi
237-
end
145+
default_chain_type(::Sampler) = Any
238146

239147
"""
240148
initialstep(rng, model, sampler, varinfo; kwargs...)

test/sampler.jl

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@
8282
sampler = Sampler(alg)
8383
lptrue = logpdf(Binomial(25, 0.2), 10)
8484
let inits = (; p=0.2)
85-
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
85+
chain = sample(
86+
model, sampler, 1; initial_params=ParamsInit(inits), progress=false
87+
)
8688
@test chain[1].metadata.p.vals == [0.2]
8789
@test getlogjoint(chain[1]) == lptrue
8890

@@ -110,7 +112,9 @@
110112
model = twovars()
111113
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
112114
for inits in ([4, -1], (; s=4, m=-1))
113-
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
115+
chain = sample(
116+
model, sampler, 1; initial_params=ParamsInit(inits), progress=false
117+
)
114118
@test chain[1].metadata.s.vals == [4]
115119
@test chain[1].metadata.m.vals == [-1]
116120
@test getlogjoint(chain[1]) == lptrue
@@ -122,7 +126,7 @@
122126
MCMCThreads(),
123127
1,
124128
10;
125-
initial_params=fill(inits, 10),
129+
initial_params=fill(ParamsInit(inits), 10),
126130
progress=false,
127131
)
128132
for c in chains
@@ -133,8 +137,10 @@
133137
end
134138

135139
# set only m = -1
136-
for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1))
137-
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
140+
for inits in ((; s=missing, m=-1), (; m=-1))
141+
chain = sample(
142+
model, sampler, 1; initial_params=ParamsInit(inits), progress=false
143+
)
138144
@test !ismissing(chain[1].metadata.s.vals[1])
139145
@test chain[1].metadata.m.vals == [-1]
140146

@@ -153,54 +159,6 @@
153159
@test c[1].metadata.m.vals == [-1]
154160
end
155161
end
156-
157-
# specify `initial_params=nothing`
158-
Random.seed!(1234)
159-
chain1 = sample(model, sampler, 1; progress=false)
160-
Random.seed!(1234)
161-
chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false)
162-
@test_throws DimensionMismatch sample(
163-
model, sampler, 1; progress=false, initial_params=zeros(10)
164-
)
165-
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
166-
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals
167-
168-
# parallel sampling
169-
Random.seed!(1234)
170-
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
171-
Random.seed!(1234)
172-
chains2 = sample(
173-
model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false
174-
)
175-
for (c1, c2) in zip(chains1, chains2)
176-
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals
177-
@test c1[1].metadata.s.vals == c2[1].metadata.s.vals
178-
end
179-
end
180-
181-
@testset "error handling" begin
182-
# https://github.com/TuringLang/Turing.jl/issues/2452
183-
@model function constrained_uniform(n)
184-
Z ~ Uniform(10, 20)
185-
X = Vector{Float64}(undef, n)
186-
for i in 1:n
187-
X[i] ~ Uniform(0, Z)
188-
end
189-
end
190-
191-
n = 2
192-
initial_z = 15
193-
initial_x = [0.2, 0.5]
194-
model = constrained_uniform(n)
195-
vi = VarInfo(model)
196-
197-
@test_throws ArgumentError DynamicPPL.initialize_parameters!!(
198-
vi, [initial_z, initial_x], model
199-
)
200-
201-
@test_throws ArgumentError DynamicPPL.initialize_parameters!!(
202-
vi, (X=initial_x, Z=initial_z), model
203-
)
204162
end
205163
end
206164
end

0 commit comments

Comments
 (0)