@@ -794,15 +794,26 @@ julia> # Now `a.x` will be sampled.
794
794
fixed (model:: Model ) = fixed (model. context)
795
795
796
796
"""
797
- (model::Model)([rng, varinfo, sampler, context])
797
+ (model::Model)()
798
+ (model::Model)(rng[, varinfo, sampler, context])
798
799
799
- Sample from the `model` using the `sampler` with random number generator `rng` and the
800
- `context`, and store the sample and log joint probability in `varinfo`.
800
+ Sample from the `model` using the `sampler` with random number generator `rng`
801
+ and the `context`, and store the sample and log joint probability in `varinfo`.
801
802
802
- The method resets the log joint probability of `varinfo` and increases the evaluation
803
- number of `sampler`.
803
+ Returns the model's return value.
804
+
805
+ If no arguments are provided, uses the default random number generator and
806
+ samples from the prior.
804
807
"""
805
- (model:: Model )(args... ) = first (evaluate!! (model, args... ))
808
+ (model:: Model )() = model (Random. default_rng ())
809
+ function (model:: Model )(
810
+ rng:: AbstractRNG ,
811
+ varinfo:: AbstractVarInfo = VarInfo (),
812
+ sampler:: AbstractSampler = SampleFromPrior (),
813
+ )
814
+ spl_ctx = SamplingContext (rng, sampler, DefaultContext ())
815
+ return evaluate!! (model, varinfo, spl_ctx)
816
+ end
806
817
807
818
"""
808
819
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
@@ -815,65 +826,51 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
815
826
end
816
827
817
828
"""
818
- evaluate!!(model::Model[, rng, varinfo, sampler, context])
819
-
820
- Sample from the `model` using the `sampler` with random number generator `rng` and the
821
- `context`, and store the sample and log joint probability in `varinfo`.
829
+ sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo)
822
830
823
- Returns both the return-value of the original model, and the resulting varinfo.
831
+ Evaluate the `model` with the given `varinfo`, but perform sampling during the
832
+ evaluation by wrapping the model's context in a `SamplingContext`.
824
833
825
- The method resets the log joint probability of `varinfo` and increases the evaluation
826
- number of `sampler`.
834
+ Returns a tuple of the model's return value, plus the updated `varinfo` object.
827
835
"""
828
- function AbstractPPL. evaluate!! (
829
- model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext
830
- )
831
- return if use_threadsafe_eval (context, varinfo)
832
- evaluate_threadsafe!! (model, varinfo, context)
833
- else
834
- evaluate_threadunsafe!! (model, varinfo, context)
835
- end
836
+ function sample!! (rng:: AbstractRNG , model:: Model , varinfo:: AbstractVarInfo )
837
+ sampling_model = contextualize (
838
+ model, SamplingContext (rng, SampleFromPrior (), model. context)
839
+ )
840
+ return evaluate!! (sampling_model, varinfo)
836
841
end
837
842
838
- function AbstractPPL. evaluate!! (
839
- model:: Model ,
840
- rng:: Random.AbstractRNG ,
841
- varinfo:: AbstractVarInfo = VarInfo (),
842
- sampler:: AbstractSampler = SampleFromPrior (),
843
- context:: AbstractContext = DefaultContext (),
844
- )
845
- return evaluate!! (model, varinfo, SamplingContext (rng, sampler, context))
846
- end
843
+ """
844
+ evaluate!!(model::Model, varinfo)
845
+ evaluate!!(model::Model, varinfo, context)
847
846
848
- function AbstractPPL . evaluate!! ( model:: Model , context:: AbstractContext )
849
- return evaluate!! (model, VarInfo (), context)
850
- end
847
+ Evaluate the ` model` with the given `varinfo`. If an extra context stack is
848
+ provided, the model's context is inserted into that context stack. See
849
+ [`combine_model_and_external_contexts`](@ref).
851
850
852
- function AbstractPPL. evaluate!! (
853
- model:: Model , args:: Union{AbstractVarInfo,AbstractSampler,AbstractContext} ...
854
- )
855
- return evaluate!! (model, Random. default_rng (), args... )
856
- end
851
+ If multiple threads are available, the varinfo provided will be wrapped in a
852
+ [`DynamicPPL.ThreadSafeVarInfo`](@ref) before evaluation.
857
853
858
- # without VarInfo
859
- function AbstractPPL. evaluate!! (
860
- model:: Model ,
861
- rng:: Random.AbstractRNG ,
862
- sampler:: AbstractSampler ,
863
- args:: AbstractContext... ,
864
- )
865
- return evaluate!! (model, rng, VarInfo (), sampler, args... )
854
+ Returns a tuple of the model's return value, plus the updated `varinfo`
855
+ (unwrapped if necessary).
856
+ """
857
+ function AbstractPPL. evaluate!! (model:: Model , varinfo:: AbstractVarInfo )
858
+ return if use_threadsafe_eval (model. context, varinfo)
859
+ evaluate_threadsafe!! (model, varinfo)
860
+ else
861
+ evaluate_threadunsafe!! (model, varinfo)
862
+ end
866
863
end
867
-
868
- # without VarInfo and without AbstractSampler
869
864
function AbstractPPL. evaluate!! (
870
- model:: Model , rng :: Random.AbstractRNG , context:: AbstractContext
865
+ model:: Model , varinfo :: AbstractVarInfo , context:: AbstractContext
871
866
)
872
- return evaluate!! (model, rng, VarInfo (), SampleFromPrior (), context)
867
+ new_ctx = combine_model_and_external_contexts (model. context, context)
868
+ model = contextualize (model, new_ctx)
869
+ return evaluate!! (model, varinfo)
873
870
end
874
871
875
872
"""
876
- evaluate_threadunsafe!!(model, varinfo, context )
873
+ evaluate_threadunsafe!!(model, varinfo)
877
874
878
875
Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`.
879
876
@@ -882,8 +879,8 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
882
879
883
880
See also: [`evaluate_threadsafe!!`](@ref)
884
881
"""
885
- function evaluate_threadunsafe!! (model, varinfo, context )
886
- return _evaluate!! (model, resetlogp!! (varinfo), context )
882
+ function evaluate_threadunsafe!! (model, varinfo)
883
+ return _evaluate!! (model, resetlogp!! (varinfo))
887
884
end
888
885
889
886
"""
@@ -897,31 +894,74 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
897
894
898
895
See also: [`evaluate_threadunsafe!!`](@ref)
899
896
"""
900
- function evaluate_threadsafe!! (model, varinfo, context )
897
+ function evaluate_threadsafe!! (model, varinfo)
901
898
wrapper = ThreadSafeVarInfo (resetlogp!! (varinfo))
902
- result, wrapper_new = _evaluate!! (model, wrapper, context)
899
+ result, wrapper_new = _evaluate!! (model, wrapper)
900
+ # TODO (penelopeysm): If seems that if you pass a TSVI to this method, it
901
+ # will return the underlying VI, which is a bit counterintuitive (because
902
+ # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it
903
+ # again).
903
904
return result, setaccs!! (wrapper_new. varinfo, getaccs (wrapper_new))
904
905
end
905
906
906
907
"""
908
+ _evaluate!!(model::Model, varinfo)
907
909
_evaluate!!(model::Model, varinfo, context)
908
910
909
- Evaluate the `model` with the arguments matching the given `context` and `varinfo` object.
911
+ Evaluate the `model` with the given `varinfo`. If an additional `context` is provided,
912
+ the model's context is combined with that context.
913
+
914
+ This function does not wrap the varinfo in a `ThreadSafeVarInfo`.
910
915
"""
911
- function _evaluate!! (model:: Model , varinfo:: AbstractVarInfo , context :: AbstractContext )
912
- args, kwargs = make_evaluate_args_and_kwargs (model, varinfo, context )
916
+ function _evaluate!! (model:: Model , varinfo:: AbstractVarInfo )
917
+ args, kwargs = make_evaluate_args_and_kwargs (model, varinfo)
913
918
return model. f (args... ; kwargs... )
914
919
end
920
+ function _evaluate!! (model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext )
921
+ # TODO (penelopeysm): We don't really need this, but it's a useful
922
+ # convenience method. We could remove it after we get rid of the
923
+ # evaluate_threadsafe!! stuff (in favour of making users call evaluate!!
924
+ # with a TSVI themselves).
925
+ new_ctx = combine_model_and_external_contexts (model. context, context)
926
+ model = contextualize (model, new_ctx)
927
+ return _evaluate!! (model, varinfo)
928
+ end
915
929
916
930
is_splat_symbol (s:: Symbol ) = startswith (string (s), " #splat#" )
917
931
932
+ """
933
+ combine_model_and_external_contexts(model_context, external_context)
934
+
935
+ Combine a context from a model and an external context into a single context.
936
+
937
+ The resulting context stack has the following structure:
938
+
939
+ `external_context` -> `childcontext(external_context)` -> ... ->
940
+ `model_context` -> `childcontext(model_context)` -> ... ->
941
+ `leafcontext(external_context)`
942
+
943
+ The reason for this is that we want to give `external_context` precedence over
944
+ `model_context`, while also preserving the leaf context of `external_context`.
945
+ We can do this by
946
+
947
+ 1. Set the leaf context of `model_context` to `leafcontext(external_context)`.
948
+ 2. Set leaf context of `external_context` to the context resulting from (1).
949
+ """
950
+ function combine_model_and_external_contexts (
951
+ model_context:: AbstractContext , external_context:: AbstractContext
952
+ )
953
+ return setleafcontext (
954
+ external_context, setleafcontext (model_context, leafcontext (external_context))
955
+ )
956
+ end
957
+
918
958
"""
919
959
make_evaluate_args_and_kwargs(model, varinfo, context)
920
960
921
961
Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e.
922
962
"""
923
963
@generated function make_evaluate_args_and_kwargs (
924
- model:: Model{_F,argnames} , varinfo:: AbstractVarInfo , context :: AbstractContext
964
+ model:: Model{_F,argnames} , varinfo:: AbstractVarInfo
925
965
) where {_F,argnames}
926
966
unwrap_args = [
927
967
if is_splat_symbol (var)
@@ -930,18 +970,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
930
970
:($ matchingvalue (varinfo, model. args.$ var))
931
971
end for var in argnames
932
972
]
933
-
934
- # We want to give `context` precedence over `model.context` while also
935
- # preserving the leaf context of `context`. We can do this by
936
- # 1. Set the leaf context of `model.context` to `leafcontext(context)`.
937
- # 2. Set leaf context of `context` to the context resulting from (1).
938
- # The result is:
939
- # `context` -> `childcontext(context)` -> ... -> `model.context`
940
- # -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
941
973
return quote
942
- context_new = setleafcontext (
943
- context, setleafcontext (model. context, leafcontext (context))
944
- )
945
974
args = (
946
975
model,
947
976
# Maybe perform `invlink!!` once prior to evaluation to avoid
0 commit comments