Skip to content

Commit 25c1db7

Browse files
committed
[no ci] The Rest
1 parent 7265885 commit 25c1db7

18 files changed

+120
-559
lines changed

docs/src/api.md

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,6 @@ The behaviour of a model execution can be changed with evaluation contexts, whic
461461
Contexts are subtypes of `AbstractPPL.AbstractContext`.
462462

463463
```@docs
464-
SamplingContext
465464
DefaultContext
466465
PrefixContext
467466
ConditionContext
@@ -490,15 +489,7 @@ DynamicPPL.init
490489

491490
### Samplers
492491

493-
In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
494-
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.
495-
496-
```@docs
497-
SampleFromPrior
498-
SampleFromUniform
499-
```
500-
501-
Additionally, a generic sampler for inference is implemented.
492+
In DynamicPPL a generic sampler for inference is implemented.
502493

503494
```@docs
504495
Sampler
@@ -509,7 +500,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
509500
```@docs
510501
DynamicPPL.initialstep
511502
DynamicPPL.loadstate
512-
DynamicPPL.initialsampler
503+
DynamicPPL.init_strategy
513504
```
514505

515506
Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ else
88
using ..EnzymeCore
99
end
1010

11-
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true
12-
1311
# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
1412
# only checks whether such a method exists, and never runs it.
1513
@inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) =

ext/DynamicPPLJETExt.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,12 @@ end
2121
function DynamicPPL.Experimental._determine_varinfo_jet(
2222
model::DynamicPPL.Model; only_ddpl::Bool=true
2323
)
24-
# Use SamplingContext to test type stability.
25-
sampling_model = DynamicPPL.contextualize(
26-
model, DynamicPPL.SamplingContext(model.context)
27-
)
28-
2924
# First we try with the typed varinfo.
30-
varinfo = DynamicPPL.typed_varinfo(sampling_model)
25+
varinfo = DynamicPPL.typed_varinfo(model)
3126

3227
# Let's make sure that both evaluation and sampling doesn't result in type errors.
3328
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34-
sampling_model, varinfo; only_ddpl
29+
model, varinfo; only_ddpl
3530
)
3631

3732
if !issuccess
@@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
4641
else
4742
# Warn the user that we can't use the type stable one.
4843
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49-
DynamicPPL.untyped_varinfo(sampling_model)
44+
DynamicPPL.untyped_varinfo(model)
5045
end
5146
end
5247

src/DynamicPPL.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,10 @@ export AbstractVarInfo,
9797
values_as_in_model,
9898
# Samplers
9999
Sampler,
100-
SampleFromPrior,
101-
SampleFromUniform,
102100
# LogDensityFunction
103101
LogDensityFunction,
104102
# Contexts
105103
contextualize,
106-
SamplingContext,
107104
DefaultContext,
108105
PrefixContext,
109106
ConditionContext,

src/context_implementations.jl

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,11 @@
11
# assume
2-
"""
3-
tilde_assume(context::SamplingContext, right, vn, vi)
4-
5-
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6-
accumulate the log probability, and return the sampled value with a context associated
7-
with a sampler.
8-
9-
Falls back to
10-
```julia
11-
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12-
```
13-
"""
14-
function tilde_assume(context::SamplingContext, right, vn, vi)
15-
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
16-
end
17-
182
function tilde_assume(context::AbstractContext, args...)
193
return tilde_assume(childcontext(context), args...)
204
end
215
function tilde_assume(::DefaultContext, right, vn, vi)
226
return assume(right, vn, vi)
237
end
248

25-
function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...)
26-
return tilde_assume(rng, childcontext(context), args...)
27-
end
28-
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
29-
return assume(rng, sampler, right, vn, vi)
30-
end
31-
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
32-
# same as above but no rng
33-
return assume(Random.default_rng(), sampler, right, vn, vi)
34-
end
35-
369
function tilde_assume(context::PrefixContext, right, vn, vi)
3710
# Note that we can't use something like this here:
3811
# new_vn = prefix(context, vn)
@@ -46,12 +19,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
4619
new_vn, new_context = prefix_and_strip_contexts(context, vn)
4720
return tilde_assume(new_context, right, new_vn, vi)
4821
end
49-
function tilde_assume(
50-
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
51-
)
52-
new_vn, new_context = prefix_and_strip_contexts(context, vn)
53-
return tilde_assume(rng, new_context, sampler, right, new_vn, vi)
54-
end
5522

5623
"""
5724
tilde_assume!!(context, right, vn, vi)
@@ -71,17 +38,6 @@ function tilde_assume!!(context, right, vn, vi)
7138
end
7239

7340
# observe
74-
"""
75-
tilde_observe!!(context::SamplingContext, right, left, vi)
76-
77-
Handle observed constants with a `context` associated with a sampler.
78-
79-
Falls back to `tilde_observe!!(context.context, right, left, vi)`.
80-
"""
81-
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
82-
return tilde_observe!!(context.context, right, left, vn, vi)
83-
end
84-
8541
function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
8642
return tilde_observe!!(childcontext(context), right, left, vn, vi)
8743
end
@@ -115,10 +71,6 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
11571
return left, vi
11672
end
11773

118-
function assume(::Random.AbstractRNG, spl::Sampler, dist)
119-
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
120-
end
121-
12274
# fallback without sampler
12375
function assume(dist::Distribution, vn::VarName, vi)
12476
y = getindex_internal(vi, vn)
@@ -127,46 +79,3 @@ function assume(dist::Distribution, vn::VarName, vi)
12779
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
12880
return x, vi
12981
end
130-
131-
# TODO: Remove this thing.
132-
# SampleFromPrior and SampleFromUniform
133-
function assume(
134-
rng::Random.AbstractRNG,
135-
sampler::Union{SampleFromPrior,SampleFromUniform},
136-
dist::Distribution,
137-
vn::VarName,
138-
vi::VarInfoOrThreadSafeVarInfo,
139-
)
140-
if haskey(vi, vn)
141-
# Always overwrite the parameters with new ones for `SampleFromUniform`.
142-
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
143-
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
144-
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
145-
# if that's okay.
146-
unset_flag!(vi, vn, "del", true)
147-
r = init(rng, dist, sampler)
148-
f = to_maybe_linked_internal_transform(vi, vn, dist)
149-
# TODO(mhauru) This should probably be call a function called setindex_internal!
150-
vi = BangBang.setindex!!(vi, f(r), vn)
151-
setorder!(vi, vn, get_num_produce(vi))
152-
else
153-
# Otherwise we just extract it.
154-
r = vi[vn, dist]
155-
end
156-
else
157-
r = init(rng, dist, sampler)
158-
if istrans(vi)
159-
f = to_linked_internal_transform(vi, vn, dist)
160-
vi = push!!(vi, vn, f(r), dist)
161-
# By default `push!!` sets the transformed flag to `false`.
162-
vi = settrans!!(vi, true, vn)
163-
else
164-
vi = push!!(vi, vn, r, dist)
165-
end
166-
end
167-
168-
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
169-
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
170-
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
171-
return r, vi
172-
end

src/contexts.jl

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ effectively updating the child context.
4747
```jldoctest
4848
julia> using DynamicPPL: DynamicTransformationContext
4949
50-
julia> ctx = SamplingContext();
50+
julia> ctx = ConditionContext((; a = 1);
5151
5252
julia> DynamicPPL.childcontext(ctx)
5353
DefaultContext()
@@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right
121121
setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
122122

123123
# Contexts
124-
"""
125-
SamplingContext(
126-
[rng::Random.AbstractRNG=Random.default_rng()],
127-
[sampler::AbstractSampler=SampleFromPrior()],
128-
[context::AbstractContext=DefaultContext()],
129-
)
130-
131-
Create a context that allows you to sample parameters with the `sampler` when running the model.
132-
The `context` determines how the returned log density is computed when running the model.
133-
134-
See also: [`DefaultContext`](@ref)
135-
"""
136-
struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext
137-
rng::R
138-
sampler::S
139-
context::C
140-
end
141-
142-
function SamplingContext(
143-
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
144-
)
145-
return SamplingContext(rng, sampler, DefaultContext())
146-
end
147-
148-
function SamplingContext(
149-
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
150-
)
151-
return SamplingContext(Random.default_rng(), sampler, context)
152-
end
153-
154-
function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
155-
return SamplingContext(rng, SampleFromPrior(), context)
156-
end
157-
158-
function SamplingContext(context::AbstractContext)
159-
return SamplingContext(Random.default_rng(), SampleFromPrior(), context)
160-
end
161-
162-
NodeTrait(context::SamplingContext) = IsParent()
163-
childcontext(context::SamplingContext) = context.context
164-
function setchildcontext(parent::SamplingContext, child)
165-
return SamplingContext(parent.rng, parent.sampler, child)
166-
end
167-
168-
"""
169-
hassampler(context)
170-
171-
Return `true` if `context` has a sampler.
172-
"""
173-
hassampler(::SamplingContext) = true
174-
hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context)
175-
hassampler(::IsLeaf, context::AbstractContext) = false
176-
hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context))
177-
178-
"""
179-
getsampler(context)
180-
181-
Return the sampler of the context `context`.
182-
183-
This will traverse the context tree until it reaches the first [`SamplingContext`](@ref),
184-
at which point it will return the sampler of that context.
185-
"""
186-
getsampler(context::SamplingContext) = context.sampler
187-
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
188-
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
189-
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
190-
191124
"""
192125
struct DefaultContext <: AbstractContext end
193126

src/debug_utils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,10 @@ function check_model_and_trace(
438438
kwargs...,
439439
)
440440
# Execute the model with the debug context.
441-
debug_context = DebugContext(
442-
SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs...
441+
new_context = DynamicPPL.setleafcontext(
442+
model.context, DynamicPPL.InitContext(rng, DynamicPPL.PriorInit())
443443
)
444+
debug_context = DebugContext(new_context; error_on_failure=error_on_failure, kwargs...)
444445
debug_model = DynamicPPL.contextualize(model, debug_context)
445446

446447
# Perform checks before evaluating the model.

0 commit comments

Comments
 (0)