Skip to content

Commit bc16e09

Browse files
committed
WIP: InitContext
1 parent f20e86c commit bc16e09

File tree

13 files changed

+387
-459
lines changed

13 files changed

+387
-459
lines changed

src/DynamicPPL.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,14 @@ export AbstractVarInfo,
9797
values_as_in_model,
9898
# Samplers
9999
Sampler,
100-
SampleFromPrior,
101-
SampleFromUniform,
100+
# Initialisation strategies
101+
PriorInit,
102+
UniformInit,
103+
ParamsInit,
102104
# LogDensityFunction
103105
LogDensityFunction,
104106
# Contexts
105107
contextualize,
106-
SamplingContext,
107108
DefaultContext,
108109
PrefixContext,
109110
ConditionContext,
@@ -170,11 +171,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
170171
# Necessary forward declarations
171172
include("utils.jl")
172173
include("chains.jl")
174+
include("contexts.jl")
175+
include("contexts/init.jl")
173176
include("model.jl")
174177
include("sampler.jl")
175178
include("varname.jl")
176179
include("distribution_wrappers.jl")
177-
include("contexts.jl")
178180
include("submodel.jl")
179181
include("varnamedvector.jl")
180182
include("accumulators.jl")

src/context_implementations.jl

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,4 @@
11
# assume
2-
"""
3-
tilde_assume(context::SamplingContext, right, vn, vi)
4-
5-
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6-
accumulate the log probability, and return the sampled value with a context associated
7-
with a sampler.
8-
9-
Falls back to
10-
```julia
11-
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12-
```
13-
"""
14-
function tilde_assume(context::SamplingContext, right, vn, vi)
15-
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
16-
end
17-
182
function tilde_assume(context::AbstractContext, args...)
193
return tilde_assume(childcontext(context), args...)
204
end
@@ -71,17 +55,6 @@ function tilde_assume!!(context, right, vn, vi)
7155
end
7256

7357
# observe
74-
"""
75-
tilde_observe!!(context::SamplingContext, right, left, vi)
76-
77-
Handle observed constants with a `context` associated with a sampler.
78-
79-
Falls back to `tilde_observe!!(context.context, right, left, vi)`.
80-
"""
81-
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
82-
return tilde_observe!!(context.context, right, left, vn, vi)
83-
end
84-
8558
function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
8659
return tilde_observe!!(childcontext(context), right, left, vn, vi)
8760
end
@@ -127,46 +100,3 @@ function assume(dist::Distribution, vn::VarName, vi)
127100
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
128101
return x, vi
129102
end
130-
131-
# TODO: Remove this thing.
132-
# SampleFromPrior and SampleFromUniform
133-
function assume(
134-
rng::Random.AbstractRNG,
135-
sampler::Union{SampleFromPrior,SampleFromUniform},
136-
dist::Distribution,
137-
vn::VarName,
138-
vi::VarInfoOrThreadSafeVarInfo,
139-
)
140-
if haskey(vi, vn)
141-
# Always overwrite the parameters with new ones for `SampleFromUniform`.
142-
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
143-
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
144-
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
145-
# if that's okay.
146-
unset_flag!(vi, vn, "del", true)
147-
r = init(rng, dist, sampler)
148-
f = to_maybe_linked_internal_transform(vi, vn, dist)
149-
# TODO(mhauru) This should probably be call a function called setindex_internal!
150-
vi = BangBang.setindex!!(vi, f(r), vn)
151-
setorder!(vi, vn, get_num_produce(vi))
152-
else
153-
# Otherwise we just extract it.
154-
r = vi[vn, dist]
155-
end
156-
else
157-
r = init(rng, dist, sampler)
158-
if istrans(vi)
159-
f = to_linked_internal_transform(vi, vn, dist)
160-
vi = push!!(vi, vn, f(r), dist)
161-
# By default `push!!` sets the transformed flag to `false`.
162-
vi = settrans!!(vi, true, vn)
163-
else
164-
vi = push!!(vi, vn, r, dist)
165-
end
166-
end
167-
168-
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
169-
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
170-
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
171-
return r, vi
172-
end

src/contexts.jl

Lines changed: 1 addition & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ effectively updating the child context.
4747
```jldoctest
4848
julia> using DynamicPPL: DynamicTransformationContext
4949
50-
julia> ctx = SamplingContext();
50+
julia> ctx = ConditionContext((; a = 1);
5151
5252
julia> DynamicPPL.childcontext(ctx)
5353
DefaultContext()
@@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right
121121
setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
122122

123123
# Contexts
124-
"""
125-
SamplingContext(
126-
[rng::Random.AbstractRNG=Random.default_rng()],
127-
[sampler::AbstractSampler=SampleFromPrior()],
128-
[context::AbstractContext=DefaultContext()],
129-
)
130-
131-
Create a context that allows you to sample parameters with the `sampler` when running the model.
132-
The `context` determines how the returned log density is computed when running the model.
133-
134-
See also: [`DefaultContext`](@ref)
135-
"""
136-
struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext
137-
rng::R
138-
sampler::S
139-
context::C
140-
end
141-
142-
function SamplingContext(
143-
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
144-
)
145-
return SamplingContext(rng, sampler, DefaultContext())
146-
end
147-
148-
function SamplingContext(
149-
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
150-
)
151-
return SamplingContext(Random.default_rng(), sampler, context)
152-
end
153-
154-
function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
155-
return SamplingContext(rng, SampleFromPrior(), context)
156-
end
157-
158-
function SamplingContext(context::AbstractContext)
159-
return SamplingContext(Random.default_rng(), SampleFromPrior(), context)
160-
end
161-
162-
NodeTrait(context::SamplingContext) = IsParent()
163-
childcontext(context::SamplingContext) = context.context
164-
function setchildcontext(parent::SamplingContext, child)
165-
return SamplingContext(parent.rng, parent.sampler, child)
166-
end
167-
168-
"""
169-
hassampler(context)
170-
171-
Return `true` if `context` has a sampler.
172-
"""
173-
hassampler(::SamplingContext) = true
174-
hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context)
175-
hassampler(::IsLeaf, context::AbstractContext) = false
176-
hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context))
177-
178-
"""
179-
getsampler(context)
180-
181-
Return the sampler of the context `context`.
182-
183-
This will traverse the context tree until it reaches the first [`SamplingContext`](@ref),
184-
at which point it will return the sampler of that context.
185-
"""
186-
getsampler(context::SamplingContext) = context.sampler
187-
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
188-
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
189-
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
190-
191124
"""
192125
struct DefaultContext <: AbstractContext end
193126
@@ -280,41 +213,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName
280213
return vn, setchildcontext(ctx, new_ctx)
281214
end
282215

283-
"""
284-
prefix(model::Model, x::VarName)
285-
prefix(model::Model, x::Val{sym})
286-
prefix(model::Model, x::Any)
287-
288-
Return `model` but with all random variables prefixed by `x`, where `x` is either:
289-
- a `VarName` (e.g. `@varname(a)`),
290-
- a `Val{sym}` (e.g. `Val(:a)`), or
291-
- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that
292-
this will introduce runtime overheads so is not recommended unless absolutely
293-
necessary.
294-
295-
# Examples
296-
297-
```jldoctest
298-
julia> using DynamicPPL: prefix
299-
300-
julia> @model demo() = x ~ Dirac(1)
301-
demo (generic function with 2 methods)
302-
303-
julia> rand(prefix(demo(), @varname(my_prefix)))
304-
(var"my_prefix.x" = 1,)
305-
306-
julia> rand(prefix(demo(), Val(:my_prefix)))
307-
(var"my_prefix.x" = 1,)
308-
```
309-
"""
310-
prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context))
311-
function prefix(model::Model, x::Val{sym}) where {sym}
312-
return contextualize(model, PrefixContext(VarName{sym}(), model.context))
313-
end
314-
function prefix(model::Model, x)
315-
return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context))
316-
end
317-
318216
"""
319217
320218
ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}

0 commit comments

Comments
 (0)