From cca239f8b513d409c7106e9a6b61ca7b1a6a7713 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 16:11:35 +0000 Subject: [PATCH 1/3] Remove Selectors and Gibbs IDs --- HISTORY.md | 2 + docs/src/api.md | 7 --- ext/DynamicPPLChainRulesCoreExt.jl | 10 +--- src/DynamicPPL.jl | 9 --- src/abstract_varinfo.jl | 46 --------------- src/context_implementations.jl | 12 ++-- src/sampler.jl | 7 +-- src/selector.jl | 13 ---- src/simple_varinfo.jl | 28 ++------- src/threadsafe.jl | 9 +-- src/varinfo.jl | 95 +++++------------------------- src/varnamedvector.jl | 8 +-- test/runtests.jl | 2 +- 13 files changed, 39 insertions(+), 209 deletions(-) delete mode 100644 src/selector.jl diff --git a/HISTORY.md b/HISTORY.md index 6b7247c8d..7ad8f526a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -15,6 +15,8 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `eltype(::VarInfo)` no longer accepts a sampler as an argument - `keys(::VarInfo)` no longer accepts a sampler as an argument - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. + - `push!!` and `push!` no longer accept samplers or `Selector`s as arguments + - `getgid`, `setgid!`, `updategid!`, and `inspace` no longer exist ### Reverse prefixing order diff --git a/docs/src/api.md b/docs/src/api.md index 6c58264fe..b9cafaaf4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -289,13 +289,6 @@ unset_flag! is_flagged ``` -For Gibbs sampling the following functions were added. - -```@docs -setgid! -updategid! -``` - The following functions were used for sequential Monte Carlo methods. ```@docs diff --git a/ext/DynamicPPLChainRulesCoreExt.jl b/ext/DynamicPPLChainRulesCoreExt.jl index 1559467f8..12b816c60 100644 --- a/ext/DynamicPPLChainRulesCoreExt.jl +++ b/ext/DynamicPPLChainRulesCoreExt.jl @@ -10,15 +10,7 @@ end # See https://github.com/TuringLang/Turing.jl/issues/1199 ChainRulesCore.@non_differentiable BangBang.push!!( - vi::DynamicPPL.VarInfo, - vn::DynamicPPL.VarName, - r, - dist::Distributions.Distribution, - gidset::Set{DynamicPPL.Selector}, -) - -ChainRulesCore.@non_differentiable DynamicPPL.updategid!( - vi::DynamicPPL.AbstractVarInfo, vn::DynamicPPL.VarName, spl::DynamicPPL.Sampler + vi::DynamicPPL.VarInfo, vn::DynamicPPL.VarName, r, dist::Distributions.Distribution ) # No need + causes issues for some AD backends, e.g. Zygote. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 55e1f7e88..7d3d9508a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -63,8 +63,6 @@ export AbstractVarInfo, is_flagged, set_flag!, unset_flag!, - setgid!, - updategid!, setorder!, istrans, link, @@ -74,7 +72,6 @@ export AbstractVarInfo, values_as, # VarName (reexport from AbstractPPL) VarName, - inspace, subsumes, @varname, # Compiler @@ -166,14 +163,8 @@ See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). """ abstract type AbstractVarInfo <: AbstractModelTrace end -const LEGACY_WARNING = """ -!!! warning - This method is considered legacy, and is likely to be deprecated in the future. -""" - # Necessary forward declarations include("utils.jl") -include("selector.jl") include("chains.jl") include("model.jl") include("sampler.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 4e9e5c554..66b098370 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -169,51 +169,6 @@ See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@re """ function getindex_internal end -""" - push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - -Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`, mutating if it makes sense. -""" -function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return BangBang.push!!(vi, vn, r, dist, Set{Selector}([])) -end - -""" - push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`, if it makes sense. - -The sampler is passed here to invalidate its cache where defined. - -$(LEGACY_WARNING) -""" -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler -) - return BangBang.push!!(vi, vn, r, dist, spl.selector) -end -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler -) - return BangBang.push!!(vi, vn, r, dist) -end - -""" - push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler of -selector `gid` from a distribution `dist` to `VarInfo` `vi`. - -$(LEGACY_WARNING) -""" -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector -) - return BangBang.push!!(vi, vn, r, dist, Set([gid])) -end - @doc """ empty!!(vi::AbstractVarInfo) @@ -768,7 +723,6 @@ end # Legacy code that is currently overloaded for the sake of simplicity. # TODO: Remove when possible. increment_num_produce!(::AbstractVarInfo) = nothing -setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 462012676..4594902dc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -239,11 +239,11 @@ function assume( r = init(rng, dist, sampler) if istrans(vi) f = to_linked_internal_transform(vi, vn, dist) - push!!(vi, vn, f(r), dist, sampler) + push!!(vi, vn, f(r), dist) # By default `push!!` sets the transformed flag to `false`. settrans!!(vi, true, vn) else - push!!(vi, vn, r, dist, sampler) + push!!(vi, vn, r, dist) end end @@ -466,11 +466,11 @@ function get_and_set_val!( vn = vns[i] if istrans(vi) ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i]) - push!!(vi, vn, ri_linked, dist, spl) + push!!(vi, vn, ri_linked, dist) # `push!!` sets the trans-flag to `false` by default. settrans!!(vi, true, vn) else - push!!(vi, vn, r[:, i], dist, spl) + push!!(vi, vn, r[:, i], dist) end end end @@ -513,14 +513,14 @@ function get_and_set_val!( # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) - push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,)) + push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists) # NOTE: Need to add the correction. # FIXME: This is not great. acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else - push!!.((vi,), vns, r, dists, (spl,)) + push!!.((vi,), vns, r, dists) end end return r diff --git a/src/sampler.jl b/src/sampler.jl index 56cd8404e..31d364daf 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -31,6 +31,8 @@ function init(rng, dist, ::SampleFromUniform, n::Int) return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) end +# TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? +# (Selector has been removed). """ Sampler{T} @@ -47,12 +49,7 @@ By default, values are sampled from the prior. """ struct Sampler{T} <: AbstractSampler alg::T - selector::Selector # Can we remove it? - # TODO: add space such that we can integrate existing external samplers in DynamicPPL end -Sampler(alg) = Sampler(alg, Selector()) -Sampler(alg, model::Model) = Sampler(alg, model, Selector()) -Sampler(alg, model::Model, s::Selector) = Sampler(alg, s) # AbstractMCMC interface for SampleFromUniform and SampleFromPrior function AbstractMCMC.step( diff --git a/src/selector.jl b/src/selector.jl deleted file mode 100644 index fd4aa6d1c..000000000 --- a/src/selector.jl +++ /dev/null @@ -1,13 +0,0 @@ -struct Selector - gid::UInt64 - tag::Symbol # :default, :invalid, :Gibbs, :HMC, etc. - rerun::Bool -end -function Selector(tag::Symbol=:default, rerun=tag != :default) - return Selector(time_ns(), tag, rerun) -end -function Selector(gid::Integer, tag::Symbol=:default) - return Selector(gid, tag, tag != :default) -end -hash(s::Selector) = hash(s.gid) -==(s1::Selector, s2::Selector) = s1.gid == s2.gid diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 07296c3f7..00d6b3437 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -374,42 +374,26 @@ end # `NamedTuple` function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, - vn::VarName{sym,typeof(identity)}, - value, - dist::Distribution, - gidset::Set{Selector}, + vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution ) where {sym} return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) end function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, - vn::VarName{sym}, - value, - dist::Distribution, - gidset::Set{Selector}, + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution ) where {sym} return Accessors.@set vi.values = set!!(vi.values, vn, value) end # `AbstractDict` function BangBang.push!!( - vi::SimpleVarInfo{<:AbstractDict}, - vn::VarName, - value, - dist::Distribution, - gidset::Set{Selector}, + vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution ) vi.values[vn] = value return vi end function BangBang.push!!( - vi::SimpleVarInfo{<:VarNamedVector}, - vn::VarName, - value, - dist::Distribution, - gidset::Set{Selector}, + vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution ) # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. @@ -483,7 +467,7 @@ function assume( value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. value_raw = to_maybe_linked_internal(vi, vn, dist, value) - vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) + vi = BangBang.push!!(vi, vn, value_raw, dist) return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end @@ -550,7 +534,7 @@ function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) +istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) islinked(vi::SimpleVarInfo) = istrans(vi) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 4367ff06d..539c1e9d6 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -57,10 +57,8 @@ end has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) -function BangBang.push!!( - vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} -) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) +function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) + return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) @@ -70,9 +68,6 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) - return setgid!(vi.varinfo, gid, vn) -end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) diff --git a/src/varinfo.jl b/src/varinfo.jl index 8f7f7b6c1..ca143ea63 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -18,8 +18,6 @@ Let `md` be an instance of `Metadata`: `md.vns`, `md.ranges` `md.dists`, `md.orders` and `md.flags`. - `md.vns[md.idcs[vn]] == vn`. - `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.gids[md.idcs[vn]]` is the set of algorithms used to sample `vn`. This is used in - the Gibbs sampling process. - `md.orders[md.idcs[vn]]` is the number of `observe` statements before `vn` is sampled. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. @@ -41,7 +39,6 @@ struct Metadata{ TDists<:AbstractVector{<:Distribution}, TVN<:AbstractVector{<:VarName}, TVal<:AbstractVector{<:Real}, - TGIds<:AbstractVector{Set{Selector}}, } # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` idcs::TIdcs # Dict{<:VarName,Int} @@ -60,10 +57,6 @@ struct Metadata{ # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} - # Vector of sampler ids corresponding to `vns` - # Each random variable can be sampled using multiple samplers, e.g. in Gibbs, hence the `Set` - gids::TGIds # AbstractVector{Set{Selector}} - # Number of `observe` statements before each random variable is sampled orders::Vector{Int} @@ -261,7 +254,6 @@ function replace_values(metadata::Metadata, x) metadata.ranges, x, metadata.dists, - metadata.gids, metadata.orders, metadata.flags, ) @@ -301,7 +293,6 @@ function Metadata() Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), - Vector{Set{Selector}}(), Vector{Int}(), flags, ) @@ -320,7 +311,6 @@ function empty!(meta::Metadata) empty!(meta.ranges) empty!(meta.vals) empty!(meta.dists) - empty!(meta.gids) empty!(meta.orders) for k in keys(meta.flags) empty!(meta.flags[k]) @@ -418,7 +408,6 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va ranges, vals, metadata.dists[indices_for_vns], - metadata.gids[indices_for_vns], metadata.orders[indices_for_vns], flags, ) @@ -493,7 +482,6 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - gids = Set{Selector}[] orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. @@ -516,15 +504,13 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)] - push!(gids, gid) push!(orders, getorder(metadata_for_vn, vn)) for k in keys(flags) push!(flags[k], is_flagged(metadata_for_vn, vn, k)) end end - return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags) + return Metadata(idcs, vns, ranges, vals, dists, orders, flags) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -730,13 +716,6 @@ end return expr end -""" - getgid(vi::VarInfo, vn::VarName) - -Return the set of sampler selectors associated with `vn` in `vi`. -""" -getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] - function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) settrans!!(getmetadata(vi, vn), trans, vn) return vi @@ -787,14 +766,8 @@ _getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) return :($(exprs...),) end -@inline function findinds(f_meta::Metadata) - # Get all the idcs of the vns - return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) -end - -function findinds(vnv::VarNamedVector) - return 1:length(vnv.varnames) -end +@inline findinds(f_meta::Metadata) = eachindex(f_meta.vns) +findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) """ all_varnames_grouped_by_symbol(vi::TypedVarInfo) @@ -876,9 +849,6 @@ function TypedVarInfo(vi::UntypedVarInfo) sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) # New dists sym_dists = getindex.((meta.dists,), inds) - # New gids, can make a resizeable FillArray - sym_gids = getindex.((meta.gids,), inds) - @assert length(sym_gids) <= 1 || all(x -> x == sym_gids[1], @view sym_gids[2:end]) # New orders sym_orders = getindex.((meta.orders,), inds) # New flags @@ -899,14 +869,7 @@ function TypedVarInfo(vi::UntypedVarInfo) push!( new_metas, Metadata( - sym_idcs, - sym_vns, - sym_ranges, - sym_vals, - sym_dists, - sym_gids, - sym_orders, - sym_flags, + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags ), ) end @@ -951,21 +914,6 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] return expr end -""" - setgid!(vi::VarInfo, gid::Selector, vn::VarName) - -Add `gid` to the set of sampler selectors associated with `vn` in `vi`. -""" -setgid!(vi::VarInfo, gid::Selector, vn::VarName) = setgid!(getmetadata(vi, vn), gid, vn) - -function setgid!(m::Metadata, gid::Selector, vn::VarName) - return push!(m.gids[getidx(m, vn)], gid) -end - -function setgid!(vnv::VarNamedVector, gid::Selector, vn::VarName) - throw(ErrorException("Calling setgid! on a VarNamedVector isn't valid.")) -end - istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") @@ -1348,7 +1296,6 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.gids, metadata.orders, metadata.flags, ) @@ -1491,7 +1438,6 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.gids, metadata.orders, metadata.flags, ) @@ -1675,7 +1621,6 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) | Varnames : $(string(vi.metadata.vns)) | Range : $(vi.metadata.ranges) | Vals : $(vi.metadata.vals) - | GIDs : $(vi.metadata.gids) | Orders : $(vi.metadata.orders) | Logp : $(getlogp(vi)) | #produce : $(get_num_produce(vi)) @@ -1715,13 +1660,17 @@ function Base.show(io::IO, vi::UntypedVarInfo) return print(io, ")") end -function BangBang.push!!( - vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} -) +""" + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) + +Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to +the `VarInfo` `vi`, mutating if it makes sense. +""" +function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist" end sym = getsym(vn) @@ -1734,14 +1683,13 @@ function BangBang.push!!( [1:length(val)], val, [dist], - [gidset], [get_num_produce(vi)], Dict{String,BitVector}("trans" => [false], "del" => [false]), ) vi = Accessors.@set vi.metadata[sym] = md else meta = getmetadata(vi, vn) - push!(meta, vn, r, dist, gidset, get_num_produce(vi)) + push!(meta, vn, r, dist, get_num_produce(vi)) end return vi @@ -1761,7 +1709,7 @@ end # exist in the TypedVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. -function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) +function Base.push!(meta::Metadata, vn, r, dist, num_produce) val = tovec(r) meta.idcs[vn] = length(meta.idcs) + 1 push!(meta.vns, vn) @@ -1770,7 +1718,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.gids, gidset) push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) @@ -1914,18 +1861,6 @@ end return expr end -""" - updategid!(vi::VarInfo, vn::VarName, spl::Sampler) - -Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked -and `vn`'s symbol is in the space of `spl`. -""" -function updategid!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, spl::Sampler) - if inspace(vn, getspace(spl)) - setgid!(vi, spl.selector, vn) - end -end - # TODO: Maybe rename or something? """ _apply!(kernel!, vi::VarInfo, values, keys) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 3b3f0ce42..c6b54fdc4 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,9 +766,9 @@ function update_internal!( return nothing end -# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# TODO(mhauru) The num_produce argument is used by Particle Gibbs. # Remove this method as soon as possible. -function BangBang.push!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) +function BangBang.push!(vnv::VarNamedVector, vn, val, dist, num_produce) f = from_vec_transform(dist) return setindex_internal!(vnv, tovec(val), vn, f) end @@ -963,9 +963,9 @@ function BangBang.push!!(vnv::VarNamedVector, pair::Pair) return setindex!!(vnv, val, vn) end -# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# TODO(mhauru) The num_produce argument is used by Particle Gibbs. # Remove this method as soon as possible. -function BangBang.push!!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) +function BangBang.push!!(vnv::VarNamedVector, vn, val, dist, num_produce) f = from_vec_transform(dist) return setindex_internal!!(vnv, tovec(val), vn, f) end diff --git a/test/runtests.jl b/test/runtests.jl index 29a148789..25cd2fb40 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,7 +33,7 @@ using JET: JET using Combinatorics: combinations using OrderedCollections: OrderedSet -using DynamicPPL: getargs_dottilde, getargs_tilde, Selector +using DynamicPPL: getargs_dottilde, getargs_tilde const GROUP = get(ENV, "GROUP", "All") Random.seed!(100) From b51e5649405fff0c393143a647784fe4c25a9de9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 13 Feb 2025 12:42:46 +0000 Subject: [PATCH 2/3] Remove getspace --- HISTORY.md | 2 +- src/DynamicPPL.jl | 3 --- src/sampler.jl | 2 -- test/ad.jl | 3 +-- test/sampler.jl | 1 - 5 files changed, 2 insertions(+), 9 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 7ad8f526a..90db022e7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -16,7 +16,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `keys(::VarInfo)` no longer accepts a sampler as an argument - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. - `push!!` and `push!` no longer accept samplers or `Selector`s as arguments - - `getgid`, `setgid!`, `updategid!`, and `inspace` no longer exist + - `getgid`, `setgid!`, `updategid!`, `getspace`, and `inspace` no longer exist ### Reverse prefixing order diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 7d3d9508a..8fea43e50 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -149,9 +149,6 @@ macro prob_str(str) )) end -# Used here and overloaded in Turing -function getspace end - """ AbstractVarInfo diff --git a/src/sampler.jl b/src/sampler.jl index 31d364daf..aa3a637ee 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -18,8 +18,6 @@ Sampling algorithm that samples unobserved random variables from their prior dis """ struct SampleFromPrior <: AbstractSampler end -getspace(::Union{SampleFromPrior,SampleFromUniform}) = () - # Initializations. init(rng, dist, ::SampleFromPrior) = rand(rng, dist) function init(rng, dist, ::SampleFromUniform) diff --git a/test/ad.jl b/test/ad.jl index 17981cf2a..87c7f22c3 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -41,7 +41,7 @@ σ = 0.3 y = @. rand(sin(t) + Normal(0, σ)) @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors + # Priors α ~ Normal(y[1], 0.001) τ ~ Exponential(1) η ~ filldist(Normal(0, 1), TT - 1) @@ -63,7 +63,6 @@ # overload assume so that model evaluation doesn't fail due to a lack # of implementation struct MyEmptyAlg end - DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = () DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = DynamicPPL.assume(dist, vn, vi) diff --git a/test/sampler.jl b/test/sampler.jl index 50111b1fd..8c4f1ed96 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -67,7 +67,6 @@ ) return vi, nothing end - DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = () # initial samplers DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() From 6c26841b277bef2ddc03339b70ea7900816fd1dd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 13 Feb 2025 12:44:19 +0000 Subject: [PATCH 3/3] Remove a dead VNV method --- src/varnamedvector.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index c6b54fdc4..6b7c82859 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1023,15 +1023,6 @@ julia> ForwardDiff.gradient(f, [1.0]) """ replace_raw_storage(vnv::VarNamedVector, vals) = Accessors.@set vnv.vals = vals -# TODO(mhauru) The space argument is used by the old Gibbs sampler. To be removed. -function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {space} - if length(space) > 0 - msg = "Selecting values in a VarNamedVector with a space is not supported." - throw(ArgumentError(msg)) - end - return replace_raw_storage(vnv, vals) -end - vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv) """