Skip to content

Commit 3af63d5

Browse files
authored
Remove context from model evaluation (use model.context instead) (#952)
* Change `evaluate!!` API, add `sample!!` * Fix literally everything else that I broke * Fix some docstrings * fix ForwardDiffExt (look, multiple dispatch bad...) * Changelog * fix a test * Fix docstrings * use `sample!!` * Fix a couple more cases * Globally rename `sample!!` -> `evaluate_and_sample!!`, add changelog warning
1 parent bec523a commit 3af63d5

37 files changed

+477
-575
lines changed

HISTORY.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,42 @@ This release overhauls how VarInfo objects track variables such as the log joint
1818
- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
1919
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.
2020

21+
### Evaluation contexts
22+
23+
Historically, evaluating a DynamicPPL model has required three arguments: a model, some kind of VarInfo, and a context.
24+
It's less known, though, that since DynamicPPL 0.14.0 the _model_ itself actually contains a context as well.
25+
This version therefore excises the context argument, and instead uses `model.context` as the evaluation context.
26+
27+
The upshot of this is that many functions that previously took a context argument now no longer do.
28+
There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value).
29+
30+
`evaluate!!(model, varinfo, ext_context)` is deprecated, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`.
31+
If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost.
32+
If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely.
33+
34+
To aid with this process, `contextualize` is now exported from DynamicPPL.
35+
36+
The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`.
37+
Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object.
38+
Thus, this release also introduces the **unexported** function `evaluate_and_sample!!`.
39+
Essentially, `evaluate_and_sample!!(rng, model, varinfo, sampler)` is a drop-in replacement for `evaluate!!(model, varinfo, SamplingContext(rng, sampler))`.
40+
**Do note that this is an internal method**, and its name or semantics are liable to change in the future without warning.
41+
42+
There are many methods that no longer take a context argument, and listing them all would be too much.
43+
However, here are the more user-facing ones:
44+
45+
- `LogDensityFunction` no longer has a context field (or type parameter)
46+
- `DynamicPPL.TestUtils.AD.run_ad` no longer uses a context (and the returned `ADResult` object no longer has a context field)
47+
- `VarInfo(rng, model, sampler)` and other VarInfo constructors / functions that made VarInfos (e.g. `typed_varinfo`) from a model
48+
- `(::Model)(args...)`: specifically, this now only takes `rng` and `varinfo` arguments (with both being optional)
49+
- If you are using the `__context__` special variable inside a model, you will now have to use `__model__.context` instead
50+
51+
And a couple of more internal changes:
52+
53+
- `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments
54+
- `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`)
55+
- The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument
56+
2157
## 0.36.12
2258

2359
Removed several unexported functions.

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8181
end
8282

8383
adbackend = to_backend(adbackend)
84-
context = DynamicPPL.DefaultContext()
8584

8685
if islinked
8786
vi = DynamicPPL.link(vi, model)
8887
end
8988

90-
f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend)
89+
f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend)
9190
# The parameters at which we evaluate f.
9291
θ = vi[:]
9392

docs/src/api.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ getargnames
3636
getmissings
3737
```
3838

39+
The context of a model can be set using [`contextualize`](@ref):
40+
41+
```@docs
42+
contextualize
43+
```
44+
3945
## Evaluation
4046

4147
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
@@ -438,13 +444,21 @@ DynamicPPL.varname_and_value_leaves
438444

439445
### Evaluation Contexts
440446

441-
Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref).
447+
Internally, model evaluation is performed with [`AbstractPPL.evaluate!!`](@ref).
442448

443449
```@docs
444450
AbstractPPL.evaluate!!
445451
```
446452

447-
The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function.
453+
This method mutates the `varinfo` used for execution.
454+
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
455+
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:
456+
457+
```@docs
458+
DynamicPPL.evaluate_and_sample!!
459+
```
460+
461+
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
448462
Contexts are subtypes of `AbstractPPL.AbstractContext`.
449463

450464
```@docs

ext/DynamicPPLForwardDiffExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ function DynamicPPL.tweak_adtype(
1111
ad::ADTypes.AutoForwardDiff{chunk_size},
1212
::DynamicPPL.Model,
1313
vi::DynamicPPL.AbstractVarInfo,
14-
::DynamicPPL.AbstractContext,
1514
) where {chunk_size}
1615
params = vi[:]
1716

ext/DynamicPPLJETExt.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL
44
using JET: JET
55

66
function DynamicPPL.Experimental.is_suitable_varinfo(
7-
model::DynamicPPL.Model,
8-
context::DynamicPPL.AbstractContext,
9-
varinfo::DynamicPPL.AbstractVarInfo;
10-
only_ddpl::Bool=true,
7+
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
118
)
129
# Let's make sure that both evaluation and sampling doesn't result in type errors.
13-
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
14-
model, varinfo, context
15-
)
10+
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
1611
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
1712
# This way we don't just fall back to untyped if the user's code is the issue.
1813
result = if only_ddpl
@@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo(
2419
end
2520

2621
function DynamicPPL.Experimental._determine_varinfo_jet(
27-
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
22+
model::DynamicPPL.Model; only_ddpl::Bool=true
2823
)
24+
# Use SamplingContext to test type stability.
25+
sampling_model = DynamicPPL.contextualize(
26+
model, DynamicPPL.SamplingContext(model.context)
27+
)
28+
2929
# First we try with the typed varinfo.
30-
varinfo = DynamicPPL.typed_varinfo(model, context)
30+
varinfo = DynamicPPL.typed_varinfo(sampling_model)
3131

3232
# Let's make sure that both evaluation and sampling doesn't result in type errors.
3333
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34-
model, context, varinfo; only_ddpl
34+
sampling_model, varinfo; only_ddpl
3535
)
3636

3737
if !issuccess
@@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
4646
else
4747
# Warn the user that we can't use the type stable one.
4848
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49-
DynamicPPL.untyped_varinfo(model, context)
49+
DynamicPPL.untyped_varinfo(sampling_model)
5050
end
5151
end
5252

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ function DynamicPPL.predict(
115115
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116116
predictive_samples = map(iters) do (sample_idx, chain_idx)
117117
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
model(rng, varinfo, DynamicPPL.SampleFromPrior())
118+
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))
119119

120120
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121121
varname_vals = mapreduce(

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export AbstractVarInfo,
102102
# LogDensityFunction
103103
LogDensityFunction,
104104
# Contexts
105+
contextualize,
105106
SamplingContext,
106107
DefaultContext,
107108
PrefixContext,

src/compiler.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
1+
const INTERNALNAMES = (:__model__, :__varinfo__)
22

33
"""
44
need_concretize(expr)
@@ -63,9 +63,9 @@ used in its place.
6363
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
6464
return quote
6565
if $(DynamicPPL.contextual_isassumption)(
66-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
66+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
6767
)
68-
# Considered an assumption by `__context__` which means either:
68+
# Considered an assumption by `__model__.context` which means either:
6969
# 1. We hit the default implementation, e.g. using `DefaultContext`,
7070
# which in turn means that we haven't considered if it's one of
7171
# the model arguments, hence we need to check this.
@@ -116,7 +116,7 @@ end
116116
isfixed(expr, vn) = false
117117
function isfixed(::Union{Symbol,Expr}, vn)
118118
return :($(DynamicPPL.contextual_isfixed)(
119-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
119+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
120120
))
121121
end
122122

@@ -417,7 +417,7 @@ function generate_assign(left, right)
417417
return quote
418418
$right_val = $right
419419
if $(DynamicPPL.is_extracting_values)(__varinfo__)
420-
$vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left)))
420+
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
421421
__varinfo__ = $(map_accumulator!!)(
422422
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
423423
)
@@ -431,7 +431,11 @@ function generate_tilde_literal(left, right)
431431
@gensym value
432432
return quote
433433
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
434-
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__
434+
__model__.context,
435+
$(DynamicPPL.check_tilde_rhs)($right),
436+
$left,
437+
nothing,
438+
__varinfo__,
435439
)
436440
$value
437441
end
@@ -456,20 +460,20 @@ function generate_tilde(left, right)
456460
$isassumption = $(DynamicPPL.isassumption(left, vn))
457461
if $(DynamicPPL.isfixed(left, vn))
458462
$left = $(DynamicPPL.getfixed_nested)(
459-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
463+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
460464
)
461465
elseif $isassumption
462466
$(generate_tilde_assume(left, dist, vn))
463467
else
464468
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
465469
if !$(DynamicPPL.inargnames)($vn, __model__)
466470
$left = $(DynamicPPL.getconditioned_nested)(
467-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
471+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
468472
)
469473
end
470474

471475
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
472-
__context__,
476+
__model__.context,
473477
$(DynamicPPL.check_tilde_rhs)($dist),
474478
$(maybe_view(left)),
475479
$vn,
@@ -494,7 +498,7 @@ function generate_tilde_assume(left, right, vn)
494498

495499
return quote
496500
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
497-
__context__,
501+
__model__.context,
498502
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
499503
__varinfo__,
500504
)
@@ -652,11 +656,7 @@ function build_output(modeldef, linenumbernode)
652656

653657
# Add the internal arguments to the user-specified arguments (positional + keywords).
654658
evaluatordef[:args] = vcat(
655-
[
656-
:(__model__::$(DynamicPPL.Model)),
657-
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
658-
:(__context__::$(DynamicPPL.AbstractContext)),
659-
],
659+
[:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))],
660660
args,
661661
)
662662

0 commit comments

Comments
 (0)