Skip to content

Commit 57bf292

Browse files
committed
Implement InitContext
1 parent 90e95e1 commit 57bf292

File tree

4 files changed

+225
-1
lines changed

4 files changed

+225
-1
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ include("sampler.jl")
175175
include("varname.jl")
176176
include("distribution_wrappers.jl")
177177
include("contexts.jl")
178+
include("contexts/init.jl")
178179
include("submodel.jl")
179180
include("varnamedvector.jl")
180181
include("accumulators.jl")

src/contexts/init.jl

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
AbstractInitStrategy
3+
4+
Abstract type representing the possible ways of initialising new values for
5+
the random variables in a model (e.g., when creating a new VarInfo).
6+
"""
7+
abstract type AbstractInitStrategy end
8+
9+
"""
10+
init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy)
11+
12+
Generate a new value for a random variable with the given distribution.
13+
14+
!!! warning "Values must be unlinked"
15+
The values returned by `init` are always in the untransformed space, i.e.,
16+
they must be within the support of the original distribution. That means that,
17+
for example, `init(rng, dist, u::UniformInit)` will in general return values that
18+
are outside the range [u.lower, u.upper].
19+
"""
20+
function init end
21+
22+
"""
23+
PriorInit()
24+
25+
Obtain new values by sampling from the prior distribution.
26+
"""
27+
struct PriorInit <: AbstractInitStrategy end
28+
init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist)
29+
30+
"""
31+
UniformInit()
32+
UniformInit(lower, upper)
33+
34+
Obtain new values by first transforming the distribution of the random variable
35+
to unconstrained space, and then sampling a value uniformly between `lower` and
36+
`upper`.
37+
38+
If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's
39+
default initialisation strategy.
40+
41+
# References
42+
43+
[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization)
44+
"""
45+
struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy
46+
lower::T
47+
upper::T
48+
function UniformInit(lower::T, upper::T) where {T<:AbstractFloat}
49+
lower > upper &&
50+
throw(ArgumentError("`lower` must be less than or equal to `upper`"))
51+
return new{T}(lower, upper)
52+
end
53+
UniformInit() = UniformInit(-2.0, 2.0)
54+
end
55+
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit)
56+
b = Bijectors.bijector(dist)
57+
sz = Bijectors.output_size(b, size(dist))
58+
y = rand(rng, Uniform(u.lower, u.upper), sz)
59+
b_inv = Bijectors.inverse(b)
60+
x = b_inv(y)
61+
# 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398
62+
if x isa Array{<:Any,0}
63+
x = x[]
64+
end
65+
return x
66+
end
67+
68+
"""
69+
ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit())
70+
ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit())
71+
72+
Obtain new values by extracting them from the given dictionary or NamedTuple.
73+
The parameter `default` specifies how new values are to be obtained if they
74+
cannot be found in `params`, or they are specified as `missing`. The default
75+
for `default` is `PriorInit()`.
76+
77+
!!! note
78+
These values must be provided in the space of the untransformed distribution.
79+
"""
80+
struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy
81+
params::P
82+
default::S
83+
function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy)
84+
return new{typeof(params),typeof(default)}(params, default)
85+
end
86+
ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit())
87+
function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit())
88+
return ParamsInit(to_varname_dict(params), default)
89+
end
90+
end
91+
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit)
92+
# TODO(penelopeysm): We should do a check to make sure that all of the
93+
# parameters in `p.params` were actually used, and either warn or error if
94+
# they aren't. This is non-trivial (we need to use something like
95+
# varname_leaves), so I'm going to defer it to a later PR.
96+
return if hasvalue(p.params, vn, dist)
97+
x = getvalue(p.params, vn, dist)
98+
if x === missing
99+
init(rng, vn, dist, p.default)
100+
else
101+
# TODO(penelopeysm): We could also check that the type of x matches
102+
# the dist?
103+
x
104+
end
105+
else
106+
init(rng, vn, dist, p.default)
107+
end
108+
end
109+
110+
"""
111+
InitContext(
112+
[rng::Random.AbstractRNG=Random.default_rng()],
113+
[strategy::AbstractInitStrategy=PriorInit()],
114+
)
115+
116+
A leaf context that indicates that new values for random variables are
117+
currently being obtained through sampling. Used e.g. when initialising a fresh
118+
VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then
119+
`evaluate!!(model, varinfo)` will override all values in the VarInfo.
120+
"""
121+
struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext
122+
rng::R
123+
strategy::S
124+
function InitContext(
125+
rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit()
126+
)
127+
return new{typeof(rng),typeof(strategy)}(rng, strategy)
128+
end
129+
function InitContext(strategy::AbstractInitStrategy=PriorInit())
130+
return InitContext(Random.default_rng(), strategy)
131+
end
132+
end
133+
NodeTrait(::InitContext) = IsLeaf()
134+
135+
function tilde_assume(
136+
ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
137+
)
138+
in_varinfo = haskey(vi, vn)
139+
# `init()` always returns values in original space, i.e. possibly
140+
# constrained
141+
x = init(ctx.rng, vn, dist, ctx.strategy)
142+
# Determine whether to insert a transformed value into the VarInfo.
143+
# If the VarInfo alrady had a value for this variable, we will
144+
# keep the same linked status as in the original VarInfo. If not, we
145+
# check the rest of the VarInfo to see if other variables are linked.
146+
# istrans(vi) returns true if vi is nonempty and all variables in vi
147+
# are linked.
148+
insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi)
149+
f = if insert_transformed_value
150+
to_linked_internal_transform(vi, vn, dist)
151+
else
152+
to_internal_transform(vi, vn, dist)
153+
end
154+
# TODO(penelopeysm): We would really like to do:
155+
# y, logjac = with_logabsdet_jacobian(f, x)
156+
# Unfortunately, `to_{linked_}internal_transform` returns a function that
157+
# always converts x to a vector, i.e., if dist is univariate, f(x) will be
158+
# a vector of length 1. It would be nice if we could unify these.
159+
y = f(x)
160+
logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x)
161+
# Add the new value to the VarInfo. `push!!` errors if the value already
162+
# exists, hence the need for setindex!!.
163+
if in_varinfo
164+
vi = setindex!!(vi, y, vn)
165+
else
166+
vi = push!!(vi, vn, y, dist)
167+
end
168+
# Neither of these set the `trans` flag so we have to do it manually if
169+
# necessary.
170+
insert_transformed_value && settrans!!(vi, true, vn)
171+
# `accumulate_assume!!` wants untransformed values as the second argument.
172+
vi = accumulate_assume!!(vi, x, -logjac, vn, dist)
173+
# We always return the untransformed value here, as that will determine
174+
# what the lhs of the tilde-statement is set to.
175+
return x, vi
176+
end
177+
178+
function tilde_observe!!(::InitContext, right, left, vn, vi)
179+
return tilde_observe!!(DefaultContext(), right, left, vn, vi)
180+
end

src/model.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,39 @@ function evaluate_and_sample!!(
854854
return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler)
855855
end
856856

857+
"""
858+
init!!(
859+
[rng::Random.AbstractRNG,]
860+
model::Model,
861+
varinfo::AbstractVarInfo,
862+
[init_strategy::AbstractInitStrategy=PriorInit()]
863+
)
864+
865+
Evaluate the `model` and replace the values of the model's random variables in
866+
the given `varinfo` with new values using a specified initialisation strategy.
867+
If the values in `varinfo` are not already present, they will be added using
868+
that same strategy.
869+
870+
If `init_strategy` is not provided, defaults to PriorInit().
871+
872+
Returns a tuple of the model's return value, plus the updated `varinfo` object.
873+
"""
874+
function init!!(
875+
rng::Random.AbstractRNG,
876+
model::Model,
877+
varinfo::AbstractVarInfo,
878+
init_strategy::AbstractInitStrategy=PriorInit(),
879+
)
880+
new_context = setleafcontext(model.context, InitContext(rng, init_strategy))
881+
new_model = contextualize(model, new_context)
882+
return evaluate!!(new_model, varinfo)
883+
end
884+
function init!!(
885+
model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit()
886+
)
887+
return init!!(Random.default_rng(), model, varinfo, init_strategy)
888+
end
889+
857890
"""
858891
evaluate!!(model::Model, varinfo)
859892

test/contexts.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Test, DynamicPPL, Accessors
2-
using AbstractPPL: getoptic
2+
using AbstractPPL: getoptic, hasvalue, getvalue
33
using DynamicPPL:
44
leafcontext,
55
setleafcontext,
@@ -431,4 +431,14 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
431431
@test fixed(c6) == Dict(@varname(a.b.d) => 2)
432432
end
433433
end
434+
435+
@testset "InitContext" begin
436+
@testset "PriorInit" begin end
437+
438+
@testset "UniformInit" begin end
439+
440+
@testset "ParamsInit" begin end
441+
442+
@testset "rng is respected (at least with PriorInit" begin end
443+
end
434444
end

0 commit comments

Comments
 (0)