Skip to content

Remove samplers from VarInfo - Selectors and GIDs #808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ This release removes the feature of `VarInfo` where it kept track of which varia
- `eltype(::VarInfo)` no longer accepts a sampler as an argument
- `keys(::VarInfo)` no longer accepts a sampler as an argument
- `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument.
- `push!!` and `push!` no longer accept samplers or `Selector`s as arguments
- `getgid`, `setgid!`, `updategid!`, `getspace`, and `inspace` no longer exist

### Reverse prefixing order

Expand Down
7 changes: 0 additions & 7 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,6 @@ unset_flag!
is_flagged
```

For Gibbs sampling the following functions were added.

```@docs
setgid!
updategid!
```

The following functions were used for sequential Monte Carlo methods.

```@docs
Expand Down
10 changes: 1 addition & 9 deletions ext/DynamicPPLChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,7 @@ end

# See https://github.com/TuringLang/Turing.jl/issues/1199
ChainRulesCore.@non_differentiable BangBang.push!!(
vi::DynamicPPL.VarInfo,
vn::DynamicPPL.VarName,
r,
dist::Distributions.Distribution,
gidset::Set{DynamicPPL.Selector},
)

ChainRulesCore.@non_differentiable DynamicPPL.updategid!(
vi::DynamicPPL.AbstractVarInfo, vn::DynamicPPL.VarName, spl::DynamicPPL.Sampler
vi::DynamicPPL.VarInfo, vn::DynamicPPL.VarName, r, dist::Distributions.Distribution
)

# No need + causes issues for some AD backends, e.g. Zygote.
Expand Down
12 changes: 0 additions & 12 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ export AbstractVarInfo,
is_flagged,
set_flag!,
unset_flag!,
setgid!,
updategid!,
setorder!,
istrans,
link,
Expand All @@ -74,7 +72,6 @@ export AbstractVarInfo,
values_as,
# VarName (reexport from AbstractPPL)
VarName,
inspace,
subsumes,
@varname,
# Compiler
Expand Down Expand Up @@ -152,9 +149,6 @@ macro prob_str(str)
))
end

# Used here and overloaded in Turing
function getspace end

"""
AbstractVarInfo

Expand All @@ -166,14 +160,8 @@ See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
"""
abstract type AbstractVarInfo <: AbstractModelTrace end

const LEGACY_WARNING = """
!!! warning
This method is considered legacy, and is likely to be deprecated in the future.
"""

# Necessary forward declarations
include("utils.jl")
include("selector.jl")
include("chains.jl")
include("model.jl")
include("sampler.jl")
Expand Down
46 changes: 0 additions & 46 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,51 +169,6 @@ See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@re
"""
function getindex_internal end

"""
push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution)

Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to
the `VarInfo` `vi`, mutating if it makes sense.
"""
function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution)
return BangBang.push!!(vi, vn, r, dist, Set{Selector}([]))
end

"""
push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler)

Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl`
from a distribution `dist` to `VarInfo` `vi`, if it makes sense.

The sampler is passed here to invalidate its cache where defined.

$(LEGACY_WARNING)
"""
function BangBang.push!!(
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler
)
return BangBang.push!!(vi, vn, r, dist, spl.selector)
end
function BangBang.push!!(
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler
)
return BangBang.push!!(vi, vn, r, dist)
end

"""
push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector)

Push a new random variable `vn` with a sampled value `r` sampled with a sampler of
selector `gid` from a distribution `dist` to `VarInfo` `vi`.

$(LEGACY_WARNING)
"""
function BangBang.push!!(
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector
)
return BangBang.push!!(vi, vn, r, dist, Set([gid]))
end

@doc """
empty!!(vi::AbstractVarInfo)

Expand Down Expand Up @@ -768,7 +723,6 @@ end
# Legacy code that is currently overloaded for the sake of simplicity.
# TODO: Remove when possible.
increment_num_produce!(::AbstractVarInfo) = nothing
setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing

"""
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
Expand Down
12 changes: 6 additions & 6 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,11 @@
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, vn, dist)
push!!(vi, vn, f(r), dist, sampler)
push!!(vi, vn, f(r), dist)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
push!!(vi, vn, r, dist)
end
end

Expand Down Expand Up @@ -466,11 +466,11 @@
vn = vns[i]
if istrans(vi)
ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i])
push!!(vi, vn, ri_linked, dist, spl)
push!!(vi, vn, ri_linked, dist)

Check warning on line 469 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L469

Added line #L469 was not covered by tests
# `push!!` sets the trans-flag to `false` by default.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r[:, i], dist, spl)
push!!(vi, vn, r[:, i], dist)
end
end
end
Expand Down Expand Up @@ -513,14 +513,14 @@
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
if istrans(vi)
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,))
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists)
# NOTE: Need to add the correction.
# FIXME: This is not great.
acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.((vi,), true, vns)
else
push!!.((vi,), vns, r, dists, (spl,))
push!!.((vi,), vns, r, dists)
end
end
return r
Expand Down
9 changes: 2 additions & 7 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ Sampling algorithm that samples unobserved random variables from their prior dis
"""
struct SampleFromPrior <: AbstractSampler end

getspace(::Union{SampleFromPrior,SampleFromUniform}) = ()

# Initializations.
init(rng, dist, ::SampleFromPrior) = rand(rng, dist)
function init(rng, dist, ::SampleFromUniform)
Expand All @@ -31,6 +29,8 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
end

# TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`?
# (Selector has been removed).
"""
Sampler{T}

Expand All @@ -47,12 +47,7 @@ By default, values are sampled from the prior.
"""
struct Sampler{T} <: AbstractSampler
alg::T
selector::Selector # Can we remove it?
# TODO: add space such that we can integrate existing external samplers in DynamicPPL
end
Sampler(alg) = Sampler(alg, Selector())
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)

# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
function AbstractMCMC.step(
Expand Down
13 changes: 0 additions & 13 deletions src/selector.jl

This file was deleted.

28 changes: 6 additions & 22 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,42 +374,26 @@ end

# `NamedTuple`
function BangBang.push!!(
vi::SimpleVarInfo{<:NamedTuple},
vn::VarName{sym,typeof(identity)},
value,
dist::Distribution,
gidset::Set{Selector},
vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution
) where {sym}
return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,)))
end
function BangBang.push!!(
vi::SimpleVarInfo{<:NamedTuple},
vn::VarName{sym},
value,
dist::Distribution,
gidset::Set{Selector},
vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution
) where {sym}
return Accessors.@set vi.values = set!!(vi.values, vn, value)
end

# `AbstractDict`
function BangBang.push!!(
vi::SimpleVarInfo{<:AbstractDict},
vn::VarName,
value,
dist::Distribution,
gidset::Set{Selector},
vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution
)
vi.values[vn] = value
return vi
end

function BangBang.push!!(
vi::SimpleVarInfo{<:VarNamedVector},
vn::VarName,
value,
dist::Distribution,
gidset::Set{Selector},
vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution
)
# The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For
# SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not.
Expand Down Expand Up @@ -483,7 +467,7 @@ function assume(
value = init(rng, dist, sampler)
# Transform if we're working in unconstrained space.
value_raw = to_maybe_linked_internal(vi, vn, dist, value)
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
vi = BangBang.push!!(vi, vn, value_raw, dist)
return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
end

Expand Down Expand Up @@ -550,7 +534,7 @@ function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
end

istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi)
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)

islinked(vi::SimpleVarInfo) = istrans(vi)
Expand Down
9 changes: 2 additions & 7 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,8 @@ end

has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo)

function BangBang.push!!(
vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
)
return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset)
function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution)
return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist)
end

get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo)
Expand All @@ -70,9 +68,6 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n

syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)

function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName)
return setgid!(vi.varinfo, gid, vn)
end
setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index)
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)

Expand Down
Loading
Loading