Skip to content

Commit 247e53b

Browse files
committed
Replace evaluate_and_sample!! -> init!!
1 parent f1d5f20 commit 247e53b

File tree

12 files changed

+176
-153
lines changed

12 files changed

+176
-153
lines changed

docs/src/api.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,6 @@ AbstractPPL.evaluate!!
456456

457457
This method mutates the `varinfo` used for execution.
458458
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
459-
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:
460-
461-
```@docs
462-
DynamicPPL.evaluate_and_sample!!
463-
```
464459

465460
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
466461
Contexts are subtypes of `AbstractPPL.AbstractContext`.

src/extract_priors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function extract_priors(rng::Random.AbstractRNG, model::Model)
116116
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
117117
# can't push new variables without knowing the num_produce. Remove this when possible.
118118
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator()))
119-
varinfo = last(evaluate_and_sample!!(rng, model, varinfo))
119+
varinfo = last(init!!(rng, model, varinfo))
120120
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
121121
end
122122

src/model.jl

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ end
850850
# ^ Weird Documenter.jl bug means that we have to write the two above separately
851851
# as it can only detect the `function`-less syntax.
852852
function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo())
853-
return first(evaluate_and_sample!!(rng, model, varinfo))
853+
return first(init!!(rng, model, varinfo))
854854
end
855855

856856
"""
@@ -864,29 +864,35 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
864864
end
865865

866866
"""
867-
evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])
868-
869-
Evaluate the `model` with the given `varinfo`, but perform sampling during the
870-
evaluation using the given `sampler` by wrapping the model's context in a
871-
`SamplingContext`.
867+
init!!(
868+
[rng::Random.AbstractRNG, ]
869+
model::Model,
870+
varinfo::AbstractVarInfo,
871+
[init_strategy::AbstractInitStrategy=PriorInit()]
872+
)
872873
873-
If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref).
874+
Evaluate the `model` and replace the values of the model's random variables
875+
in the given `varinfo` with new values, using a specified initialisation strategy.
876+
If the values in `varinfo` are not set, they will be added.
877+
using a specified initialisation strategy. If `init_strategy` is not provided,
878+
defaults to PriorInit().
874879
875880
Returns a tuple of the model's return value, plus the updated `varinfo` object.
876881
"""
877-
function evaluate_and_sample!!(
882+
function init!!(
878883
rng::Random.AbstractRNG,
879884
model::Model,
880885
varinfo::AbstractVarInfo,
881-
sampler::AbstractSampler=SampleFromPrior(),
886+
init_strategy::AbstractInitStrategy=PriorInit(),
882887
)
883-
sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context))
884-
return evaluate!!(sampling_model, varinfo)
888+
new_context = setleafcontext(model.context, InitContext(rng, init_strategy))
889+
new_model = contextualize(model, new_context)
890+
return evaluate!!(new_model, varinfo)
885891
end
886-
function evaluate_and_sample!!(
887-
model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior()
892+
function init!!(
893+
model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit()
888894
)
889-
return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler)
895+
return init!!(Random.default_rng(), model, varinfo, init_strategy)
890896
end
891897

892898
"""
@@ -1049,11 +1055,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
10491055
Generate a sample of type `T` from the prior distribution of the `model`.
10501056
"""
10511057
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
1052-
x = last(
1053-
evaluate_and_sample!!(
1054-
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
1055-
),
1056-
)
1058+
x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())))
10571059
return values_as(x, T)
10581060
end
10591061

src/simple_varinfo.jl

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ julia> rng = StableRNG(42);
3939
julia> # In the `NamedTuple` version we need to provide the place-holder values for
4040
# the variables which are using "containers", e.g. `Array`.
4141
# In this case, this means that we need to specify `x` but not `m`.
42-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), )));
42+
_, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), )));
4343
4444
julia> # (✓) Vroom, vroom! FAST!!!
4545
vi[@varname(x[1])]
@@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])]
5757
1.3736306979834252
5858
5959
julia> # (×) If we don't provide the container...
60-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi
60+
_, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); vi
6161
ERROR: type NamedTuple has no field x
6262
[...]
6363
6464
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
65-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
65+
_, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
6666
6767
julia> # (✓) Sort of fast, but only possible at runtime.
6868
vi[@varname(x[1])]
@@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods)
9191
9292
julia> m = demo_constrained();
9393
94-
julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo());
94+
julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo());
9595
9696
julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
9797
1.8632965762164932
9898
99-
julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
99+
julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
100100
101101
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
102102
-0.21080155351918753
103103
104-
julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
104+
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
105105
106106
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107107
true
108108
109109
julia> # And with `OrderedDict` of course!
110-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
110+
_, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111111
112112
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
113113
0.6225185067787314
114114
115-
julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
115+
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
116116
117117
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
118118
true
@@ -226,24 +226,25 @@ end
226226

227227
# Constructor from `Model`.
228228
function SimpleVarInfo{T}(
229-
rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior()
229+
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
230230
) where {T<:Real}
231-
new_model = contextualize(model, SamplingContext(rng, sampler, model.context))
231+
new_context = setleafcontext(model.context, InitContext(rng, init_strategy))
232+
new_model = contextualize(model, new_context)
232233
return last(evaluate!!(new_model, SimpleVarInfo{T}()))
233234
end
234235
function SimpleVarInfo{T}(
235-
model::Model, sampler::AbstractSampler=SampleFromPrior()
236+
model::Model, init_strategy::AbstractInitStrategy=PriorInit()
236237
) where {T<:Real}
237-
return SimpleVarInfo{T}(Random.default_rng(), model, sampler)
238+
return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy)
238239
end
239240
# Constructors without type param
240241
function SimpleVarInfo(
241-
rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior()
242+
rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit()
242243
)
243-
return SimpleVarInfo{LogProbType}(rng, model, sampler)
244+
return SimpleVarInfo{LogProbType}(rng, model, init_strategy)
244245
end
245-
function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior())
246-
return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler)
246+
function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit())
247+
return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy)
247248
end
248249

249250
# Constructor from `VarInfo`.
@@ -259,12 +260,12 @@ end
259260

260261
function untyped_simple_varinfo(model::Model)
261262
varinfo = SimpleVarInfo(OrderedDict{VarName,Any}())
262-
return last(evaluate_and_sample!!(model, varinfo))
263+
return last(init!!(model, varinfo))
263264
end
264265

265266
function typed_simple_varinfo(model::Model)
266267
varinfo = SimpleVarInfo{Float64}()
267-
return last(evaluate_and_sample!!(model, varinfo))
268+
return last(init!!(model, varinfo))
268269
end
269270

270271
function unflatten(svi::SimpleVarInfo, x::AbstractVector)
@@ -474,7 +475,6 @@ function assume(
474475
return value, vi
475476
end
476477

477-
# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
478478
function settrans!!(vi::SimpleVarInfo, trans)
479479
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
480480
end
@@ -484,6 +484,15 @@ end
484484
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
485485
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans)
486486
end
487+
function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
488+
# We keep this method around just to obey the AbstractVarInfo interface; however,
489+
# this is only a valid operation if it would be a no-op.
490+
if trans != istrans(vi)
491+
error(
492+
"Individual variables in SimpleVarInfo cannot have different `settrans` statuses.",
493+
)
494+
end
495+
end
487496

488497
istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
489498
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)

src/test_utils/contexts.jl

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,45 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
2929
node_trait = DynamicPPL.NodeTrait(context)
3030
# Throw error immediately if it it's missing a `NodeTrait` implementation.
3131
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
32-
throw(ValueError("Invalid NodeTrait: $node_trait"))
32+
error("Invalid NodeTrait: $node_trait")
3333

34-
# To see change, let's make sure we're using a different leaf context than the current.
35-
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
36-
DynamicPPL.DynamicTransformationContext{false}()
34+
if node_trait isa DynamicPPL.IsLeaf
35+
test_leaf_context(context, model)
3736
else
38-
DefaultContext()
37+
test_parent_context(context, model)
3938
end
40-
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
41-
leafcontext_new
39+
end
40+
41+
function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
42+
@test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf
43+
44+
# Note that for a leaf context we can't assume that it will work with an
45+
# empty VarInfo. Thus we only test evaluation (i.e., assuming that the
46+
# varinfo already contains all necessary variables).
47+
@testset "evaluation" begin
48+
# Generate a new filled untyped varinfo
49+
_, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo())
50+
typed_vi = DynamicPPL.typed_varinfo(untyped_vi)
51+
new_model = contextualize(model, context)
52+
for vi in [untyped_vi, typed_vi]
53+
_, vi = DynamicPPL.evaluate!!(new_model, vi)
54+
@test vi isa DynamicPPL.VarInfo
55+
end
56+
end
57+
end
58+
59+
function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
60+
@test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent
4261

43-
# The interface methods.
44-
if node_trait isa DynamicPPL.IsParent
45-
# `childcontext` and `setchildcontext`
46-
# With new child context
62+
@testset "{set,}{leaf,child}context" begin
63+
# Ensure we're using a different leaf context than the current.
64+
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
65+
DynamicPPL.DynamicTransformationContext{false}()
66+
else
67+
DefaultContext()
68+
end
69+
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
70+
leafcontext_new
4771
childcontext_new = TestParentContext()
4872
@test DynamicPPL.childcontext(
4973
DynamicPPL.setchildcontext(context, childcontext_new)
@@ -56,19 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
5680
leafcontext_new
5781
end
5882

59-
# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
60-
# The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it.
61-
# NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the
62-
# context might alter which variables are present, their names, etc., e.g. `PrefixContext`.
63-
# TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos.
64-
# Untyped varinfo.
65-
varinfo_untyped = DynamicPPL.VarInfo()
66-
model_with_spl = contextualize(model, SamplingContext(context))
67-
model_without_spl = contextualize(model, context)
68-
@test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any
69-
@test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any
70-
# Typed varinfo.
71-
varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped)
72-
@test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any
73-
@test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any
83+
@testset "initialisation and evaluation" begin
84+
new_model = contextualize(model, context)
85+
for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())]
86+
# Initialisation
87+
_, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo())
88+
@test vi isa DynamicPPL.VarInfo
89+
# Evaluation
90+
_, vi = DynamicPPL.evaluate!!(new_model, vi)
91+
@test vi isa DynamicPPL.VarInfo
92+
end
93+
end
7494
end

src/test_utils/model_interface.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model`
9292
a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided.
9393
"""
9494
function varnames(model::Model)
95-
return collect(
96-
keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict()))))
97-
)
95+
return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict())))))
9896
end
9997

10098
"""

0 commit comments

Comments
 (0)