Skip to content

Commit 26f96a7

Browse files
committed
Remove SamplingContext for good
1 parent a392451 commit 26f96a7

15 files changed

+17
-445
lines changed

docs/src/api.md

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -456,17 +456,12 @@ AbstractPPL.evaluate!!
456456

457457
This method mutates the `varinfo` used for execution.
458458
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
459-
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:
460-
461-
```@docs
462-
DynamicPPL.evaluate_and_sample!!
463-
```
459+
If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this.
464460

465461
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
466462
Contexts are subtypes of `AbstractPPL.AbstractContext`.
467463

468464
```@docs
469-
SamplingContext
470465
DefaultContext
471466
PrefixContext
472467
ConditionContext
@@ -500,15 +495,7 @@ DynamicPPL.init
500495

501496
### Samplers
502497

503-
In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
504-
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.
505-
506-
```@docs
507-
SampleFromPrior
508-
SampleFromUniform
509-
```
510-
511-
Additionally, a generic sampler for inference is implemented.
498+
In DynamicPPL a generic sampler for inference is implemented.
512499

513500
```@docs
514501
Sampler

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ else
88
using ..EnzymeCore
99
end
1010

11-
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true
12-
1311
# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
1412
# only checks whether such a method exists, and never runs it.
1513
@inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) =

src/DynamicPPL.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,10 @@ export AbstractVarInfo,
9797
values_as_in_model,
9898
# Samplers
9999
Sampler,
100-
SampleFromPrior,
101-
SampleFromUniform,
102100
# LogDensityFunction
103101
LogDensityFunction,
104102
# Contexts
105103
contextualize,
106-
SamplingContext,
107104
DefaultContext,
108105
PrefixContext,
109106
ConditionContext,

src/context_implementations.jl

Lines changed: 5 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,14 @@
11
# assume
2-
"""
3-
tilde_assume(context::SamplingContext, right, vn, vi)
4-
5-
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6-
accumulate the log probability, and return the sampled value with a context associated
7-
with a sampler.
8-
9-
Falls back to
10-
```julia
11-
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12-
```
13-
"""
14-
function tilde_assume(context::SamplingContext, right, vn, vi)
15-
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
16-
end
17-
182
function tilde_assume(context::AbstractContext, args...)
193
return tilde_assume(childcontext(context), args...)
204
end
215
function tilde_assume(::DefaultContext, right, vn, vi)
22-
return assume(right, vn, vi)
23-
end
24-
25-
function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...)
26-
return tilde_assume(rng, childcontext(context), args...)
27-
end
28-
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
29-
return assume(rng, sampler, right, vn, vi)
30-
end
31-
function tilde_assume(rng::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi)
32-
@warn(
33-
"Encountered SamplingContext->InitContext. This method will be removed in the next PR.",
34-
)
35-
# just pretend the `InitContext` isn't there for now.
36-
return assume(rng, sampler, right, vn, vi)
37-
end
38-
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
39-
# same as above but no rng
40-
return assume(Random.default_rng(), sampler, right, vn, vi)
6+
y = getindex_internal(vi, vn)
7+
f = from_maybe_linked_internal_transform(vi, vn, right)
8+
x, logjac = with_logabsdet_jacobian(f, y)
9+
vi = accumulate_assume!!(vi, x, logjac, vn, right)
10+
return x, vi
4111
end
42-
4312
function tilde_assume(context::PrefixContext, right, vn, vi)
4413
# Note that we can't use something like this here:
4514
# new_vn = prefix(context, vn)
@@ -53,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
5322
new_vn, new_context = prefix_and_strip_contexts(context, vn)
5423
return tilde_assume(new_context, right, new_vn, vi)
5524
end
56-
function tilde_assume(
57-
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
58-
)
59-
new_vn, new_context = prefix_and_strip_contexts(context, vn)
60-
return tilde_assume(rng, new_context, sampler, right, new_vn, vi)
61-
end
6225

6326
"""
6427
tilde_assume!!(context, right, vn, vi)
@@ -78,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
7841
end
7942

8043
# observe
81-
"""
82-
tilde_observe!!(context::SamplingContext, right, left, vi)
83-
84-
Handle observed constants with a `context` associated with a sampler.
85-
86-
Falls back to `tilde_observe!!(context.context, right, left, vi)`.
87-
"""
88-
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
89-
return tilde_observe!!(context.context, right, left, vn, vi)
90-
end
91-
9244
function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
9345
return tilde_observe!!(childcontext(context), right, left, vn, vi)
9446
end
@@ -121,59 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
12173
vi = accumulate_observe!!(vi, right, left, vn)
12274
return left, vi
12375
end
124-
125-
function assume(::Random.AbstractRNG, spl::Sampler, dist)
126-
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
127-
end
128-
129-
# fallback without sampler
130-
function assume(dist::Distribution, vn::VarName, vi)
131-
y = getindex_internal(vi, vn)
132-
f = from_maybe_linked_internal_transform(vi, vn, dist)
133-
x, logjac = with_logabsdet_jacobian(f, y)
134-
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
135-
return x, vi
136-
end
137-
138-
# TODO: Remove this thing.
139-
# SampleFromPrior and SampleFromUniform
140-
function assume(
141-
rng::Random.AbstractRNG,
142-
sampler::Union{SampleFromPrior,SampleFromUniform},
143-
dist::Distribution,
144-
vn::VarName,
145-
vi::VarInfoOrThreadSafeVarInfo,
146-
)
147-
if haskey(vi, vn)
148-
# Always overwrite the parameters with new ones for `SampleFromUniform`.
149-
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
150-
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
151-
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
152-
# if that's okay.
153-
unset_flag!(vi, vn, "del", true)
154-
r = init(rng, dist, sampler)
155-
f = to_maybe_linked_internal_transform(vi, vn, dist)
156-
# TODO(mhauru) This should probably be call a function called setindex_internal!
157-
vi = BangBang.setindex!!(vi, f(r), vn)
158-
setorder!(vi, vn, get_num_produce(vi))
159-
else
160-
# Otherwise we just extract it.
161-
r = vi[vn, dist]
162-
end
163-
else
164-
r = init(rng, dist, sampler)
165-
if istrans(vi)
166-
f = to_linked_internal_transform(vi, vn, dist)
167-
vi = push!!(vi, vn, f(r), dist)
168-
# By default `push!!` sets the transformed flag to `false`.
169-
vi = settrans!!(vi, true, vn)
170-
else
171-
vi = push!!(vi, vn, r, dist)
172-
end
173-
end
174-
175-
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
176-
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
177-
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
178-
return r, vi
179-
end

src/contexts.jl

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ effectively updating the child context.
4747
```jldoctest
4848
julia> using DynamicPPL: DynamicTransformationContext
4949
50-
julia> ctx = SamplingContext();
50+
julia> ctx = ConditionContext((; a = 1);
5151
5252
julia> DynamicPPL.childcontext(ctx)
5353
DefaultContext()
@@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right
121121
setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
122122

123123
# Contexts
124-
"""
125-
SamplingContext(
126-
[rng::Random.AbstractRNG=Random.default_rng()],
127-
[sampler::AbstractSampler=SampleFromPrior()],
128-
[context::AbstractContext=DefaultContext()],
129-
)
130-
131-
Create a context that allows you to sample parameters with the `sampler` when running the model.
132-
The `context` determines how the returned log density is computed when running the model.
133-
134-
See also: [`DefaultContext`](@ref)
135-
"""
136-
struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext
137-
rng::R
138-
sampler::S
139-
context::C
140-
end
141-
142-
function SamplingContext(
143-
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
144-
)
145-
return SamplingContext(rng, sampler, DefaultContext())
146-
end
147-
148-
function SamplingContext(
149-
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
150-
)
151-
return SamplingContext(Random.default_rng(), sampler, context)
152-
end
153-
154-
function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
155-
return SamplingContext(rng, SampleFromPrior(), context)
156-
end
157-
158-
function SamplingContext(context::AbstractContext)
159-
return SamplingContext(Random.default_rng(), SampleFromPrior(), context)
160-
end
161-
162-
NodeTrait(context::SamplingContext) = IsParent()
163-
childcontext(context::SamplingContext) = context.context
164-
function setchildcontext(parent::SamplingContext, child)
165-
return SamplingContext(parent.rng, parent.sampler, child)
166-
end
167-
168-
"""
169-
hassampler(context)
170-
171-
Return `true` if `context` has a sampler.
172-
"""
173-
hassampler(::SamplingContext) = true
174-
hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context)
175-
hassampler(::IsLeaf, context::AbstractContext) = false
176-
hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context))
177-
178-
"""
179-
getsampler(context)
180-
181-
Return the sampler of the context `context`.
182-
183-
This will traverse the context tree until it reaches the first [`SamplingContext`](@ref),
184-
at which point it will return the sampler of that context.
185-
"""
186-
getsampler(context::SamplingContext) = context.sampler
187-
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
188-
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
189-
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
190-
191124
"""
192125
struct DefaultContext <: AbstractContext end
193126

src/debug_utils.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,7 @@ end
255255

256256
function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
257257
record_pre_tilde_assume!(context, vn, right, vi)
258-
value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
259-
record_post_tilde_assume!(context, vn, right, value, vi)
260-
return value, vi
261-
end
262-
function DynamicPPL.tilde_assume(
263-
rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
264-
)
265-
record_pre_tilde_assume!(context, vn, right, vi)
266-
value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
258+
value, vi = DynamicPPL.tilde_assume!!(childcontext(context), right, vn, vi)
267259
record_post_tilde_assume!(context, vn, right, value, vi)
268260
return value, vi
269261
end
@@ -438,9 +430,10 @@ function check_model_and_trace(
438430
kwargs...,
439431
)
440432
# Execute the model with the debug context.
441-
debug_context = DebugContext(
442-
SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs...
433+
new_context = DynamicPPL.setleafcontext(
434+
model.context, DynamicPPL.InitContext(rng, DynamicPPL.PriorInit())
443435
)
436+
debug_context = DebugContext(new_context; error_on_failure=error_on_failure, kwargs...)
444437
debug_model = DynamicPPL.contextualize(model, debug_context)
445438

446439
# Perform checks before evaluating the model.

src/sampler.jl

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,3 @@
1-
# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler`
2-
# That would let us use all defaults for Sampler, combine it with other samplers etc.
3-
"""
4-
SampleFromUniform
5-
6-
Sampling algorithm that samples unobserved random variables from a uniform distribution.
7-
8-
# References
9-
10-
[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values)
11-
"""
12-
struct SampleFromUniform <: AbstractSampler end
13-
14-
"""
15-
SampleFromPrior
16-
17-
Sampling algorithm that samples unobserved random variables from their prior distribution.
18-
"""
19-
struct SampleFromPrior <: AbstractSampler end
20-
21-
# Initializations.
22-
init(rng, dist, ::SampleFromPrior) = rand(rng, dist)
23-
function init(rng, dist, ::SampleFromUniform)
24-
return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist)
25-
end
26-
27-
init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n)
28-
function init(rng, dist, ::SampleFromUniform, n::Int)
29-
return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
30-
end
31-
321
# TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`?
332
# (Selector has been removed).
343
"""
@@ -49,20 +18,6 @@ struct Sampler{T} <: AbstractSampler
4918
alg::T
5019
end
5120

52-
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
53-
function AbstractMCMC.step(
54-
rng::Random.AbstractRNG,
55-
model::Model,
56-
sampler::Union{SampleFromUniform,SampleFromPrior},
57-
state=nothing;
58-
kwargs...,
59-
)
60-
vi = VarInfo()
61-
strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit()
62-
DynamicPPL.init!!(rng, model, vi, strategy)
63-
return vi, nothing
64-
end
65-
6621
"""
6722
default_varinfo(rng, model, sampler)
6823

src/simple_varinfo.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -457,23 +457,6 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
457457
end
458458

459459
# Context implementations
460-
# NOTE: Evaluations, i.e. those without `rng` are shared with other
461-
# implementations of `AbstractVarInfo`.
462-
function assume(
463-
rng::Random.AbstractRNG,
464-
sampler::Union{SampleFromPrior,SampleFromUniform},
465-
dist::Distribution,
466-
vn::VarName,
467-
vi::SimpleOrThreadSafeSimple,
468-
)
469-
value = init(rng, dist, sampler)
470-
# Transform if we're working in unconstrained space.
471-
f = to_maybe_linked_internal_transform(vi, vn, dist)
472-
value_raw, logjac = with_logabsdet_jacobian(f, value)
473-
vi = BangBang.push!!(vi, vn, value_raw, dist)
474-
vi = accumulate_assume!!(vi, value, -logjac, vn, dist)
475-
return value, vi
476-
end
477460

478461
function settrans!!(vi::SimpleVarInfo, trans)
479462
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())

0 commit comments

Comments
 (0)