Skip to content

Commit 20d0909

Browse files
committed
Continue
1 parent 0caa3bd commit 20d0909

File tree

8 files changed

+206
-287
lines changed

8 files changed

+206
-287
lines changed

src/DynamicPPL.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ export AbstractVarInfo,
9898
# Samplers
9999
Sampler,
100100
# Initialisation strategies
101-
Prior,
102-
Uniform,
101+
PriorInit,
102+
UniformInit,
103+
ParamsInit,
103104
# LogDensityFunction
104105
LogDensityFunction,
105106
# Contexts

src/contexts/init.jl

Lines changed: 182 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,8 @@
1-
# Uniform random numbers with range 4 for robust initializations
1+
# UniformInit random numbers with range 4 for robust initializations
22
# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
33
randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2
44
randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2
55

6-
istransformable(dist) = link_transform(dist) !== identity
7-
8-
#################################
9-
# Single-sample initialisations #
10-
#################################
11-
inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng))
12-
function inittrans(rng, dist::MultivariateDistribution)
13-
# Get the length of the unconstrained vector
14-
b = link_transform(dist)
15-
d = Bijectors.output_length(b, length(dist))
16-
return Bijectors.invlink(dist, randrealuni(rng, d))
17-
end
18-
function inittrans(rng, dist::MatrixDistribution)
19-
# Get the size of the unconstrained vector
20-
b = link_transform(dist)
21-
sz = Bijectors.output_size(b, size(dist))
22-
return Bijectors.invlink(dist, randrealuni(rng, sz...))
23-
end
24-
function inittrans(rng, dist::Distribution{CholeskyVariate})
25-
# Get the size of the unconstrained vector
26-
b = link_transform(dist)
27-
sz = Bijectors.output_size(b, size(dist))
28-
return Bijectors.invlink(dist, randrealuni(rng, sz...))
29-
end
30-
################################
31-
# Multi-sample initialisations #
32-
################################
33-
function inittrans(rng, dist::UnivariateDistribution, n::Int)
34-
return Bijectors.invlink(dist, randrealuni(rng, n))
35-
end
36-
function inittrans(rng, dist::MultivariateDistribution, n::Int)
37-
return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n))
38-
end
39-
function inittrans(rng, dist::MatrixDistribution, n::Int)
40-
return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n])
41-
end
42-
436
"""
447
AbstractInitStrategy
458
@@ -49,15 +12,29 @@ the random variables in a model (e.g., when creating a new VarInfo).
4912
abstract type AbstractInitStrategy end
5013

5114
"""
52-
Prior()
15+
init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy)
16+
17+
Generate a new value for a random variable with the given distribution.
18+
19+
!!! warning "Values must be unlinked"
20+
The values returned by `init` are always in the untransformed space, i.e.,
21+
they must be within the support of the original distribution. That means that,
22+
for example, `init(rng, dist, u::UniformInit)` will in general return values that
23+
are outside the range [u.lower, u.upper].
24+
"""
25+
function init end
5326

54-
Obtain new values by sampling from the prior.
5527
"""
56-
struct Prior <: AbstractInitStrategy end
28+
PriorInit()
5729
30+
Obtain new values by sampling from the prior distribution.
5831
"""
59-
Uniform()
60-
Uniform(lower, upper)
32+
struct PriorInit <: AbstractInitStrategy end
33+
init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist)
34+
35+
"""
36+
UniformInit()
37+
UniformInit(lower, upper)
6138
6239
Obtain new values by first transforming the distribution of the random variable
6340
to unconstrained space, and then sampling a value uniformly between `lower` and
@@ -70,41 +47,65 @@ default initialisation strategy.
7047
7148
[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization)
7249
"""
73-
struct Uniform{T<:AbstractFloat} <: AbstractInitStrategy
50+
struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy
7451
lower::T
7552
upper::T
53+
function UniformInit(lower::T, upper::T) where {T<:AbstractFloat}
54+
lower > upper &&
55+
throw(ArgumentError("`lower` must be less than or equal to `upper`"))
56+
return new{T}(lower, upper)
57+
end
58+
UniformInit() = UniformInit(-2.0, 2.0)
59+
end
60+
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit)
61+
b = Bijectors.bijector(dist)
62+
sz = Bijectors.output_size(b, size(dist))
63+
y = rand(rng, Uniform(u.lower, u.upper), sz)
64+
b_inv = Bijectors.inverse(b)
65+
return b_inv(y)
7666
end
77-
Uniform() = Uniform(-2, 2)
7867

7968
"""
80-
Params(params::AbstractDict{VarName, Any}, default::AbstractInitStrategy)
81-
Params(params::NamedTuple, default::AbstractInitStrategy)
69+
ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit())
70+
ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit())
8271
8372
Obtain new values by extracting them from the given dictionary or NamedTuple.
84-
These values are assumed to be provided in the space of the untransformed
85-
distribution.
86-
8773
The parameter `default` specifies how new values are to be obtained if they
88-
cannot be found in `params`. The default for `default` is `Prior()`.
74+
cannot be found in `params`, or they are specified as `missing`. The default
75+
for `default` is `PriorInit()`.
76+
77+
!!! note
78+
These values must be provided in the space of the untransformed distribution.
8979
"""
90-
struct Params{P,S<:AbstractInitStrategy} <: AbstractInitStrategy
80+
struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy
9181
params::P
9282
default::S
93-
94-
function Params(
95-
params::AbstractDict{VarName,Any}, default::AbstractInitStrategy=Prior()
96-
)
83+
function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy)
9784
return new{typeof(params),typeof(default)}(params, default)
9885
end
99-
function Params(params::NamedTuple, default::AbstractInitStrategy=Prior())
100-
return Params(to_varname_dict(params), default)
86+
ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit())
87+
function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit())
88+
return ParamsInit(to_varname_dict(params), default)
89+
end
90+
end
91+
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit)
92+
return if hasvalue(p.params, vn)
93+
x = getvalue(p.params, vn)
94+
if x === missing
95+
init(rng, vn, dist, p.default)
96+
else
97+
# TODO: Check that the type of x matches the dist?
98+
x
99+
end
100+
else
101+
init(rng, vn, dist, p.default)
101102
end
102103
end
103104

104105
"""
105106
InitContext(
106107
[rng::Random.AbstractRNG=Random.default_rng()],
107-
[strategy::AbstractInitStrategy=Prior()],
108+
[strategy::AbstractInitStrategy=PriorInit()],
108109
)
109110
110111
A leaf context that indicates that new values for random variables are
@@ -115,95 +116,144 @@ VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then
115116
struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext
116117
rng::R
117118
strategy::S
118-
function InitContext(rng::Random.AbstractRNG, strategy::AbstractInitStrategy=Prior())
119+
function InitContext(
120+
rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit()
121+
)
119122
return new{typeof(rng),typeof(strategy)}(rng, strategy)
120123
end
121-
function InitContext(strategy::AbstractInitStrategy=Prior())
124+
function InitContext(strategy::AbstractInitStrategy=PriorInit())
122125
return InitContext(Random.default_rng(), strategy)
123126
end
124127
end
125128
NodeTrait(::InitContext) = IsLeaf()
126129

127130
function tilde_assume(
128-
ctx::InitContext{<:Random.AbstractRNG,Prior},
129-
dist::Distribution,
130-
vn::VarName,
131-
vi::AbstractVarInfo,
131+
ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
132132
)
133-
r = rand(ctx.rng, dist)
134-
vi[vn] = r
135-
# TODO: FIX
136-
logjac = 0
137-
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
138-
println("sampled $r from $dist for $vn")
139-
return r, vi
133+
in_varinfo = haskey(vi, vn)
134+
# `init()` always returns values in original space, i.e. possibly
135+
# constrained
136+
x = init(ctx.rng, vn, dist, ctx.strategy)
137+
# There is a function `to_maybe_linked_internal_transform` that does this,
138+
# but unfortunately it uses `istrans(vi, vn)` which fails if vn is not in
139+
# vi, so we have to manually check. By default we will insert an unlinked
140+
# value into the varinfo.
141+
is_transformed = in_varinfo ? istrans(vi, vn) : false
142+
f = if is_transformed
143+
to_linked_internal_transform(vi, vn, dist)
144+
else
145+
to_internal_transform(vi, vn, dist)
146+
end
147+
# TODO(penelopeysm): We would really like to do:
148+
# y, logjac = with_logabsdet_jacobian(f, x)
149+
# Unfortunately, `to_{linked_}internal_transform` returns a function that
150+
# always converts x to a vector, i.e., if dist is univariate, f(x) will be
151+
# a vector of length 1. It would be nice if we could unify these.
152+
y = f(x)
153+
logjac = logabsdetjac(is_transformed ? Bijectors.bijector(dist) : identity, x)
154+
# Add the new value to the VarInfo. `push!!` errors if the value already
155+
# exists, hence the need for setindex!!
156+
if in_varinfo
157+
vi = setindex!!(vi, y, vn)
158+
else
159+
vi = push!!(vi, vn, y, dist)
160+
end
161+
# `accumulate_assume!!` wants untransformed values as the second argument.
162+
vi = accumulate_assume!!(vi, x, -logjac, vn, dist)
163+
# We always return the untransformed value here, as that will determine
164+
# what the lhs of the tilde-statement is set to.
165+
return x, vi
140166
end
141167

142-
# TODO: Remove this thing.
143-
# function assume(
144-
# rng::Random.AbstractRNG,
145-
# init_strategy::AbstractInitStrategy,
146-
# dist::Distribution,
147-
# vn::VarName,
148-
# vi::AbstractVarInfo,
168+
# """
169+
# set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
170+
# set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
171+
#
172+
# Take the values inside `initial_params`, replace the corresponding values in
173+
# the given VarInfo object, and return a new VarInfo object with the updated values.
174+
#
175+
# This differs from `DynamicPPL.unflatten` in two ways:
176+
#
177+
# 1. It works with `NamedTuple` arguments.
178+
# 2. For the `AbstractVector` method, if any of the elements are missing, it will not
179+
# overwrite the original value in the VarInfo (it will just use the original
180+
# value instead).
181+
# """
182+
# function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
183+
# throw(
184+
# ArgumentError(
185+
# "`initial_params` must be a vector of type `Union{Real,Missing}`. " *
186+
# "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.",
187+
# ),
188+
# )
189+
# end
190+
#
191+
# function set_initial_values(
192+
# varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
149193
# )
150-
# if haskey(vi, vn)
151-
# # Always overwrite the parameters with new ones for `SampleFromUniform`.
152-
# if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
153-
# # TODO(mhauru) Is it important to unset the flag here? The `true` allows us
154-
# # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
155-
# # if that's okay.
156-
# unset_flag!(vi, vn, "del", true)
157-
# r = init(rng, dist, sampler)
158-
# f = to_maybe_linked_internal_transform(vi, vn, dist)
159-
# # TODO(mhauru) This should probably be call a function called setindex_internal!
160-
# vi = BangBang.setindex!!(vi, f(r), vn)
161-
# setorder!(vi, vn, get_num_produce(vi))
162-
# else
163-
# # Otherwise we just extract it.
164-
# r = vi[vn, dist]
165-
# end
166-
# else
167-
# r = init(rng, dist, sampler)
168-
# if istrans(vi)
169-
# f = to_linked_internal_transform(vi, vn, dist)
170-
# vi = push!!(vi, vn, f(r), dist)
171-
# # By default `push!!` sets the transformed flag to `false`.
172-
# vi = settrans!!(vi, true, vn)
173-
# else
174-
# vi = push!!(vi, vn, r, dist)
194+
# flattened_param_vals = varinfo[:]
195+
# length(flattened_param_vals) == length(initial_params) || throw(
196+
# DimensionMismatch(
197+
# "Provided initial value size ($(length(initial_params))) doesn't match " *
198+
# "the model size ($(length(flattened_param_vals))).",
199+
# ),
200+
# )
201+
#
202+
# # Update values that are provided.
203+
# for i in eachindex(initial_params)
204+
# x = initial_params[i]
205+
# if x !== missing
206+
# flattened_param_vals[i] = x
175207
# end
176208
# end
177209
#
178-
# # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
179-
# logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
180-
# vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
181-
# return r, vi
210+
# # Update in `varinfo`.
211+
# new_varinfo = unflatten(varinfo, flattened_param_vals)
212+
# return new_varinfo
182213
# end
183-
184-
# function assume(
185-
# rng::Random.AbstractRNG,
186-
# sampler::Union{SampleFromPrior,SampleFromUniform},
187-
# dist::Distribution,
188-
# vn::VarName,
189-
# vi::SimpleOrThreadSafeSimple,
190-
# )
191-
# value = init(rng, dist, sampler)
192-
# # Transform if we're working in unconstrained space.
193-
# f = to_maybe_linked_internal_transform(vi, vn, dist)
194-
# value_raw, logjac = with_logabsdet_jacobian(f, value)
195-
# vi = BangBang.push!!(vi, vn, value_raw, dist)
196-
# vi = accumulate_assume!!(vi, value, -logjac, vn, dist)
197-
# return value, vi
198-
# end
199-
200-
# Initializations.
201-
# init(rng, dist, ::SampleFromPrior) = rand(rng, dist)
202-
# function init(rng, dist, ::SampleFromUniform)
203-
# return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist)
214+
#
215+
# function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
216+
# varinfo = deepcopy(varinfo)
217+
# vars_in_varinfo = keys(varinfo)
218+
# for v in keys(initial_params)
219+
# vn = VarName{v}()
220+
# if !(vn in vars_in_varinfo)
221+
# for vv in vars_in_varinfo
222+
# if subsumes(vn, vv)
223+
# throw(
224+
# ArgumentError(
225+
# "The current model contains sub-variables of $v, such as ($vv). " *
226+
# "Using NamedTuple for initial_params is not supported in such a case. " *
227+
# "Please use AbstractVector for initial_params instead of NamedTuple.",
228+
# ),
229+
# )
230+
# end
231+
# end
232+
# throw(ArgumentError("Variable $v not found in the model."))
233+
# end
234+
# end
235+
# initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
236+
# return update_values!!(
237+
# varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))
238+
# )
204239
# end
205240
#
206-
# init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n)
207-
# function init(rng, dist, ::SampleFromUniform, n::Int)
208-
# return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
241+
# function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model)
242+
# @debug "Using passed-in initial variable values" initial_params
243+
#
244+
# # `link` the varinfo if needed.
245+
# linked = islinked(vi)
246+
# if linked
247+
# vi = invlink!!(vi, model)
248+
# end
249+
#
250+
# # Set the values in `vi`.
251+
# vi = set_initial_values(vi, initial_params)
252+
#
253+
# # `invlink` if needed.
254+
# if linked
255+
# vi = link!!(vi, model)
256+
# end
257+
#
258+
# return vi
209259
# end

0 commit comments

Comments
 (0)