diff --git a/HISTORY.md b/HISTORY.md index a956bd188..1af5c2ca3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,10 +4,25 @@ **Breaking changes** -### VarInfo constructor +### VarInfo constructors `VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). + +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. + +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. + +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. + ### VarName prefixing behaviour The way in which VarNames in submodels are prefixed has been changed. @@ -53,6 +68,20 @@ outer() | (a.x=1.0,) If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +**Other changes** + +While these are technically breaking, they are only internal changes and do not affect the public API. +The following four functions have been added and/or reworked to make it easier to construct VarInfos with different types of metadata: + + 1. `DynamicPPL.untyped_varinfo([rng, ]model[, sampler, context])` + 2. `DynamicPPL.typed_varinfo([rng, ]model[, sampler, context])` + 3. `DynamicPPL.untyped_vector_varinfo([rng, ]model[, sampler, context])` + 4. `DynamicPPL.typed_vector_varinfo([rng, ]model[, sampler, context])` + +The reason for this change is that there were several flavours of VarInfo. +Some, like `typed_varinfo`, were easy to construct because we had convenience methods for them; however, the others were more difficult. +This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing. + ## 0.35.5 Several internal methods have been removed: diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 4c73bf355..16338de2f 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -52,8 +52,8 @@ end Create a benchmark suite for `model` using the selected varinfo type and AD backend. Available varinfo choices: - • `:untyped` → uses `VarInfo()` - • `:typed` → uses `VarInfo(model)` + • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` + • `:typed` → uses `DynamicPPL.typed_varinfo(model)` • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) @@ -67,11 +67,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: suite = BenchmarkGroup() vi = if varinfo_choice == :untyped - vi = VarInfo() - model(rng, vi) - vi + DynamicPPL.untyped_varinfo(rng, model) elseif varinfo_choice == :typed - VarInfo(rng, model) + DynamicPPL.typed_varinfo(rng, model) elseif varinfo_choice == :simple_namedtuple SimpleVarInfo{Float64}(model(rng)) elseif varinfo_choice == :simple_dict diff --git a/docs/src/api.md b/docs/src/api.md index 2f6376f5d..f83a96886 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -291,18 +291,17 @@ AbstractVarInfo But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. -For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods: +#### `VarInfo` ```@docs -DynamicPPL.untyped_varinfo -DynamicPPL.typed_varinfo +VarInfo ``` -#### `VarInfo` - ```@docs -VarInfo -TypedVarInfo +DynamicPPL.untyped_varinfo +DynamicPPL.typed_varinfo +DynamicPPL.untyped_vector_varinfo +DynamicPPL.typed_vector_varinfo ``` One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index e6e1f2619..b04913aaf 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -227,13 +227,13 @@ Continuing from the example from the previous section, we can use a `VarInfo` wi ```@example varinfo-design # Type-unstable -varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped) +varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped) varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] ``` ```@example varinfo-design # Type-stable -varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed) +varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed) varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9f45718c5..51fa53079 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -45,8 +45,6 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - UntypedVarInfo, - TypedVarInfo, SimpleVarInfo, push!!, empty!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 44edaa4e9..f11b8a3ec 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -247,11 +247,11 @@ julia> values_as(SimpleVarInfo(data), Vector) 2.0 ``` -`TypedVarInfo`: +`VarInfo` with `NamedTuple` of `Metadata`: ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; @@ -273,11 +273,11 @@ julia> values_as(vi, Vector) 2.0 ``` -`UntypedVarInfo`: +`VarInfo` with `Metadata`: ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); + vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; diff --git a/src/sampler.jl b/src/sampler.jl index ff008cc93..49d910fec 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -86,7 +86,7 @@ function default_varinfo( context::AbstractContext, ) init_sampler = initialsampler(sampler) - return VarInfo(rng, model, init_sampler, context) + return typed_varinfo(rng, model, init_sampler, context) end function AbstractMCMC.sample( diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 064483ddd..abf14b8fc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -10,7 +10,7 @@ Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. $(FIELDS) # Notes -The major differences between this and `TypedVarInfo` are: +The major differences between this and `NTVarInfo` are: 1. `SimpleVarInfo` does not require linearization. 2. `SimpleVarInfo` can use more efficient bijectors. 3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either @@ -244,7 +244,7 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} +function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) end function SimpleVarInfo{T}( diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 5150be64b..7404a9af7 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -94,7 +94,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) # Typed varinfo. - varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped) + varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6a655ded4..539872143 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -27,12 +27,10 @@ function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo - vi_untyped_metadata = VarInfo(DynamicPPL.Metadata()) - vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector()) - model(vi_untyped_metadata) - model(vi_untyped_vnv) - vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata) - vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv) + vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) + vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) + vi_typed_metadata = DynamicPPL.typed_varinfo(model) + vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) diff --git a/src/varinfo.jl b/src/varinfo.jl index 94b1f1c07..360857ef7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -69,34 +69,91 @@ end ########### """ -``` -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo - metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} -end -``` + struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + metadata::Tmeta + logp::Base.RefValue{Tlogp} + num_produce::Base.RefValue{Int} + end + +A light wrapper over some kind of metadata. -A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of -`VarInfo`. If `vi isa VarInfo{<:Metadata}`, then only one `Metadata` instance is used -for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If -`vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each -symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows -for the type specialization of `vi` after the first sampling iteration when all the -symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`. +The type of the metadata can be one of a number of options. It may either be a +`Metadata` or a `VarNamedVector`, _or_, it may be a `NamedTuple` which maps +symbols to `Metadata` or `VarNamedVector` instances. Here, a _symbol_ refers +to a Julia variable and may consist of one or more `VarName`s which appear on +the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both +have the same symbol `x`. -Note: It is the user's responsibility to ensure that each "symbol" is visited at least -once whenever the model is called, regardless of any stochastic branching. Each symbol -refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. +Several type aliases are provided for these forms of VarInfos: +- `VarInfo{<:Metadata}` is `UntypedVarInfo` +- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` +- `VarInfo{<:NamedTuple}` is `NTVarInfo` + +The NamedTuple form, i.e. `NTVarInfo`, is useful for maintaining type stability +of model evaluation. However, the element type of NamedTuples are not contained +in its type itself: thus, there is no way to use the type system to determine +whether the elements of the NamedTuple are `Metadata` or `VarNamedVector`. + +Note that for NTVarInfo, it is the user's responsibility to ensure that each +symbol is visited at least once during model evaluation, regardless of any +stochastic branching. """ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo metadata::Tmeta logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end -const VectorVarInfo = VarInfo{<:VarNamedVector} +VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +""" + VarInfo([rng, ]model[, sampler, context]) + +Generate a `VarInfo` object for the given `model`, by evaluating it once using +the given `rng`, `sampler`, and `context`. + +!!! warning + + This function currently returns a `VarInfo` with its metadata field set to + a `NamedTuple` of `Metadata`. This is an implementation detail. In general, + this function may return any kind of object that satisfies the + `AbstractVarInfo` interface. If you require precise control over the type + of `VarInfo` returned, use the internal functions `untyped_varinfo`, + `typed_varinfo`, `untyped_vector_varinfo`, or `typed_vector_varinfo` + instead. +""" +function VarInfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return typed_varinfo(rng, model, sampler, context) +end +function VarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return VarInfo(Random.default_rng(), model, sampler, context) +end +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return VarInfo(rng, model, SampleFromPrior(), context) +end +function VarInfo(model::Model, context::AbstractContext) + # No sampler, no rng + return VarInfo(Random.default_rng(), model, SampleFromPrior(), context) +end + +const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} -const TypedVarInfo = VarInfo{<:NamedTuple} +# TODO: NTVarInfo carries no information about the type of the actual metadata +# i.e. the elements of the NamedTuple. It could be Metadata or it could be +# VarNamedVector. +# Resolving this ambiguity would likely require us to replace NamedTuple with +# something which carried both its keys as well as its values' types as type +# parameters. +const NTVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } @@ -132,70 +189,245 @@ function metadata_to_varnamedvector(md::Metadata) ) end -function VectorVarInfo(vi::UntypedVarInfo) - md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end - -function VectorVarInfo(vi::TypedVarInfo) - md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end - function has_varnamedvector(vi::VarInfo) return vi.metadata isa VarNamedVector || - (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) + (vi isa NTVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) end +######################## +# VarInfo constructors # +######################## + """ - untyped_varinfo(model[, context, metadata]) + untyped_varinfo([rng, ]model[, sampler, context, metadata]) -Return an untyped varinfo object for the given `model` and `context`. +Return a VarInfo object for the given `model` and `context`, which has just a +single `Metadata` as its metadata field. # Arguments -- `model::Model`: The model for which to create the varinfo object. -- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`. -- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object. - Default: `Metadata()`. +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ +function untyped_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + varinfo = VarInfo(Metadata()) + context = SamplingContext(rng, sampler, context) + return last(evaluate!!(model, varinfo, context)) +end function untyped_varinfo( model::Model, - context::AbstractContext=SamplingContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) - varinfo = VarInfo(metadata) - return last( - evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context)) + # No rng + return untyped_varinfo(Random.default_rng(), model, sampler, context) +end +function untyped_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return untyped_varinfo(rng, model, SampleFromPrior(), context) +end +function untyped_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return untyped_varinfo(model, SampleFromPrior(), context) +end + +""" + typed_varinfo(vi::UntypedVarInfo) + +This function finds all the unique `sym`s from the instances of `VarName{sym}` found in +`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the +global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as +a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each +symbol. +""" +function typed_varinfo(vi::UntypedVarInfo) + meta = vi.metadata + new_metas = Metadata[] + # Symbols of all instances of `VarName{sym}` in `vi.vns` + syms_tuple = Tuple(syms(vi)) + for s in syms_tuple + # Find all indices in `vns` with symbol `s` + inds = findall(vn -> getsym(vn) === s, meta.vns) + n = length(inds) + # New `vns` + sym_vns = getindex.((meta.vns,), inds) + # New idcs + sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) + # New dists + sym_dists = getindex.((meta.dists,), inds) + # New orders + sym_orders = getindex.((meta.orders,), inds) + # New flags + sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) + + # Extract new ranges and vals + _ranges = getindex.((meta.ranges,), inds) + # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 + _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] + sym_ranges = Vector{eltype(_ranges)}(undef, n) + start = 0 + for i in 1:n + sym_ranges[i] = (start + 1):(start + length(_vals[i])) + start += length(_vals[i]) + end + sym_vals = foldl(vcat, _vals) + + push!( + new_metas, + Metadata( + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags + ), + ) + end + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple{syms_tuple}(Tuple(new_metas)) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end +function typed_varinfo(vi::NTVarInfo) + # This function preserves the behaviour of typed_varinfo(vi) where vi is + # already a NTVarInfo + has_varnamedvector(vi) && error( + "Cannot convert VarInfo with NamedTuple of VarNamedVector to VarInfo with NamedTuple of Metadata", ) + return vi +end +""" + typed_varinfo([rng, ]model[, sampler, context, metadata]) + +Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +`Metadata` structs as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function typed_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return typed_varinfo(untyped_varinfo(rng, model, sampler, context)) +end +function typed_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return typed_varinfo(Random.default_rng(), model, sampler, context) +end +function typed_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return typed_varinfo(rng, model, SampleFromPrior(), context) +end +function typed_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return typed_varinfo(model, SampleFromPrior(), context) end """ - typed_varinfo(model[, context, metadata]) + untyped_vector_varinfo([rng, ]model[, sampler, context, metadata]) -Return a typed varinfo object for the given `model`, `sampler` and `context`. +Return a VarInfo object for the given `model` and `context`, which has just a +single `VarNamedVector` as its metadata field. -This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting -varinfo object to a typed varinfo object. +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function untyped_vector_varinfo(vi::UntypedVarInfo) + md = metadata_to_varnamedvector(vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end +function untyped_vector_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler, context)) +end +function untyped_vector_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return untyped_vector_varinfo(Random.default_rng(), model, sampler, context) +end +function untyped_vector_varinfo( + rng::Random.AbstractRNG, model::Model, context::AbstractContext +) + # No sampler + return untyped_vector_varinfo(rng, model, SampleFromPrior(), context) +end +function untyped_vector_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return untyped_vector_varinfo(model, SampleFromPrior(), context) +end -See also: [`DynamicPPL.untyped_varinfo`](@ref) """ -typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...)) + typed_vector_varinfo([rng, ]model[, sampler, context, metadata]) -function VarInfo( +Return a VarInfo object for the given `model` and `context`, which has a +NamedTuple of `VarNamedVector`s as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function typed_vector_varinfo(vi::NTVarInfo) + md = map(metadata_to_varnamedvector, vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end +function typed_vector_varinfo(vi::UntypedVectorVarInfo) + new_metas = group_by_symbol(vi.metadata) + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple(new_metas) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end +function typed_vector_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler, context)) end -function VarInfo( - model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... +function typed_vector_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) - return VarInfo(Random.default_rng(), model, args...) + # No rng + return typed_vector_varinfo(Random.default_rng(), model, sampler, context) +end +function typed_vector_varinfo( + rng::Random.AbstractRNG, model::Model, context::AbstractContext +) + # No sampler + return typed_vector_varinfo(rng, model, SampleFromPrior(), context) +end +function typed_vector_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return typed_vector_varinfo(model, SampleFromPrior(), context) end """ @@ -204,7 +436,7 @@ end Return the length of the vector representation of `varinfo`. """ vector_length(varinfo::VarInfo) = length(varinfo.metadata) -vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) +vector_length(varinfo::NTVarInfo) = sum(length, varinfo.metadata) vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) @@ -241,11 +473,6 @@ end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) -# without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - return VarInfo(rng, model, SampleFromPrior(), context) -end - #### #### Internal functions #### @@ -500,7 +727,7 @@ setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val Return the metadata in `vi` that belongs to `vn`. """ getmetadata(vi::VarInfo, vn::VarName) = vi.metadata -getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) +getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) """ getidx(vi::VarInfo, vn::VarName) @@ -541,7 +768,7 @@ end Return the range corresponding to `varname` in the vector representation of `varinfo`. """ vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn) -function vector_getrange(vi::TypedVarInfo, vn::VarName) +function vector_getrange(vi::NTVarInfo, vn::VarName) offset = 0 for md in values(vi.metadata) # First, we need to check if `vn` is in `md`. @@ -563,8 +790,8 @@ Return the range corresponding to `varname` in the vector representation of `var function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName}) return map(Base.Fix1(vector_getrange, varinfo), varname) end -# Specialized version for `TypedVarInfo`. -function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) +# Specialized version for `NTVarInfo`. +function vector_getranges(varinfo::NTVarInfo, vns::Vector{<:VarName}) # TODO: Does it help if we _don't_ convert to a vector here? metadatas = collect(values(varinfo.metadata)) # Extract the offsets. @@ -624,7 +851,7 @@ end getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) # NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. # See for example https://github.com/JuliaLang/julia/pull/46381. -function getindex_internal(vi::TypedVarInfo, ::Colon) +function getindex_internal(vi::NTVarInfo, ::Colon) return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) end function getindex_internal(md::Metadata, ::Colon) @@ -684,10 +911,10 @@ settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) Returns a tuple of the unique symbols of random variables in `vi`. """ syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols -syms(vi::TypedVarInfo) = keys(vi.metadata) +syms(vi::NTVarInfo) = keys(vi.metadata) _getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) -_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) +_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) @generated function _getidcs(metadata::NamedTuple{names}) where {names} exprs = [] @@ -702,12 +929,11 @@ end findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) """ - all_varnames_grouped_by_symbol(vi::TypedVarInfo) + all_varnames_grouped_by_symbol(vi::NTVarInfo) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ -all_varnames_grouped_by_symbol(vi::TypedVarInfo) = - all_varnames_grouped_by_symbol(vi.metadata) +all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(vi.metadata) @generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} expr = Expr(:tuple) @@ -745,73 +971,6 @@ end #### APIs for typed and untyped VarInfo #### -# VarInfo - -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) - -function TypedVarInfo(vi::VectorVarInfo) - new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end - -""" - TypedVarInfo(vi::UntypedVarInfo) - -This function finds all the unique `sym`s from the instances of `VarName{sym}` found in -`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the -global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as -a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each -symbol. -""" -function TypedVarInfo(vi::UntypedVarInfo) - meta = vi.metadata - new_metas = Metadata[] - # Symbols of all instances of `VarName{sym}` in `vi.vns` - syms_tuple = Tuple(syms(vi)) - for s in syms_tuple - # Find all indices in `vns` with symbol `s` - inds = findall(vn -> getsym(vn) === s, meta.vns) - n = length(inds) - # New `vns` - sym_vns = getindex.((meta.vns,), inds) - # New idcs - sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) - # New dists - sym_dists = getindex.((meta.dists,), inds) - # New orders - sym_orders = getindex.((meta.orders,), inds) - # New flags - sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) - - # Extract new ranges and vals - _ranges = getindex.((meta.ranges,), inds) - # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 - _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] - sym_ranges = Vector{eltype(_ranges)}(undef, n) - start = 0 - for i in 1:n - sym_ranges[i] = (start + 1):(start + length(_vals[i])) - start += length(_vals[i]) - end - sym_vals = foldl(vcat, _vals) - - push!( - new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags - ), - ) - end - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end -TypedVarInfo(vi::TypedVarInfo) = vi - function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) resetlogp!!(vi) @@ -834,8 +993,8 @@ Base.keys(vi::VarInfo) = Base.keys(vi.metadata) # HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly # on other methods in the codebase which requires `Vector{<:VarName}`. -Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] -@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names} +Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] +@generated function Base.keys(vi::NTVarInfo{<:NamedTuple{names}}) where {names} expr = Expr(:call) push!(expr.args, :vcat) @@ -898,7 +1057,7 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end -function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -952,13 +1111,13 @@ function _link!(vi::UntypedVarInfo, vns) end end -# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _link!(vi::TypedVarInfo, vns::VarNameTuple) +# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _link!(vi::NTVarInfo, vns::VarNameTuple) return _link!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::TypedVarInfo, vns::NamedTuple) +function _link!(vi::NTVarInfo, vns::NamedTuple) return _link!(vi.metadata, vi, vns) end @@ -1002,7 +1161,7 @@ end return expr end -function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1064,13 +1223,13 @@ function _invlink!(vi::UntypedVarInfo, vns) end end -# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) +# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _invlink!(vi::NTVarInfo, vns::VarNameTuple) return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::TypedVarInfo, vns::NamedTuple) +function _invlink!(vi::NTVarInfo, vns::NamedTuple) return _invlink!(vi.metadata, vi, vns) end @@ -1121,7 +1280,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) return vi end -function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function link(::DynamicTransformation, vi::NTVarInfo, model::Model) return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1156,13 +1315,13 @@ function _link(model::Model, varinfo::VarInfo, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -# If we try to _link a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) +# If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _link(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) return _link(model, varinfo, group_varnames_by_symbol(vns)) end -function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) +function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md = _link_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) @@ -1257,7 +1416,7 @@ function _link_metadata!!( return metadata end -function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1297,13 +1456,13 @@ function _invlink(model::Model, varinfo::VarInfo, vns) ) end -# If we try to _invlink a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) +# If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _invlink(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) return _invlink(model, varinfo, group_varnames_by_symbol(vns)) end -function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) +function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) @@ -1394,7 +1553,7 @@ end # TODO(mhauru) The treatment of the case when some variables are linked and others are not # should be revised. It used to be the case that for UntypedVarInfo `islinked` returned -# whether the first variable was linked. For TypedVarInfo we did an OR over the first +# whether the first variable was linked. For NTVarInfo we did an OR over the first # variables under each symbol. We now more consistently use OR, but I'm not convinced this # is really the right thing to do. """ @@ -1538,7 +1697,7 @@ Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) Check whether `vn` has a value in `vi`. """ Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) -function Base.haskey(vi::TypedVarInfo, vn::VarName) +function Base.haskey(vi::NTVarInfo, vn::VarName) md_haskey = map(vi.metadata) do metadata haskey(metadata, vn) end @@ -1601,12 +1760,12 @@ the `VarInfo` `vi`, mutating if it makes sense. function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist" + elseif vi isa NTVarInfo + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" end sym = getsym(vn) - if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + if vi isa NTVarInfo && ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. val = tovec(r) md = Metadata( @@ -1627,18 +1786,18 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) return vi end -function Base.push!(vi::VectorVarInfo, vn::VarName, val, args...) +function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) push!(getmetadata(vi, vn), vn, val, args...) return vi end -function Base.push!(vi::VectorVarInfo, pair::Pair, args...) +function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) vn, val = pair return push!(vi, vn, val, args...) end -# TODO(mhauru) push! can't be implemented in-place for TypedVarInfo if the symbol doesn't -# exist in the TypedVarInfo already. We could implement it in the cases where it it does +# TODO(mhauru) push! can't be implemented in-place for NTVarInfo if the symbol doesn't +# exist in the NTVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. function Base.push!(meta::Metadata, vn, r, dist, num_produce) @@ -1760,7 +1919,7 @@ function set_retained_vns_del!(vi::UntypedVarInfo) end return nothing end -function set_retained_vns_del!(vi::TypedVarInfo) +function set_retained_vns_del!(vi::NTVarInfo) idcs = _getidcs(vi) return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end @@ -1821,12 +1980,12 @@ function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) return vi end -function _apply!(kernel!, vi::TypedVarInfo, values, keys) +function _apply!(kernel!, vi::NTVarInfo, values, keys) return _typed_apply!(kernel!, vi, vi.metadata, values, collect_maybe(keys)) end @generated function _typed_apply!( - kernel!, vi::TypedVarInfo, metadata::NamedTuple{names}, values, keys + kernel!, vi::NTVarInfo, metadata::NamedTuple{names}, values, keys ) where {names} updates = map(names) do n quote @@ -1963,7 +2122,8 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.VarInfo(rng, m, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()); # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. +julia> var_info = DynamicPPL.VarInfo(rng, m); + # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. julia> var_info[@varname(m)] -0.6702516921145671 @@ -2061,8 +2221,8 @@ function values_as( return ConstructionBase.constructorof(D)(iter) end -values_as(vi::VectorVarInfo, args...) = values_as(vi.metadata, args...) -values_as(vi::VectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) +values_as(vi::UntypedVectorVarInfo, args...) = values_as(vi.metadata, args...) +values_as(vi::UntypedVectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) function values_from_metadata(md::Metadata) return ( diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 933bfb1d1..86329a51d 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -14,7 +14,7 @@ @model demo2() = x ~ Normal() @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa - DynamicPPL.TypedVarInfo + DynamicPPL.NTVarInfo @model function demo3() # Just making sure that nothing strange happens when type inference fails. @@ -53,7 +53,7 @@ end # Should pass if we're only checking the tilde statements. @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa - DynamicPPL.TypedVarInfo + DynamicPPL.NTVarInfo # Should fail if we're including errors in the model body. @test DynamicPPL.Experimental.determine_suitable_varinfo( demo5(); only_ddpl=false @@ -75,11 +75,11 @@ ) JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. - is_typed = varinfo isa DynamicPPL.TypedVarInfo + is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed # If the test failed, check why it didn't infer a typed varinfo if !is_typed - typed_vi = VarInfo(model) + typed_vi = DynamicPPL.typed_varinfo(model) f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, typed_vi ) diff --git a/test/model.jl b/test/model.jl index 447a9ecaa..dd5a35fe6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -25,9 +25,9 @@ function innermost_distribution_type(d::Distributions.Product) return dists[1] end -is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true -is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false +is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true +is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @@ -233,8 +233,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, Metadata" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() - spl = SampleFromPrior() - vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) + vi = VarInfo(model) vi = link!!(vi, model) for i in 1:10 @@ -250,8 +249,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, VectorVarInfo" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() for i in 1:10 - vi = VarInfo(model) - @test vi[@varname(x)] >= vi[@varname(m)] + for vi_constructor in + [DynamicPPL.typed_vector_varinfo, DynamicPPL.untyped_vector_varinfo] + vi = vi_constructor(model) + @test vi[@varname(x)] >= vi[@varname(m)] + end end end @@ -400,7 +402,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = filter( - is_typed_varinfo, + is_type_stable_varinfo, DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 8e48814a4..aa3b592f7 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -92,7 +92,7 @@ SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), SimpleVarInfo(DynamicPPL.VarNamedVector()), - VarInfo(model), + DynamicPPL.typed_varinfo(model), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) diff --git a/test/test_util.jl b/test/test_util.jl index 87c69b5fe..902dd7230 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -33,14 +33,18 @@ end Return string representing a short description of `vi`. """ -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = - "threadsafe($(short_varinfo_name(vi.varinfo)))" -function short_varinfo_name(vi::TypedVarInfo) - DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" - return "TypedVarInfo" +function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) + return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" +function short_varinfo_name(vi::DynamicPPL.NTVarInfo) + return if DynamicPPL.has_varnamedvector(vi) + "TypedVectorVarInfo" + else + "TypedVarInfo" + end +end +short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" end diff --git a/test/varinfo.jl b/test/varinfo.jl index 74feb42f6..777917aa6 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -34,7 +34,7 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) end @testset "varinfo.jl" begin - @testset "TypedVarInfo with Metadata" begin + @testset "VarInfo with NT of Metadata" begin @model gdemo(x, y) = begin s ~ InverseGamma(2, 3) m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) @@ -43,9 +43,8 @@ end end model = gdemo(1.0, 2.0) - vi = VarInfo(DynamicPPL.Metadata()) - model(vi, SampleFromUniform()) - tvi = TypedVarInfo(vi) + vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata for f in fieldnames(typeof(tvi.metadata)) @@ -102,7 +101,7 @@ end @test vi[vn] == 2 * r # TODO(mhauru) Implement these functions for other VarInfo types too. - if vi isa DynamicPPL.VectorVarInfo + if vi isa DynamicPPL.UntypedVectorVarInfo delete!(vi, vn) @test isempty(vi) vi = push!!(vi, vn, r, dist) @@ -116,7 +115,7 @@ end vi = VarInfo() test_base!!(vi) - test_base!!(TypedVarInfo(vi)) + test_base!!(DynamicPPL.typed_varinfo(vi)) test_base!!(SimpleVarInfo()) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) @@ -135,7 +134,7 @@ end vi = VarInfo() test_varinfo_logp!(vi) - test_varinfo_logp!(TypedVarInfo(vi)) + test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) test_varinfo_logp!(SimpleVarInfo()) test_varinfo_logp!(SimpleVarInfo(Dict())) test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) @@ -160,17 +159,17 @@ end unset_flag!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end - vi = VarInfo(DynamicPPL.Metadata()) + vi = VarInfo() test_varinfo!(vi) - test_varinfo!(empty!!(TypedVarInfo(vi))) + test_varinfo!(empty!!(DynamicPPL.typed_varinfo(vi))) end - @testset "push!! to TypedVarInfo" begin + @testset "push!! to VarInfo with NT of Metadata" begin vn_x = @varname x vn_y = @varname y untyped_vi = VarInfo() untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) - typed_vi = TypedVarInfo(untyped_vi) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) @test typed_vi[vn_x] == 1.0 @test typed_vi[vn_y] == 2.0 @@ -206,16 +205,10 @@ end m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) - vi_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() - ) - vi_untyped = VarInfo(DynamicPPL.Metadata()) - vi_vnv = VarInfo(DynamicPPL.VarNamedVector()) - vi_vnv_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.VarNamedVector() - ) - model(vi_untyped, SampleFromPrior()) - model(vi_vnv, SampleFromPrior()) + vi_typed = DynamicPPL.typed_varinfo(model) + vi_untyped = DynamicPPL.untyped_varinfo(model) + vi_vnv = DynamicPPL.untyped_vector_varinfo(model) + vi_vnv_typed = DynamicPPL.typed_vector_varinfo(model) model_name = model == model_uv ? "univariate" : "multivariate" @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ @@ -405,7 +398,7 @@ end @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values - vi = TypedVarInfo(vi) + vi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) @@ -459,9 +452,9 @@ end # Need to run once since we can't specify that we want to _sample_ # in the unconstrained space for `VarInfo` without having `vn` # present in the `varinfo`. - ## `UntypedVarInfo` - vi = VarInfo() - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + + ## `untyped_varinfo` + vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -469,8 +462,8 @@ end x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - ## `TypedVarInfo` - vi = VarInfo(model) + ## `typed_varinfo` + vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -979,7 +972,7 @@ end @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @test DynamicPPL.get_num_produce(vi) == 3 - vi = empty!!(DynamicPPL.TypedVarInfo(vi)) + vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi)