Skip to content

Commit c8db5a4

Browse files
committed
Change evaluate!! API, add sample!!
1 parent 80db9e2 commit c8db5a4

File tree

1 file changed

+99
-70
lines changed

1 file changed

+99
-70
lines changed

src/model.jl

Lines changed: 99 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -794,15 +794,26 @@ julia> # Now `a.x` will be sampled.
794794
fixed(model::Model) = fixed(model.context)
795795

796796
"""
797-
(model::Model)([rng, varinfo, sampler, context])
797+
(model::Model)()
798+
(model::Model)(rng[, varinfo, sampler, context])
798799
799-
Sample from the `model` using the `sampler` with random number generator `rng` and the
800-
`context`, and store the sample and log joint probability in `varinfo`.
800+
Sample from the `model` using the `sampler` with random number generator `rng`
801+
and the `context`, and store the sample and log joint probability in `varinfo`.
801802
802-
The method resets the log joint probability of `varinfo` and increases the evaluation
803-
number of `sampler`.
803+
Returns the model's return value.
804+
805+
If no arguments are provided, uses the default random number generator and
806+
samples from the prior.
804807
"""
805-
(model::Model)(args...) = first(evaluate!!(model, args...))
808+
(model::Model)() = model(Random.default_rng())
809+
function (model::Model)(
810+
rng::AbstractRNG,
811+
varinfo::AbstractVarInfo=VarInfo(),
812+
sampler::AbstractSampler=SampleFromPrior(),
813+
)
814+
spl_ctx = SamplingContext(rng, sampler, DefaultContext())
815+
return evaluate!!(model, varinfo, spl_ctx)
816+
end
806817

807818
"""
808819
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
@@ -815,65 +826,51 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
815826
end
816827

817828
"""
818-
evaluate!!(model::Model[, rng, varinfo, sampler, context])
819-
820-
Sample from the `model` using the `sampler` with random number generator `rng` and the
821-
`context`, and store the sample and log joint probability in `varinfo`.
829+
sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo)
822830
823-
Returns both the return-value of the original model, and the resulting varinfo.
831+
Evaluate the `model` with the given `varinfo`, but perform sampling during the
832+
evaluation by wrapping the model's context in a `SamplingContext`.
824833
825-
The method resets the log joint probability of `varinfo` and increases the evaluation
826-
number of `sampler`.
834+
Returns a tuple of the model's return value, plus the updated `varinfo` object.
827835
"""
828-
function AbstractPPL.evaluate!!(
829-
model::Model, varinfo::AbstractVarInfo, context::AbstractContext
830-
)
831-
return if use_threadsafe_eval(context, varinfo)
832-
evaluate_threadsafe!!(model, varinfo, context)
833-
else
834-
evaluate_threadunsafe!!(model, varinfo, context)
835-
end
836+
function sample!!(rng::AbstractRNG, model::Model, varinfo::AbstractVarInfo)
837+
sampling_model = contextualize(
838+
model, SamplingContext(rng, SampleFromPrior(), model.context)
839+
)
840+
return evaluate!!(sampling_model, varinfo)
836841
end
837842

838-
function AbstractPPL.evaluate!!(
839-
model::Model,
840-
rng::Random.AbstractRNG,
841-
varinfo::AbstractVarInfo=VarInfo(),
842-
sampler::AbstractSampler=SampleFromPrior(),
843-
context::AbstractContext=DefaultContext(),
844-
)
845-
return evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))
846-
end
843+
"""
844+
evaluate!!(model::Model, varinfo)
845+
evaluate!!(model::Model, varinfo, context)
847846
848-
function AbstractPPL.evaluate!!(model::Model, context::AbstractContext)
849-
return evaluate!!(model, VarInfo(), context)
850-
end
847+
Evaluate the `model` with the given `varinfo`. If an extra context stack is
848+
provided, the model's context is inserted into that context stack. See
849+
[`combine_model_and_external_contexts`](@ref).
851850
852-
function AbstractPPL.evaluate!!(
853-
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
854-
)
855-
return evaluate!!(model, Random.default_rng(), args...)
856-
end
851+
If multiple threads are available, the varinfo provided will be wrapped in a
852+
[`DynamicPPL.ThreadSafeVarInfo`](@ref) before evaluation.
857853
858-
# without VarInfo
859-
function AbstractPPL.evaluate!!(
860-
model::Model,
861-
rng::Random.AbstractRNG,
862-
sampler::AbstractSampler,
863-
args::AbstractContext...,
864-
)
865-
return evaluate!!(model, rng, VarInfo(), sampler, args...)
854+
Returns a tuple of the model's return value, plus the updated `varinfo`
855+
(unwrapped if necessary).
856+
"""
857+
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
858+
return if use_threadsafe_eval(model.context, varinfo)
859+
evaluate_threadsafe!!(model, varinfo)
860+
else
861+
evaluate_threadunsafe!!(model, varinfo)
862+
end
866863
end
867-
868-
# without VarInfo and without AbstractSampler
869864
function AbstractPPL.evaluate!!(
870-
model::Model, rng::Random.AbstractRNG, context::AbstractContext
865+
model::Model, varinfo::AbstractVarInfo, context::AbstractContext
871866
)
872-
return evaluate!!(model, rng, VarInfo(), SampleFromPrior(), context)
867+
new_ctx = combine_model_and_external_contexts(model.context, context)
868+
model = contextualize(model, new_ctx)
869+
return evaluate!!(model, varinfo)
873870
end
874871

875872
"""
876-
evaluate_threadunsafe!!(model, varinfo, context)
873+
evaluate_threadunsafe!!(model, varinfo)
877874
878875
Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`.
879876
@@ -882,8 +879,8 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
882879
883880
See also: [`evaluate_threadsafe!!`](@ref)
884881
"""
885-
function evaluate_threadunsafe!!(model, varinfo, context)
886-
return _evaluate!!(model, resetlogp!!(varinfo), context)
882+
function evaluate_threadunsafe!!(model, varinfo)
883+
return _evaluate!!(model, resetlogp!!(varinfo))
887884
end
888885

889886
"""
@@ -897,31 +894,74 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
897894
898895
See also: [`evaluate_threadunsafe!!`](@ref)
899896
"""
900-
function evaluate_threadsafe!!(model, varinfo, context)
897+
function evaluate_threadsafe!!(model, varinfo)
901898
wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo))
902-
result, wrapper_new = _evaluate!!(model, wrapper, context)
899+
result, wrapper_new = _evaluate!!(model, wrapper)
900+
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it
901+
# will return the underlying VI, which is a bit counterintuitive (because
902+
# calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it
903+
# again).
903904
return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new))
904905
end
905906

906907
"""
908+
_evaluate!!(model::Model, varinfo)
907909
_evaluate!!(model::Model, varinfo, context)
908910
909-
Evaluate the `model` with the arguments matching the given `context` and `varinfo` object.
911+
Evaluate the `model` with the given `varinfo`. If an additional `context` is provided,
912+
the model's context is combined with that context.
913+
914+
This function does not wrap the varinfo in a `ThreadSafeVarInfo`.
910915
"""
911-
function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext)
912-
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo, context)
916+
function _evaluate!!(model::Model, varinfo::AbstractVarInfo)
917+
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo)
913918
return model.f(args...; kwargs...)
914919
end
920+
function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext)
921+
# TODO(penelopeysm): We don't really need this, but it's a useful
922+
# convenience method. We could remove it after we get rid of the
923+
# evaluate_threadsafe!! stuff (in favour of making users call evaluate!!
924+
# with a TSVI themselves).
925+
new_ctx = combine_model_and_external_contexts(model.context, context)
926+
model = contextualize(model, new_ctx)
927+
return _evaluate!!(model, varinfo)
928+
end
915929

916930
is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#")
917931

932+
"""
933+
combine_model_and_external_contexts(model_context, external_context)
934+
935+
Combine a context from a model and an external context into a single context.
936+
937+
The resulting context stack has the following structure:
938+
939+
`external_context` -> `childcontext(external_context)` -> ... ->
940+
`model_context` -> `childcontext(model_context)` -> ... ->
941+
`leafcontext(external_context)`
942+
943+
The reason for this is that we want to give `external_context` precedence over
944+
`model_context`, while also preserving the leaf context of `external_context`.
945+
We can do this by
946+
947+
1. Set the leaf context of `model_context` to `leafcontext(external_context)`.
948+
2. Set leaf context of `external_context` to the context resulting from (1).
949+
"""
950+
function combine_model_and_external_contexts(
951+
model_context::AbstractContext, external_context::AbstractContext
952+
)
953+
return setleafcontext(
954+
external_context, setleafcontext(model_context, leafcontext(external_context))
955+
)
956+
end
957+
918958
"""
919959
make_evaluate_args_and_kwargs(model, varinfo, context)
920960
921961
Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e.
922962
"""
923963
@generated function make_evaluate_args_and_kwargs(
924-
model::Model{_F,argnames}, varinfo::AbstractVarInfo, context::AbstractContext
964+
model::Model{_F,argnames}, varinfo::AbstractVarInfo
925965
) where {_F,argnames}
926966
unwrap_args = [
927967
if is_splat_symbol(var)
@@ -930,18 +970,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
930970
:($matchingvalue(varinfo, model.args.$var))
931971
end for var in argnames
932972
]
933-
934-
# We want to give `context` precedence over `model.context` while also
935-
# preserving the leaf context of `context`. We can do this by
936-
# 1. Set the leaf context of `model.context` to `leafcontext(context)`.
937-
# 2. Set leaf context of `context` to the context resulting from (1).
938-
# The result is:
939-
# `context` -> `childcontext(context)` -> ... -> `model.context`
940-
# -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
941973
return quote
942-
context_new = setleafcontext(
943-
context, setleafcontext(model.context, leafcontext(context))
944-
)
945974
args = (
946975
model,
947976
# Maybe perform `invlink!!` once prior to evaluation to avoid

0 commit comments

Comments
 (0)