Skip to content

SampleFromPrior, etc. cleanup #859

Open
@penelopeysm

Description

@penelopeysm

# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler`
# That would let us use all defaults for Sampler, combine it with other samplers etc.
"""
SampleFromUniform
Sampling algorithm that samples unobserved random variables from a uniform distribution.
# References
[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values)
"""
struct SampleFromUniform <: AbstractSampler end
"""
SampleFromPrior
Sampling algorithm that samples unobserved random variables from their prior distribution.
"""
struct SampleFromPrior <: AbstractSampler end

# TODO: Remove this thing.
# SampleFromPrior and SampleFromUniform
function assume(
rng::Random.AbstractRNG,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
# if that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
# TODO(mhauru) This should probably be call a function called setindex_internal!
# Also, if we use !! we shouldn't ignore the return value.
BangBang.setindex!!(vi, f(r), vn)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, vn, dist)
push!!(vi, vn, f(r), dist)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist)
end
end
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
return r, logpdf(dist, r) - logjac, vi
end
# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi)
function observe(right::Distribution, left, vi)
increment_num_produce!(vi)
return Distributions.loglikelihood(right, left), vi
end

This is all kinda hacky and also confusing since SampleFromPrior doesn't actually sample if haskey(vi, vn) (but it's hard to see why! - if haskey(vi, vn), but sampler isa SampleFromPrior, the second if doesn't fire and we just get r = vi[vn, dist])

It should be cleaned up and the behaviour made more consistent. Also the comment about using Sampler{Prior} and Sampler{Uniform} IMO makes a lot of sense. Note, this could have further impacts on Turing.jl because there is some SampleFromPrior / SampleFromUniform type piracy there. Also it might have an effect on the implementation of TuringLang/Turing.jl#2476.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions