From b3d53c814afce1e876a13f81ae0df6e7b4ddb01e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 01:40:13 +0100 Subject: [PATCH 01/17] Replace `evaluate_and_sample!!` -> `init!!` --- docs/src/api.md | 12 +++--- src/extract_priors.jl | 2 +- src/model.jl | 47 ++++---------------- src/sampler.jl | 3 +- src/simple_varinfo.jl | 47 ++++++++++++-------- src/test_utils/contexts.jl | 72 ++++++++++++++++++++----------- src/test_utils/model_interface.jl | 4 +- src/varinfo.jl | 68 ++++++++++++++++------------- test/compiler.jl | 13 +++--- test/contexts.jl | 19 ++++---- test/model.jl | 25 +---------- test/sampler.jl | 4 +- test/simple_varinfo.jl | 8 ++-- test/varinfo.jl | 53 +++++++++++------------ test/varnamedvector.jl | 4 +- 15 files changed, 179 insertions(+), 202 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 86b70011f..ff26e2346 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -456,11 +456,6 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. -To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: - -```@docs -DynamicPPL.evaluate_and_sample!! -``` The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. @@ -475,7 +470,12 @@ InitContext ### VarInfo initialisation -`InitContext` is used to initialise, or overwrite, values in a VarInfo. +The function `init!!` is used to initialise, or overwrite, values in a VarInfo. +It is really a thin wrapper around using `evaluate!!` with an `InitContext`. + +```@docs +DynamicPPL.init!! +``` To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. There are three concrete strategies provided in DynamicPPL: diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 64dcf2eea..5342d70c4 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -117,7 +117,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) = function extract_priors(rng::Random.AbstractRNG, model::Model) varinfo = VarInfo() varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) - varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/model.jl b/src/model.jl index 6be4eb383..7d7a7921d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -850,7 +850,7 @@ end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(evaluate_and_sample!!(rng, model, varinfo)) + return first(init!!(rng, model, varinfo)) end """ @@ -863,46 +863,19 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) return Threads.nthreads() > 1 end -""" - evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) - -Evaluate the `model` with the given `varinfo`, but perform sampling during the -evaluation using the given `sampler` by wrapping the model's context in a -`SamplingContext`. - -If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). - -Returns a tuple of the model's return value, plus the updated `varinfo` object. -""" -function evaluate_and_sample!!( - rng::Random.AbstractRNG, - model::Model, - varinfo::AbstractVarInfo, - sampler::AbstractSampler=SampleFromPrior(), -) - sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return evaluate!!(sampling_model, varinfo) -end -function evaluate_and_sample!!( - model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() -) - return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) -end - """ init!!( - [rng::Random.AbstractRNG,] + [rng::Random.AbstractRNG, ] model::Model, varinfo::AbstractVarInfo, [init_strategy::AbstractInitStrategy=PriorInit()] ) -Evaluate the `model` and replace the values of the model's random variables in -the given `varinfo` with new values using a specified initialisation strategy. -If the values in `varinfo` are not already present, they will be added using -that same strategy. - -If `init_strategy` is not provided, defaults to PriorInit(). +Evaluate the `model` and replace the values of the model's random variables +in the given `varinfo` with new values, using a specified initialisation strategy. +If the values in `varinfo` are not set, they will be added. +using a specified initialisation strategy. If `init_strategy` is not provided, +defaults to PriorInit(). Returns a tuple of the model's return value, plus the updated `varinfo` object. """ @@ -1049,11 +1022,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate_and_sample!!( - rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) - ), - ) + x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) return values_as(x, T) end diff --git a/src/sampler.jl b/src/sampler.jl index 673b5128f..184e4b70a 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,7 +58,8 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) + strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit() + DynamicPPL.init!!(rng, model, vi, strategy) return vi, nothing end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d897ae021..53c9b4715 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -39,7 +39,7 @@ julia> rng = StableRNG(42); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); +julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); + _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -232,24 +232,25 @@ end # Constructor from `Model`. function SimpleVarInfo{T}( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) return last(evaluate!!(new_model, SimpleVarInfo{T}())) end function SimpleVarInfo{T}( - model::Model, sampler::AbstractSampler=SampleFromPrior() + model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, sampler) + return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) end # Constructors without type param function SimpleVarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return SimpleVarInfo{LogProbType}(rng, model, sampler) + return SimpleVarInfo{LogProbType}(rng, model, init_strategy) end -function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) +function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) end # Constructor from `VarInfo`. @@ -265,12 +266,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) @@ -480,7 +481,6 @@ function assume( return value, vi end -# NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end @@ -490,6 +490,15 @@ end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) end +function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) + # We keep this method around just to obey the AbstractVarInfo interface; however, + # this is only a valid operation if it would be a no-op. + if trans != istrans(vi) + error( + "Individual variables in SimpleVarInfo cannot have different `settrans` statuses.", + ) + end +end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 863db4262..4a019441b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -29,21 +29,45 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod node_trait = DynamicPPL.NodeTrait(context) # Throw error immediately if it it's missing a `NodeTrait` implementation. node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || - throw(ValueError("Invalid NodeTrait: $node_trait")) + error("Invalid NodeTrait: $node_trait") - # To see change, let's make sure we're using a different leaf context than the current. - leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + if node_trait isa DynamicPPL.IsLeaf + test_leaf_context(context, model) else - DefaultContext() + test_parent_context(context, model) end - @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == - leafcontext_new +end + +function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf + + # Note that for a leaf context we can't assume that it will work with an + # empty VarInfo. Thus we only test evaluation (i.e., assuming that the + # varinfo already contains all necessary variables). + @testset "evaluation" begin + # Generate a new filled untyped varinfo + _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + new_model = contextualize(model, context) + for vi in [untyped_vi, typed_vi] + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end +end + +function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - # The interface methods. - if node_trait isa DynamicPPL.IsParent - # `childcontext` and `setchildcontext` - # With new child context + @testset "{set,}{leaf,child}context" begin + # Ensure we're using a different leaf context than the current. + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + DynamicPPL.DynamicTransformationContext{false}() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new childcontext_new = TestParentContext() @test DynamicPPL.childcontext( DynamicPPL.setchildcontext(context, childcontext_new) @@ -56,19 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod leafcontext_new end - # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). - # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. - # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the - # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. - # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. - # Untyped varinfo. - varinfo_untyped = DynamicPPL.VarInfo() - model_with_spl = contextualize(model, SamplingContext(context)) - model_without_spl = contextualize(model, context) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any - # Typed varinfo. - varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any + @testset "initialisation and evaluation" begin + new_model = contextualize(model, context) + for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 93aed074c..cb949464e 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,9 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect( - keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) - ) + return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index f280eb98b..82a648bf5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -113,10 +113,14 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler]) + VarInfo( + [rng::Random.AbstractRNG], + model, + [init_strategy::AbstractInitStrategy] + ) -Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`. +Generate a `VarInfo` object for the given `model`, by initialising it with the +given `rng` and `init_strategy`. !!! warning @@ -129,12 +133,12 @@ the given `rng`, `sampler`. instead. """ function VarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_varinfo(rng, model, sampler) + return typed_varinfo(rng, model, init_strategy) end -function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return VarInfo(Random.default_rng(), model, sampler) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return VarInfo(Random.default_rng(), model, init_strategy) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -195,7 +199,7 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler]) + untyped_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -203,15 +207,15 @@ Construct a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) + return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_varinfo(Random.default_rng(), model, sampler) +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -270,7 +274,7 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler]) + typed_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. @@ -278,19 +282,19 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function typed_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_varinfo(untyped_varinfo(rng, model, sampler)) + return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_varinfo(Random.default_rng(), model, sampler) +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return typed_varinfo(Random.default_rng(), model, init_strategy) end """ - untyped_vector_varinfo([rng, ]model[, sampler]) + untyped_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has just a single `VarNamedVector` as its metadata field. @@ -298,23 +302,25 @@ Return a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) + return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_vector_varinfo(Random.default_rng(), model, sampler) +function untyped_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=PriorInit() +) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_vector_varinfo([rng, ]model[, sampler]) + typed_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -322,7 +328,7 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -334,12 +340,12 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) end -function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_vector_varinfo(Random.default_rng(), model, sampler) +function typed_vector_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end """ diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..874b71204 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,11 +193,8 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - # During the model evaluation, its context is wrapped in a - # SamplingContext, so `model_` is not going to be equal to `model`. - # We can still check equality of `f` though. @test model_.f === model.f - @test model_.context isa SamplingContext + @test model_.context isa DynamicPPL.InitContext @test model_.context.rng isa Random.AbstractRNG # disable warnings @@ -598,13 +595,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) + retval_and_vi = DynamicPPL.init!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -620,11 +617,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 3819c4564..61b40fae5 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -166,29 +166,30 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test new_vn == @varname(a.x[1]) @test new_ctx == DefaultContext() - ctx2 = SamplingContext(PrefixContext(@varname(a))) + ctx2 = FixedContext((b=4,), PrefixContext(@varname(a))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext() + @test new_ctx == FixedContext((b=4,)) ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == ConditionContext((a=1,)) - ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + ctx4 = FixedContext( + (b=4,), PrefixContext(@varname(a), ConditionContext((a=1,))) + ) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext(ConditionContext((a=1,))) + @test new_ctx == FixedContext((b=4,)ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) - sampling_model = contextualize(model, context) - # Sample with the context. - varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(sampling_model, varinfo) + context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) + new_model = contextualize(model, context) + # Initialize a new varinfo with the prefixed model + _, varinfo = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) diff --git a/test/model.jl b/test/model.jl index daa3cc743..7f4313ee7 100644 --- a/test/model.jl +++ b/test/model.jl @@ -155,24 +155,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() logjoint(model, chain) end - @testset "rng" begin - model = GDEMO_DEFAULT - - for sampler in (SampleFromPrior(), SampleFromUniform()) - for i in 1:10 - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - vals = vi[:] - - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - @test vi[:] == vals - end - end - end - @testset "defaults without VarInfo, Sampler, and Context" begin model = GDEMO_DEFAULT @@ -332,7 +314,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) + vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -591,10 +573,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [ - last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for - _ in 1:10000 - ] + chain = [VarInfo(m_lin_reg) _ in 1:10000] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/sampler.jl b/test/sampler.jl index fe9fd331a..d04f39ac9 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -69,8 +69,8 @@ end # initial samplers - DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() - @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() + DynamicPPL.init_strategy(::Sampler{OnlyInitAlgUniform}) = UniformInit() + @test DynamicPPL.init_strategy(Sampler(OnlyInitAlgDefault())) == PriorInit() for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) # model with one variable: initialization p = 0.2 diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 2f080e66b..f57a955d2 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -158,7 +158,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) + _, svi_new = DynamicPPL.init!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -226,9 +226,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) + svi_nt = last(DynamicPPL.init!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) + svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -273,7 +273,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) + vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. diff --git a/test/varinfo.jl b/test/varinfo.jl index 19052e9f1..68c5bef30 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -44,7 +44,7 @@ end end model = gdemo(1.0, 2.0) - vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, VarInfo(), UniformInit()) tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata @@ -486,17 +486,18 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model using SampleFromUniform does not + # Check that instantiating the model using UniformInit does not # perform linking - # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) - # specifically in this test is because SFU samples from the linked - # distribution i.e. in unconstrained space. However, it does this not - # by linking the varinfo but by transforming the distributions on the - # fly. That's why it's worth specifically checking that it can do this - # without having to change the VarInfo object. + # Note (penelopeysm): The purpose of using UniformInit specifically in + # this test is because it samples from the linked distribution i.e. in + # unconstrained space. However, it does this not by linking the varinfo + # but by transforming the distributions on the fly. That's why it's + # worth specifically checking that it can do this without having to + # change the VarInfo object. + # TODO(penelopeysm): Move this to UniformInit tests rather than here. vi = VarInfo() meta = vi.metadata - _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, vi, UniformInit()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -569,8 +570,8 @@ end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -578,8 +579,8 @@ end ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -587,8 +588,8 @@ end ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -596,24 +597,24 @@ end ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -994,10 +995,9 @@ end end model1 = demo(1) varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. + # Calling init!! should preserve the fact that the variables are linked. model2 = demo(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit())) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -1013,10 +1013,9 @@ end end model1 = demo_dot(1) varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. + # Calling init!! should preserve the fact that the variables are linked. model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit())) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 57a8175d4..af24be86f 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -610,9 +610,7 @@ end DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) - ) + varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. From 92aa2745ce062305b61072338c1c5c480932e1ef Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 01:51:13 +0100 Subject: [PATCH 02/17] Use `ParamsInit` for `predict`; remove `setval_and_resample!` and friends --- ext/DynamicPPLMCMCChainsExt.jl | 38 +++++--- src/model.jl | 11 ++- src/varinfo.jl | 143 ---------------------------- test/ext/DynamicPPLMCMCChainsExt.jl | 7 +- test/model.jl | 2 +- test/test_util.jl | 4 +- test/varinfo.jl | 60 +----------- 7 files changed, 49 insertions(+), 216 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index a29696720..cd86cfb5e 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -28,7 +28,7 @@ end function _check_varname_indexing(c::MCMCChains.Chains) return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using `VarName`s.") + error("This `Chains` object does not support indexing using `VarName`s.") end function DynamicPPL.getindex_varname( @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx) + _check_varname_indexing(c) + d = Dict{DynamicPPL.VarName,Any}() + for vn in DynamicPPL.varnames(c) + d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) + end + return d +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -114,9 +123,15 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) - + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict` + _, varinfo = DynamicPPL.init!!( + rng, + model, + varinfo, + DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()), + ) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( collect, @@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to the `model`. - model(deepcopy(varinfo)) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict`, and + # return the model's retval. + retval, _ = DynamicPPL.init!!( + model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()) + ) + retval end end diff --git a/src/model.jl b/src/model.jl index 7d7a7921d..af394987f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1200,8 +1200,15 @@ function predict( varinfo = DynamicPPL.VarInfo(model) return map(chain) do params_varinfo vi = deepcopy(varinfo) - DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi) + # TODO(penelopeysm): Requires two model evaluations, one to extract the + # parameters and one to set them. The reason why we need values_as_in_model + # is because `params_varinfo` may well have some weird combination of + # linked/unlinked, whereas `varinfo` is always unlinked since it is + # freshly constructed. + # This is quite inefficient. It would of course be alright if + # ValuesAsInModelAccumulator was a default acc. + values_nt = values_as_in_model(model, false, params_varinfo) + _, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit())) return vi end end diff --git a/src/varinfo.jl b/src/varinfo.jl index 82a648bf5..1e6afe3ba 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1498,42 +1498,6 @@ function islinked(vi::VarInfo) return any(istrans(vi, vn) for vn in keys(vi)) end -function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) - return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) -end -function nested_setindex_maybe!( - vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym} -) where {names,sym} - return if sym in names - _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) - else - nothing - end -end -function _nested_setindex_maybe!( - vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName -) - # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = Base.keys(md) - if vn in vns - setindex!(vi, val, vn) - return vn - end - - # Otherwise, we need to check if either of the `vns` subsumes `vn`. - i = findfirst(Base.Fix2(subsumes, vn), vns) - i === nothing && return nothing - - vn_parent = vns[i] - val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail optic. - optic = remove_parent_optic(vn_parent, vn) - # Update the value for the parent. - val_parent_updated = set!!(val_parent, optic, val) - setindex!(vi, val_parent_updated, vn_parent) - return vn_parent -end - # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type function getindex(vi::VarInfo, vn::VarName) @@ -1978,113 +1942,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke return indices end -""" - setval_and_resample!(vi::VarInfo, x) - setval_and_resample!(vi::VarInfo, values, keys) - setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call -`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means -that the next time we call `model(vi)` these variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) - return setval_and_resample!(vi, values(x), keys(x)) -end -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) - return _apply!(_setval_and_resample_kernel!, vi, values, keys) -end -function setval_and_resample!( - vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - if supports_varname_indexing(chains) - # First we need to set every variable to be resampled. - for vn in keys(vi) - set_flag!(vi, vn, "del") - end - # Then we set the variables in `varinfo` from `chain`. - for vn in varnames(chains) - vn_updated = nested_setindex_maybe!( - vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn - ) - - # Unset the `del` flag if we found something. - if vn_updated !== nothing - # NOTE: This will be triggered even if only a subset of a variable has been set! - unset_flag!(vi, vn_updated, "del") - end - end - else - setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) - end -end - -function _setval_and_resample_kernel!( - vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys -) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - settrans!!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") - end - - return indices -end - values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 3ba5edfe1..79e13ad84 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -2,7 +2,12 @@ @model demo() = x ~ Normal() model = demo() - chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) + chain = MCMCChains.Chains( + randn(1000, 2, 1), + [:x, :y], + Dict(:internals => [:y]); + info=(; varname_to_symbol=Dict(@varname(x) => :x)), + ) chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 diff --git a/test/model.jl b/test/model.jl index 7f4313ee7..18e5e633f 100644 --- a/test/model.jl +++ b/test/model.jl @@ -573,7 +573,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [VarInfo(m_lin_reg) _ in 1:10000] + chain = [VarInfo(m_lin_reg) for _ in 1:10000] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/test_util.jl b/test/test_util.jl index d5335249d..b7c46ff34 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I varnames = collect(varnames) # Construct matrix of values vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct dict of varnames -> symbol + vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) # Construct and return the Chains object - return Chains(vals, varnames) + return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters) diff --git a/test/varinfo.jl b/test/varinfo.jl index 68c5bef30..63c681665 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -278,7 +278,7 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval! & setval_and_resample!" begin + @testset "setval!" begin @model function testmodel(x) n = length(x) s ~ truncated(Normal(); lower=0) @@ -329,8 +329,8 @@ end else DynamicPPL.setval!(vicopy, (m=zeros(5),)) end - # Setting `m` fails for univariate due to limitations of `setval!` - # and `setval_and_resample!`. See docstring of `setval!` for more info. + # Setting `m` fails for univariate due to limitations of `setval!`. + # See docstring of `setval!` for more info. if model == model_uv && vi in [vi_untyped, vi_typed] @test_broken vicopy[m_vns] == zeros(5) else @@ -355,57 +355,6 @@ end DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - if vi in [vi_vnv, vi_vnv_typed] - # `setval_and_resample!` works differently for `VarNamedVector`: All - # values will be resampled when model(vicopy) is called. Hence the below - # tests are not applicable. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - # Ordering is NOT preserved. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] != vi[s_vns] - - # Correct ordering. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 end end @@ -419,9 +368,6 @@ end ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals - - DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals end @testset "setval! on chain" begin From 6d38b6024adbd271a57bcccdf3174d2d2c3fd79c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 12:31:37 +0100 Subject: [PATCH 03/17] Use `init!!` for initialisation --- docs/src/api.md | 2 +- src/sampler.jl | 148 +++++++++--------------------------------------- test/sampler.jl | 64 ++++----------------- 3 files changed, 40 insertions(+), 174 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index ff26e2346..ee62880e2 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -514,7 +514,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu ```@docs DynamicPPL.initialstep DynamicPPL.loadstate -DynamicPPL.initialsampler +DynamicPPL.init_strategy ``` Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. diff --git a/src/sampler.jl b/src/sampler.jl index 184e4b70a..d43b840a5 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -68,6 +68,8 @@ end Return a default varinfo object for the given `model` and `sampler`. +The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo'). + # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. @@ -76,9 +78,10 @@ Return a default varinfo object for the given `model` and `sampler`. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ -function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler) +function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler) + # Note that variable values are unconditionally initialized later, so no + # point putting them in now. + return typed_varinfo(VarInfo()) end function AbstractMCMC.sample( @@ -96,24 +99,32 @@ function AbstractMCMC.sample( ) end -# initial step: general interface for resuming and +""" + init_strategy(sampler) + +Define the initialisation strategy used for generating initial values when +sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden. +""" +init_strategy(::Sampler) = PriorInit() + function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... + rng::Random.AbstractRNG, + model::Model, + spl::Sampler; + initial_params::AbstractInitStrategy=init_strategy(spl), + kwargs..., ) - # Sample initial values. + # Generate the default varinfo (usually this just makes an empty VarInfo + # with NamedTuple of Metadata). vi = default_varinfo(rng, model, spl) - # Update the parameters if provided. - if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi)) - end + # Fill it with initial parameters. Note that, if `ParamsInit` is used, the + # parameters provided must be in unlinked space (when inserted into the + # varinfo, they will be adjusted to match the linking status of the + # varinfo). + _, vi = init!!(rng, model, vi, initial_params) + # Call the actual function that does the first step. return initialstep(rng, model, spl, vi; initial_params, kwargs...) end @@ -131,110 +142,7 @@ loadstate(data) = data Default type of the chain of posterior samples from `sampler`. """ -default_chain_type(sampler::Sampler) = Any - -""" - initialsampler(sampler::Sampler) - -Return the sampler that is used for generating the initial parameters when sampling with -`sampler`. - -By default, it returns an instance of [`SampleFromPrior`](@ref). -""" -initialsampler(spl::Sampler) = SampleFromPrior() - -""" - set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - -Take the values inside `initial_params`, replace the corresponding values in -the given VarInfo object, and return a new VarInfo object with the updated values. - -This differs from `DynamicPPL.unflatten` in two ways: - -1. It works with `NamedTuple` arguments. -2. For the `AbstractVector` method, if any of the elements are missing, it will not -overwrite the original value in the VarInfo (it will just use the original -value instead). -""" -function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - throw( - ArgumentError( - "`initial_params` must be a vector of type `Union{Real,Missing}`. " * - "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", - ), - ) -end - -function set_initial_values( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} -) - flattened_param_vals = varinfo[:] - length(flattened_param_vals) == length(initial_params) || throw( - DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match " * - "the model size ($(length(flattened_param_vals))).", - ), - ) - - # Update values that are provided. - for i in eachindex(initial_params) - x = initial_params[i] - if x !== missing - flattened_param_vals[i] = x - end - end - - # Update in `varinfo`. - new_varinfo = unflatten(varinfo, flattened_param_vals) - return new_varinfo -end - -function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - varinfo = deepcopy(varinfo) - vars_in_varinfo = keys(varinfo) - for v in keys(initial_params) - vn = VarName{v}() - if !(vn in vars_in_varinfo) - for vv in vars_in_varinfo - if subsumes(vn, vv) - throw( - ArgumentError( - "The current model contains sub-variables of $v, such as ($vv). " * - "Using NamedTuple for initial_params is not supported in such a case. " * - "Please use AbstractVector for initial_params instead of NamedTuple.", - ), - ) - end - end - throw(ArgumentError("Variable $v not found in the model.")) - end - end - initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return update_values!!( - varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) - ) -end - -function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) - @debug "Using passed-in initial variable values" initial_params - - # `link` the varinfo if needed. - linked = islinked(vi) - if linked - vi = invlink!!(vi, model) - end - - # Set the values in `vi`. - vi = set_initial_values(vi, initial_params) - - # `invlink` if needed. - if linked - vi = link!!(vi, model) - end - - return vi -end +default_chain_type(::Sampler) = Any """ initialstep(rng, model, sampler, varinfo; kwargs...) diff --git a/test/sampler.jl b/test/sampler.jl index d04f39ac9..9555c80d6 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -82,7 +82,9 @@ sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) let inits = (; p=0.2) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) + chain = sample( + model, sampler, 1; initial_params=ParamsInit(inits), progress=false + ) @test chain[1].metadata.p.vals == [0.2] @test getlogjoint(chain[1]) == lptrue @@ -110,7 +112,9 @@ model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) for inits in ([4, -1], (; s=4, m=-1)) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) + chain = sample( + model, sampler, 1; initial_params=ParamsInit(inits), progress=false + ) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @test getlogjoint(chain[1]) == lptrue @@ -122,7 +126,7 @@ MCMCThreads(), 1, 10; - initial_params=fill(inits, 10), + initial_params=fill(ParamsInit(inits), 10), progress=false, ) for c in chains @@ -133,8 +137,10 @@ end # set only m = -1 - for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) + for inits in ((; s=missing, m=-1), (; m=-1)) + chain = sample( + model, sampler, 1; initial_params=ParamsInit(inits), progress=false + ) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] @@ -153,54 +159,6 @@ @test c[1].metadata.m.vals == [-1] end end - - # specify `initial_params=nothing` - Random.seed!(1234) - chain1 = sample(model, sampler, 1; progress=false) - Random.seed!(1234) - chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) - @test_throws DimensionMismatch sample( - model, sampler, 1; progress=false, initial_params=zeros(10) - ) - @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals - @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals - - # parallel sampling - Random.seed!(1234) - chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) - Random.seed!(1234) - chains2 = sample( - model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false - ) - for (c1, c2) in zip(chains1, chains2) - @test c1[1].metadata.m.vals == c2[1].metadata.m.vals - @test c1[1].metadata.s.vals == c2[1].metadata.s.vals - end - end - - @testset "error handling" begin - # https://github.com/TuringLang/Turing.jl/issues/2452 - @model function constrained_uniform(n) - Z ~ Uniform(10, 20) - X = Vector{Float64}(undef, n) - for i in 1:n - X[i] ~ Uniform(0, Z) - end - end - - n = 2 - initial_z = 15 - initial_x = [0.2, 0.5] - model = constrained_uniform(n) - vi = VarInfo(model) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], model - ) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), model - ) end end end From 7af03a385e6744828edb476ff68d8306c6c5f32f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 16:10:00 +0100 Subject: [PATCH 04/17] Paper over the `Sampling->Init` context stack (pending removal of SamplingContext) --- src/context_implementations.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9e9a2d63d..c493647d8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -28,6 +28,13 @@ end function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end +function tilde_assume(rng::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi) + @warn( + "Encountered SamplingContext->InitContext. This method will be removed in the next PR.", + ) + # just pretend the `InitContext` isn't there for now. + return assume(rng, sampler, right, vn, vi) +end function tilde_assume(::DefaultContext, sampler, right, vn, vi) # same as above but no rng return assume(Random.default_rng(), sampler, right, vn, vi) From 2edcd10c6217de1ed1171ff9a0c207bbfe25c4a8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 16:25:00 +0100 Subject: [PATCH 05/17] Remove SamplingContext from JETExt to avoid triggering `Sampling->Init` pathway --- ext/DynamicPPLJETExt.jl | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 760d17bb0..89a36ffaf 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -21,22 +21,17 @@ end function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model; only_ddpl::Bool=true ) - # Use SamplingContext to test type stability. - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(sampling_model) + varinfo = DynamicPPL.typed_varinfo(model) - # Let's make sure that both evaluation and sampling doesn't result in type errors. + # Let's make sure that evaluation doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - sampling_model, varinfo; only_ddpl + model, varinfo; only_ddpl ) if !issuccess # Useful information for debugging. - @debug "Evaluaton with typed varinfo failed with the following issues:" + @debug "Evaluation with typed varinfo failed with the following issues:" @debug result end @@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(sampling_model) + DynamicPPL.untyped_varinfo(model) end end From 8520ec3634354081656341d1617d048ec20ea8b6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:05:44 +0100 Subject: [PATCH 06/17] Un-fix `predict` on varinfo --- src/model.jl | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/model.jl b/src/model.jl index af394987f..cc2b54567 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1200,15 +1200,24 @@ function predict( varinfo = DynamicPPL.VarInfo(model) return map(chain) do params_varinfo vi = deepcopy(varinfo) - # TODO(penelopeysm): Requires two model evaluations, one to extract the - # parameters and one to set them. The reason why we need values_as_in_model - # is because `params_varinfo` may well have some weird combination of - # linked/unlinked, whereas `varinfo` is always unlinked since it is - # freshly constructed. - # This is quite inefficient. It would of course be alright if - # ValuesAsInModelAccumulator was a default acc. - values_nt = values_as_in_model(model, false, params_varinfo) - _, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit())) + # TODO(penelopeysm): BEWARE - Because `ParamsInit` expects unlinked + # values, this is bugged in the case where `params_varinfo` is linked. + # This could be solved by using `values_as_in_model`. However, we run + # into another problem: + # - The `model` passed into this function will have the target + # prediction variables (say `y`) set to `missing`. + # - Calling `values_as_in_model` on this new model will lead to + # DynamicPPL attempting to extract the values of `y` from the + # `params_varinfo`, which will fail since `y` will not have been + # present. + # I think that the solution is to pass in the ORIGINAL model, but ALSO + # pass in a set / iterable of VarNames that are to be predicted against. + # That way, we can: + # - Safely use `values_as_in_model` to extract a dict of values. + # - Drop the values of the prediction variables from the dict. + # - Then use `ParamsInit` to generate the predictions. + values_dict = values_as(params_varinfo, Dict{VarName,Any}) + _, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_dict, PriorInit())) return vi end end From eb2b0d2dbed1f6f9c6fe252b790bf26e9bc5f6f6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:05:59 +0100 Subject: [PATCH 07/17] Fix some tests --- test/sampler.jl | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/test/sampler.jl b/test/sampler.jl index 9555c80d6..df1d5e908 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -81,10 +81,8 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - let inits = (; p=0.2) - chain = sample( - model, sampler, 1; initial_params=ParamsInit(inits), progress=false - ) + let inits = ParamsInit((; p=0.2)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] @test getlogjoint(chain[1]) == lptrue @@ -111,10 +109,8 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - for inits in ([4, -1], (; s=4, m=-1)) - chain = sample( - model, sampler, 1; initial_params=ParamsInit(inits), progress=false - ) + let inits = ParamsInit((; s=4, m=-1)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @test getlogjoint(chain[1]) == lptrue @@ -126,7 +122,7 @@ MCMCThreads(), 1, 10; - initial_params=fill(ParamsInit(inits), 10), + initial_params=fill(inits, 10), progress=false, ) for c in chains @@ -137,10 +133,8 @@ end # set only m = -1 - for inits in ((; s=missing, m=-1), (; m=-1)) - chain = sample( - model, sampler, 1; initial_params=ParamsInit(inits), progress=false - ) + for inits in (ParamsInit((; s=missing, m=-1)), ParamsInit((; m=-1))) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] From 509f50e520188dd27fb32b4d7c48c7bc3cd0b0fd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:08:43 +0100 Subject: [PATCH 08/17] Fix more tests --- src/sampler.jl | 4 ++-- test/sampler.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index d43b840a5..6b98b3fa2 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -59,8 +59,8 @@ function AbstractMCMC.step( ) vi = VarInfo() strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit() - DynamicPPL.init!!(rng, model, vi, strategy) - return vi, nothing + _, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) + return new_vi, nothing end """ diff --git a/test/sampler.jl b/test/sampler.jl index df1d5e908..1a61b2f1f 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -25,8 +25,8 @@ @test length(chains) == N # `m` is Gaussian, i.e. no transformation is used, so it - # should have a mean equal to its prior, i.e. 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 + # will be drawn from U[-2, 2] and its mean should be 0. + @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 From 34fbc547671e7ece23f47d33014a6554b1eab7f3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:23:39 +0100 Subject: [PATCH 09/17] Throw an error in `predict` if any variable is transformed --- src/model.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/model.jl b/src/model.jl index cc2b54567..8649fb598 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1216,6 +1216,14 @@ function predict( # - Safely use `values_as_in_model` to extract a dict of values. # - Drop the values of the prediction variables from the dict. # - Then use `ParamsInit` to generate the predictions. + for vn in keys(vi) + # note that istrans(vi) checks that ALL variables are transformed, + # whereas the failure here happens if ANY variable is transformed. + # so we have to manually loop over keys + istrans(vi, vn) && error( + "`predict(rng, model, ::AbstractArray{<:AbstractVarInfo})` will give you wrong results if the `VarInfo` contains transformed variables and is therefore currently forbidden. Please provide an unlinked `VarInfo` instead.", + ) + end values_dict = values_as(params_varinfo, Dict{VarName,Any}) _, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_dict, PriorInit())) return vi From 70b46822faa96b5475078cc1d5bc9703181e0228 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:24:50 +0100 Subject: [PATCH 10/17] Fix docs --- src/sampler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index 6b98b3fa2..c4d040959 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -41,7 +41,7 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL. provided that supports resuming sampling from a previous state and setting initial parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref) for loading previous states and actually performing the initial sampling step, -respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref) +respectively. Additionally, sometimes one might want to implement [`init_strategy`](@ref) that specifies how the initial parameter values are sampled if they are not provided. By default, values are sampled from the prior. """ From 0f8804f904d58b9ee7ff78f8792be47835506c2a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:25:32 +0100 Subject: [PATCH 11/17] Fix typo in test --- test/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contexts.jl b/test/contexts.jl index 61b40fae5..d40a80936 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -181,7 +181,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() ) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == FixedContext((b=4,)ConditionContext((a=1,))) + @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS From cbbfb46df65481ee96726bfa8f83e192808714ef Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:31:16 +0100 Subject: [PATCH 12/17] Add varname info to chain --- test/model.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/model.jl b/test/model.jl index 18e5e633f..d2479eca1 100644 --- a/test/model.jl +++ b/test/model.jl @@ -488,7 +488,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct a chain with 'sampled values' of β ground_truth_β = 2 - β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), + ) # Generate predictions from that chain xs_test = [10 + 0.1, 10 + 2 * 0.1] @@ -534,7 +538,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "prediction from multiple chains" begin # Normal linreg model multiple_β_chain = MCMCChains.Chains( - reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), ) predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) @test size(multiple_β_chain, 3) == size(predictions, 3) From 9bfb5c262e8a0e9e91b756f97b79f8d9339ed913 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:35:30 +0100 Subject: [PATCH 13/17] revert unnecessary changes that complicate the diff --- src/model.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 8649fb598..2e05dfff7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -865,7 +865,7 @@ end """ init!!( - [rng::Random.AbstractRNG, ] + [rng::Random.AbstractRNG,] model::Model, varinfo::AbstractVarInfo, [init_strategy::AbstractInitStrategy=PriorInit()] @@ -874,8 +874,9 @@ end Evaluate the `model` and replace the values of the model's random variables in the given `varinfo` with new values, using a specified initialisation strategy. If the values in `varinfo` are not set, they will be added. -using a specified initialisation strategy. If `init_strategy` is not provided, -defaults to PriorInit(). +using a specified initialisation strategy. + +If `init_strategy` is not provided, defaults to PriorInit(). Returns a tuple of the model's return value, plus the updated `varinfo` object. """ From 726c4e24876edeb6729b11288825ceb236ab1a94 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:40:33 +0100 Subject: [PATCH 14/17] Some simplifications --- src/simple_varinfo.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 53c9b4715..465fd7f94 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -234,9 +234,7 @@ end function SimpleVarInfo{T}( rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) - new_model = contextualize(model, new_context) - return last(evaluate!!(new_model, SimpleVarInfo{T}())) + return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) end function SimpleVarInfo{T}( model::Model, init_strategy::AbstractInitStrategy=PriorInit() @@ -491,8 +489,9 @@ function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) end function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) - # We keep this method around just to obey the AbstractVarInfo interface; however, - # this is only a valid operation if it would be a no-op. + # We keep this method around just to obey the AbstractVarInfo interface. + # However, note that this would only be a valid operation if it would be a + # no-op, which we check here. if trans != istrans(vi) error( "Individual variables in SimpleVarInfo cannot have different `settrans` statuses.", From 0914281d945a5ffe919e6d4f2bfb62fe9d2508aa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 13:08:47 +0100 Subject: [PATCH 15/17] Remove duplicated test --- test/varinfo.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 63c681665..4adab1205 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -512,18 +512,6 @@ end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) - - ## `untyped_varinfo` - vi = DynamicPPL.untyped_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) - # Resample in unconstrained space. - vi = last(DynamicPPL.init!!(model, vi, PriorInit())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Resample in unconstrained space. vi = last(DynamicPPL.init!!(model, vi, PriorInit())) From 590c8f2778180617b82a0e22599a2946ec97cdd7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 13:19:33 +0100 Subject: [PATCH 16/17] Fix JET test (correctly) --- test/ext/DynamicPPLJETExt.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 6737cf056..9f7e05cb0 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -40,6 +40,11 @@ end end @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa + DynamicPPL.NTVarInfo + init_model = DynamicPPL.contextualize( + demo4(), DynamicPPL.InitContext(DynamicPPL.PriorInit()) + ) + @test DynamicPPL.Experimental.determine_suitable_varinfo(init_model) isa DynamicPPL.UntypedVarInfo # In this model, the type error occurs in the user code rather than in DynamicPPL. From 502559207ca80e4e3bfb3da13a4e11bfac30b07b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 17:31:12 +0100 Subject: [PATCH 17/17] Fix missing `istrans()` method for TSVI{SVI} --- src/simple_varinfo.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 465fd7f94..cdf99cbd5 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -502,6 +502,7 @@ end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) +istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) islinked(vi::SimpleVarInfo) = istrans(vi)