Skip to content

Commit f20e86c

Browse files
authored
Remove 3-argument {_,}evaluate!!; clean up submodel code (#960)
* Clean up submodel code, remove 3-arg `_evaluate!!` * Remove 3-argument `evaluate!!` as well * Update changelog * Improve submodel error message * Fix doctest * Add error hint for three-argument evaluate!!
1 parent 7f20709 commit f20e86c

File tree

8 files changed

+85
-202
lines changed

8 files changed

+85
-202
lines changed

HISTORY.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ This version therefore excises the context argument, and instead uses `model.con
3131
The upshot of this is that many functions that previously took a context argument now no longer do.
3232
There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value).
3333

34-
`evaluate!!(model, varinfo, ext_context)` is deprecated, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`.
34+
`evaluate!!(model, varinfo, ext_context)` is removed, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`.
3535
If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost.
3636
If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely.
3737

38-
To aid with this process, `contextualize` is now exported from DynamicPPL.
38+
**To aid with this process, `contextualize` is now exported from DynamicPPL.**
3939

4040
The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`.
4141
Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object.
@@ -54,9 +54,10 @@ However, here are the more user-facing ones:
5454

5555
And a couple of more internal changes:
5656

57-
- `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments
57+
- Just like `evaluate!!`, the other functions `_evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` now no longer accept context arguments
5858
- `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`)
5959
- The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument
60+
- The internal representation and API dealing with submodels (i.e., `ReturnedModelWrapper`, `Sampleable`, `should_auto_prefix`, `is_rhs_model`) has been simplified. If you need to check whether something is a submodel, just use `x isa DynamicPPL.Submodel`. Note that the public API i.e. `to_submodel` remains completely untouched.
6061

6162
## 0.36.12
6263

docs/src/api.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,6 @@ In the context of including models within models, it's also useful to prefix the
152152
DynamicPPL.prefix
153153
```
154154

155-
Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else
156-
157-
```@docs
158-
returned(::Model)
159-
```
160-
161155
## Utilities
162156

163157
It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.

src/DynamicPPL.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
171171
include("utils.jl")
172172
include("chains.jl")
173173
include("model.jl")
174-
include("submodel.jl")
175174
include("sampler.jl")
176175
include("varname.jl")
177176
include("distribution_wrappers.jl")
178177
include("contexts.jl")
178+
include("submodel.jl")
179179
include("varnamedvector.jl")
180180
include("accumulators.jl")
181181
include("default_accumulators.jl")
@@ -226,6 +226,21 @@ if isdefined(Base.Experimental, :register_error_hint)
226226
)
227227
end
228228
end
229+
230+
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
231+
is_evaluate_three_arg =
232+
exc.f === AbstractPPL.evaluate!! &&
233+
length(argtypes) == 3 &&
234+
argtypes[1] <: Model &&
235+
argtypes[2] <: AbstractVarInfo &&
236+
argtypes[3] <: AbstractContext
237+
if is_evaluate_three_arg
238+
print(
239+
io,
240+
"\n\nThe method `evaluate!!(model, varinfo, new_ctx)` has been removed. Instead, you should store the `new_ctx` in the `model.context` field using `new_model = contextualize(model, new_ctx)`, and then call `evaluate!!(new_model, varinfo)` on the new model. (Note that, if the model already contained a non-default context, you will need to wrap the existing context.)",
241+
)
242+
end
243+
end
229244
end
230245
end
231246

src/compiler.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,7 @@ function check_tilde_rhs(@nospecialize(x))
176176
end
177177
check_tilde_rhs(x::Distribution) = x
178178
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
179-
check_tilde_rhs(x::ReturnedModelWrapper) = x
180-
function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
181-
model = check_tilde_rhs(x.model)
182-
return Sampleable{typeof(model),AutoPrefix}(model)
183-
end
179+
check_tilde_rhs(x::Submodel{M,AutoPrefix}) where {M,AutoPrefix} = x
184180

185181
"""
186182
check_dot_tilde_rhs(x)

src/context_implementations.jl

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -63,31 +63,10 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
6363
probability of `vi` with the returned value.
6464
"""
6565
function tilde_assume!!(context, right, vn, vi)
66-
return if is_rhs_model(right)
67-
# Here, we apply the PrefixContext _not_ to the parent `context`, but
68-
# to the context of the submodel being evaluated. This means that later=
69-
# on in `make_evaluate_args_and_kwargs`, the context stack will be
70-
# correctly arranged such that it goes like this:
71-
# parent_context[1] -> parent_context[2] -> ... -> PrefixContext ->
72-
# submodel_context[1] -> submodel_context[2] -> ... -> leafcontext
73-
# See the docstring of `make_evaluate_args_and_kwargs`, and the internal
74-
# DynamicPPL documentation on submodel conditioning, for more details.
75-
#
76-
# NOTE: This relies on the existence of `right.model.model`. Right now,
77-
# the only thing that can return true for `is_rhs_model` is something
78-
# (a `Sampleable`) that has a `model` field that itself (a
79-
# `ReturnedModelWrapper`) has a `model` field. This may or may not
80-
# change in the future.
81-
if should_auto_prefix(right)
82-
dppl_model = right.model.model # This isa DynamicPPL.Model
83-
prefixed_submodel_context = PrefixContext(vn, dppl_model.context)
84-
new_dppl_model = contextualize(dppl_model, prefixed_submodel_context)
85-
right = to_submodel(new_dppl_model, true)
86-
end
87-
rand_like!!(right, context, vi)
66+
return if right isa DynamicPPL.Submodel
67+
_evaluate!!(right, vi, context, vn)
8868
else
89-
value, vi = tilde_assume(context, right, vn, vi)
90-
return value, vi
69+
tilde_assume(context, right, vn, vi)
9170
end
9271
end
9372

@@ -129,17 +108,14 @@ accumulate the log probability, and return the observed value and updated `vi`.
129108
Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
130109
and indices; if needed, these can be accessed through this function, though.
131110
"""
132-
function tilde_observe!!(context::DefaultContext, right, left, vn, vi)
133-
is_rhs_model(right) && throw(
134-
ArgumentError(
135-
"`~` with a model on the right-hand side of an observe statement is not supported",
136-
),
137-
)
111+
function tilde_observe!!(::DefaultContext, right, left, vn, vi)
112+
right isa DynamicPPL.Submodel &&
113+
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed"))
138114
vi = accumulate_observe!!(vi, right, left, vn)
139115
return left, vi
140116
end
141117

142-
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)
118+
function assume(::Random.AbstractRNG, spl::Sampler, dist)
143119
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
144120
end
145121

src/model.jl

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ julia> # However, it's not possible to condition `inner` directly.
258258
conditioned_model_fail = model | (inner = 1.0, );
259259
260260
julia> conditioned_model_fail()
261-
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
261+
ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed
262262
[...]
263263
```
264264
"""
@@ -864,12 +864,6 @@ If multiple threads are available, the varinfo provided will be wrapped in a
864864
865865
Returns a tuple of the model's return value, plus the updated `varinfo`
866866
(unwrapped if necessary).
867-
868-
evaluate!!(model::Model, varinfo, context)
869-
870-
When an extra context stack is provided, the model's context is inserted into
871-
that context stack. See `combine_model_and_external_contexts`. This method is
872-
deprecated.
873867
"""
874868
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
875869
return if use_threadsafe_eval(model.context, varinfo)
@@ -878,17 +872,6 @@ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
878872
evaluate_threadunsafe!!(model, varinfo)
879873
end
880874
end
881-
function AbstractPPL.evaluate!!(
882-
model::Model, varinfo::AbstractVarInfo, context::AbstractContext
883-
)
884-
Base.depwarn(
885-
"The `context` argument to evaluate!!(model, varinfo, context) is deprecated.",
886-
:dynamicppl_evaluate_context,
887-
)
888-
new_ctx = combine_model_and_external_contexts(model.context, context)
889-
model = contextualize(model, new_ctx)
890-
return evaluate!!(model, varinfo)
891-
end
892875

893876
"""
894877
evaluate_threadunsafe!!(model, varinfo)
@@ -932,54 +915,14 @@ Evaluate the `model` with the given `varinfo`.
932915
933916
This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not
934917
reset the log probability of the `varinfo` before running.
935-
936-
_evaluate!!(model::Model, varinfo, context)
937-
938-
If an additional `context` is provided, the model's context is combined with
939-
that context before evaluation.
940918
"""
941919
function _evaluate!!(model::Model, varinfo::AbstractVarInfo)
942920
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo)
943921
return model.f(args...; kwargs...)
944922
end
945-
function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext)
946-
# TODO(penelopeysm): We don't really need this, but it's a useful
947-
# convenience method. We could remove it after we get rid of the
948-
# evaluate_threadsafe!! stuff (in favour of making users call evaluate!!
949-
# with a TSVI themselves).
950-
new_ctx = combine_model_and_external_contexts(model.context, context)
951-
model = contextualize(model, new_ctx)
952-
return _evaluate!!(model, varinfo)
953-
end
954923

955924
is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#")
956925

957-
"""
958-
combine_model_and_external_contexts(model_context, external_context)
959-
960-
Combine a context from a model and an external context into a single context.
961-
962-
The resulting context stack has the following structure:
963-
964-
`external_context` -> `childcontext(external_context)` -> ... ->
965-
`model_context` -> `childcontext(model_context)` -> ... ->
966-
`leafcontext(external_context)`
967-
968-
The reason for this is that we want to give `external_context` precedence over
969-
`model_context`, while also preserving the leaf context of `external_context`.
970-
We can do this by
971-
972-
1. Set the leaf context of `model_context` to `leafcontext(external_context)`.
973-
2. Set leaf context of `external_context` to the context resulting from (1).
974-
"""
975-
function combine_model_and_external_contexts(
976-
model_context::AbstractContext, external_context::AbstractContext
977-
)
978-
return setleafcontext(
979-
external_context, setleafcontext(model_context, leafcontext(external_context))
980-
)
981-
end
982-
983926
"""
984927
make_evaluate_args_and_kwargs(model, varinfo)
985928

0 commit comments

Comments
 (0)