Skip to content

Commit 9c648dc

Browse files
committed
Clean up submodel code, remove 3-arg _evaluate!!
1 parent b4830e0 commit 9c648dc

File tree

5 files changed

+58
-142
lines changed

5 files changed

+58
-142
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
172172
include("utils.jl")
173173
include("chains.jl")
174174
include("model.jl")
175-
include("submodel.jl")
176175
include("sampler.jl")
177176
include("varname.jl")
178177
include("distribution_wrappers.jl")
179178
include("contexts.jl")
179+
include("submodel.jl")
180180
include("varnamedvector.jl")
181181
include("accumulators.jl")
182182
include("default_accumulators.jl")

src/compiler.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,7 @@ function check_tilde_rhs(@nospecialize(x))
176176
end
177177
check_tilde_rhs(x::Distribution) = x
178178
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
179-
check_tilde_rhs(x::ReturnedModelWrapper) = x
180-
function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
181-
model = check_tilde_rhs(x.model)
182-
return Sampleable{typeof(model),AutoPrefix}(model)
183-
end
179+
check_tilde_rhs(x::Submodel{M,AutoPrefix}) where {M,AutoPrefix} = x
184180

185181
"""
186182
check_dot_tilde_rhs(x)

src/context_implementations.jl

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -63,31 +63,10 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
6363
probability of `vi` with the returned value.
6464
"""
6565
function tilde_assume!!(context, right, vn, vi)
66-
return if is_rhs_model(right)
67-
# Here, we apply the PrefixContext _not_ to the parent `context`, but
68-
# to the context of the submodel being evaluated. This means that later=
69-
# on in `make_evaluate_args_and_kwargs`, the context stack will be
70-
# correctly arranged such that it goes like this:
71-
# parent_context[1] -> parent_context[2] -> ... -> PrefixContext ->
72-
# submodel_context[1] -> submodel_context[2] -> ... -> leafcontext
73-
# See the docstring of `make_evaluate_args_and_kwargs`, and the internal
74-
# DynamicPPL documentation on submodel conditioning, for more details.
75-
#
76-
# NOTE: This relies on the existence of `right.model.model`. Right now,
77-
# the only thing that can return true for `is_rhs_model` is something
78-
# (a `Sampleable`) that has a `model` field that itself (a
79-
# `ReturnedModelWrapper`) has a `model` field. This may or may not
80-
# change in the future.
81-
if should_auto_prefix(right)
82-
dppl_model = right.model.model # This isa DynamicPPL.Model
83-
prefixed_submodel_context = PrefixContext(vn, dppl_model.context)
84-
new_dppl_model = contextualize(dppl_model, prefixed_submodel_context)
85-
right = to_submodel(new_dppl_model, true)
86-
end
87-
rand_like!!(right, context, vi)
66+
return if right isa DynamicPPL.Submodel
67+
_evaluate!!(right, vi, context, vn)
8868
else
89-
value, vi = tilde_assume(context, right, vn, vi)
90-
return value, vi
69+
tilde_assume(context, right, vn, vi)
9170
end
9271
end
9372

@@ -129,17 +108,17 @@ accumulate the log probability, and return the observed value and updated `vi`.
129108
Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
130109
and indices; if needed, these can be accessed through this function, though.
131110
"""
132-
function tilde_observe!!(context::DefaultContext, right, left, vn, vi)
133-
is_rhs_model(right) && throw(
111+
function tilde_observe!!(::DefaultContext, right, left, vn, vi)
112+
right isa DynamicPPL.Submodel && throw(
134113
ArgumentError(
135-
"`~` with a model on the right-hand side of an observe statement is not supported",
114+
"`~` with a submodel on the right-hand side of an observe statement is not supported",
136115
),
137116
)
138117
vi = accumulate_observe!!(vi, right, left, vn)
139118
return left, vi
140119
end
141120

142-
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)
121+
function assume(::Random.AbstractRNG, spl::Sampler, dist)
143122
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
144123
end
145124

src/model.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -932,25 +932,11 @@ Evaluate the `model` with the given `varinfo`.
932932
933933
This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not
934934
reset the log probability of the `varinfo` before running.
935-
936-
_evaluate!!(model::Model, varinfo, context)
937-
938-
If an additional `context` is provided, the model's context is combined with
939-
that context before evaluation.
940935
"""
941936
function _evaluate!!(model::Model, varinfo::AbstractVarInfo)
942937
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo)
943938
return model.f(args...; kwargs...)
944939
end
945-
function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext)
946-
# TODO(penelopeysm): We don't really need this, but it's a useful
947-
# convenience method. We could remove it after we get rid of the
948-
# evaluate_threadsafe!! stuff (in favour of making users call evaluate!!
949-
# with a TSVI themselves).
950-
new_ctx = combine_model_and_external_contexts(model.context, context)
951-
model = contextualize(model, new_ctx)
952-
return _evaluate!!(model, varinfo)
953-
end
954940

955941
is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#")
956942

src/submodel.jl

Lines changed: 49 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,13 @@
11
"""
2-
is_rhs_model(x)
2+
Submodel{M,AutoPrefix}
33
4-
Return `true` if `x` is a model or model wrapper, and `false` otherwise.
4+
A wrapper around a model, plus a flag indicating whether it should be automatically
5+
prefixed with the left-hand variable in a `~` statement.
56
"""
6-
is_rhs_model(x) = false
7-
8-
"""
9-
Distributional
10-
11-
Abstract type for type indicating that something is "distributional".
12-
"""
13-
abstract type Distributional end
14-
15-
"""
16-
should_auto_prefix(distributional)
17-
18-
Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise.
19-
"""
20-
function should_auto_prefix end
21-
22-
"""
23-
is_rhs_model(x)
24-
25-
Return `true` if the `distributional` is a model, and `false` otherwise.
26-
"""
27-
function is_rhs_model end
28-
29-
"""
30-
Sampleable{M} <: Distributional
31-
32-
A wrapper around a model indicating it is sampleable.
33-
"""
34-
struct Sampleable{M,AutoPrefix} <: Distributional
35-
model::M
36-
end
37-
38-
should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix
39-
is_rhs_model(x::Sampleable) = is_rhs_model(x.model)
40-
41-
# TODO: Export this if it end up having a purpose beyond `to_submodel`.
42-
"""
43-
to_sampleable(model[, auto_prefix])
44-
45-
Return a wrapper around `model` indicating it is sampleable.
46-
47-
# Arguments
48-
- `model::Model`: the model to wrap.
49-
- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`.
50-
"""
51-
to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model)
52-
53-
"""
54-
rand_like!!(model_wrap, context, varinfo)
55-
56-
Returns a tuple with the first element being the realization and the second the updated varinfo.
57-
58-
# Arguments
59-
- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use.
60-
- `context::AbstractContext`: the context to use for evaluation.
61-
- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation.
62-
"""
63-
function rand_like!!(
64-
model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo
65-
)
66-
return rand_like!!(model_wrap.model, context, varinfo)
67-
end
68-
69-
"""
70-
ReturnedModelWrapper
71-
72-
A wrapper around a model indicating it is a model over its return values.
73-
74-
This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead.
75-
"""
76-
struct ReturnedModelWrapper{M<:Model}
7+
struct Submodel{M,AutoPrefix}
778
model::M
789
end
7910

80-
is_rhs_model(::ReturnedModelWrapper) = true
81-
82-
function rand_like!!(
83-
model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo
84-
)
85-
# Return's the value and the (possibly mutated) varinfo.
86-
return _evaluate!!(model_wrap.model, varinfo, context)
87-
end
88-
89-
"""
90-
returned(model)
91-
92-
Return a `model` wrapper indicating that it is a model over its return-values.
93-
"""
94-
returned(model::Model) = ReturnedModelWrapper(model)
95-
9611
"""
9712
to_submodel(model::Model[, auto_prefix::Bool])
9813
@@ -106,8 +21,8 @@ the model can be sampled from but not necessarily evaluated for its log density.
10621
`left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`.
10722
10823
!!! warning
109-
To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`.
110-
If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly.
24+
To avoid variable names clashing between models, it is recommended to leave the argument `auto_prefix` equal to `true`.
25+
If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly, i.e. `to_submodel(prefix(model, @varname(my_prefix)))`
11126
11227
# Arguments
11328
- `model::Model`: the model to wrap.
@@ -229,11 +144,51 @@ julia> @model illegal_likelihood() = a ~ to_submodel(inner())
229144
illegal_likelihood (generic function with 2 methods)
230145
231146
julia> model = illegal_likelihood() | (a = 1.0,);
232-
233147
julia> model()
234148
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
235149
[...]
236150
```
237151
"""
238-
to_submodel(model::Model, auto_prefix::Bool=true) =
239-
to_sampleable(returned(model), auto_prefix)
152+
to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m)
153+
154+
# When automatic prefixing is used, the submodel itself doesn't carry the
155+
# prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel
156+
# is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then
157+
# passed into this function.
158+
#
159+
# `parent_context` here refers to the context of the model that contains the
160+
# submodel.
161+
function _evaluate!!(
162+
submodel::Submodel{M,AutoPrefix},
163+
vi::AbstractVarInfo,
164+
parent_context::AbstractContext,
165+
left_vn::VarName,
166+
) where {M<:Model,AutoPrefix}
167+
# First, we construct the context to be used when evaluating the submodel. There
168+
# are several considerations here:
169+
# (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but
170+
# _only_ if automatic prefixing is supposed to be applied.
171+
submodel_context_prefixed = if AutoPrefix
172+
PrefixContext(left_vn, submodel.model.context)
173+
else
174+
submodel.model.context
175+
end
176+
177+
# (2) We need to respect the leaf-context of the parent model. This, unfortunately,
178+
# means disregarding the leaf-context of the submodel.
179+
submodel_context = setleafcontext(
180+
submodel_context_prefixed, leafcontext(parent_context)
181+
)
182+
183+
# (3) We need to use the parent model's context to wrap the whole thing, so that
184+
# e.g. if the user conditions the parent model, the conditioned variables will be
185+
# correctly picked up when evaluating the submodel.
186+
eval_context = setleafcontext(parent_context, submodel_context)
187+
188+
# (4) Finally, we need to store that context inside the submodel.
189+
model = contextualize(submodel.model, eval_context)
190+
191+
# Once that's all set up nicely, we can just _evaluate!! the wrapped model. This
192+
# returns a tuple of submodel.model's return value and the new varinfo.
193+
return _evaluate!!(model, vi)
194+
end

0 commit comments

Comments
 (0)