Skip to content

Move hasvalue and getvalue from DynamicPPL; implement extra Distributions-based methods #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
56b9f2c
Move hasvalue and getvalue to AbstractPPL; reimplement
penelopeysm Jul 5, 2025
646b9d7
Add hasvalue for (some) distributions
penelopeysm Jul 5, 2025
07975a2
Bump min Julia to 1.10
penelopeysm Jul 5, 2025
332c64a
Make hasvalue and getvalue use the most specific value
penelopeysm Jul 6, 2025
049001e
Specify getvalue semantics in docstring
penelopeysm Jul 6, 2025
e0adba7
Simplify logic (can rely on normalisation)
penelopeysm Jul 6, 2025
9291e07
Add tests for composition of head/tail and init/last
penelopeysm Jul 6, 2025
a736e70
Finish implementing distributions methods
penelopeysm Jul 6, 2025
5a902b0
Document
penelopeysm Jul 6, 2025
0e8d256
Fix LinearAlgebra version bound
penelopeysm Jul 6, 2025
29dc922
Try to fix documentation for extension (why is this so complicated...)
penelopeysm Jul 6, 2025
a998af6
Fix extension documentation
penelopeysm Jul 6, 2025
6a01588
Implement fallback {has,get}value methods for NamedTuple + Distribution
penelopeysm Jul 6, 2025
ecf0eac
Fix wrong way round composition, add more tests
penelopeysm Jul 9, 2025
8dc53d8
Update src/hasvalue.jl
penelopeysm Jul 11, 2025
a85e0e0
Update HISTORY.md
penelopeysm Jul 11, 2025
7fac592
Minor bump
penelopeysm Jul 11, 2025
5fa6a53
Fix test (forgot to push this...)
penelopeysm Jul 11, 2025
b4f69b1
Add extra example for getvalue
penelopeysm Jul 11, 2025
00520c3
Tweak error message when value not found
penelopeysm Jul 11, 2025
d25f0be
Format (?!)
penelopeysm Jul 11, 2025
e20190c
Fix doctests
penelopeysm Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
## 0.13.0

Minimum compatibility has been bumped to Julia 1.10.

Added the new functions `hasvalue(container::T, ::VarName[, ::Distribution])` and `getvalue(container::T, ::VarName[, ::Distribution])`, where `T` is either `NamedTuple` or `AbstractDict{<:VarName}`.

These functions check whether a given `VarName` has a value in the given `NamedTuple` or `AbstractDict`, and return the value if it exists.

The optional `Distribution` argument allows one to reconstruct a full value from its component indices.
For example, if `container` has `x[1]` and `x[2]`, then `hasvalue(container, @varname(x), dist)` will return true if `size(dist) == (2,)` (for example, `MvNormal(zeros(2), I)`).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider having the third argument be the dimension, rather than the distribution? I'm not sure at all that this would be better, but it would avoid a dependence on Distributions.jl for hasvalue, and I was wondering if it has merit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were so close, it's just Cholesky that breaks it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another option actually, which is instead of taking a Distribution, take a value sampled from that Distribution. That means we would only need a dependency on LinearAlgebra rather than Distributions.

However, that would mean an additional call to rand() which, while quite minor in the grand scheme of things, I feel opposed to in principle.

In this case plain `hasvalue(container, @varname(x))` would return `false`, since we can not know whether the vector-valued variable `x` has all of its elements specified in `container` (there might be an `x[3]` missing).

These functions (without the `Distribution` argument) were previously in DynamicPPL.jl (albeit unexported).

## 0.12.0

### VarName constructors
Expand Down
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.12.0"
version = "0.13.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -13,11 +13,20 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[extensions]
AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"]

[compat]
AbstractMCMC = "2, 3, 4, 5"
Accessors = "0.1"
DensityInterface = "0.4"
Distributions = "0.25"
LinearAlgebra = "<0.0.1, 1.10"
JSON = "0.19 - 0.21"
Random = "1.6"
StatsBase = "0.32, 0.33, 0.34"
julia = "~1.6.6, 1.7.3"
julia = "1.10"
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[deps]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
using Documenter
using AbstractPPL
# trigger DistributionsExt loading
using Distributions, LinearAlgebra

# Doctest setup
DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true)

makedocs(;
sitename="AbstractPPL",
modules=[AbstractPPL],
modules=[AbstractPPL, Base.get_extension(AbstractPPL, :AbstractPPLDistributionsExt)],
pages=["Home" => "index.md", "API" => "api.md"],
checkdocs=:exports,
doctest=false,
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ prefix
unprefix
```

## Extracting values corresponding to a VarName

```@docs
hasvalue
getvalue
```

## VarName serialisation

```@docs
Expand Down
276 changes: 276 additions & 0 deletions ext/AbstractPPLDistributionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
module AbstractPPLDistributionsExt

using AbstractPPL: AbstractPPL, VarName, Accessors
using Distributions: Distributions
using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular

#=
This section is copied from Accessors.jl's documentation:
https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/

It defines a wrapper that, when called with `set`, mutates the original value
rather than returning a new value. We need this because the non-mutating optics
don't work for triangular matrices (and hence LKJCholesky): see
https://github.com/JuliaObjects/Accessors.jl/issues/203
=#
struct Lens!{L}
pure::L
end
(l::Lens!)(o) = l.pure(o)

Check warning on line 19 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L19

Added line #L19 was not covered by tests
function Accessors.set(o, l::Lens!{<:ComposedFunction}, val)
o_inner = l.pure.inner(o)
return Accessors.set(o_inner, Lens!(l.pure.outer), val)
end
function Accessors.set(o, l::Lens!{Accessors.PropertyLens{prop}}, val) where {prop}
setproperty!(o, prop, val)
return o

Check warning on line 26 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L24-L26

Added lines #L24 - L26 were not covered by tests
end
function Accessors.set(o, l::Lens!{<:Accessors.IndexLens}, val)
o[l.pure.indices...] = val
return o
end

"""
get_optics(dist::MultivariateDistribution)
get_optics(dist::MatrixDistribution)
get_optics(dist::LKJCholesky)

Return a complete set of optics for each element of the type returned by `rand(dist)`.
"""
function get_optics(
dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
)
indices = CartesianIndices(size(dist))
return map(idx -> Accessors.IndexLens(idx.I), indices)
end
function get_optics(dist::Distributions.LKJCholesky)
is_up = dist.uplo == 'U'
cartesian_indices = filter(CartesianIndices(size(dist))) do cartesian_index
i, j = cartesian_index.I
is_up ? i <= j : i >= j
end
# there is an additional layer as we need to access `.L` or `.U` before we
# can index into it
field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L)
return map(idx -> Accessors.IndexLens(idx.I) ∘ field_lens, cartesian_indices)
end
Comment on lines +46 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by the scenario in which we need this. LKJCholesky returns objects of type Cholesky, which are not AbstractArrays and can't be indexed. Would we ever have a situation where we would have something like

getvalue(Dict(@varname(x[1,1]) => 1.0, @varname(x[1,2]) => 0.0, @varname(x[2,2]) => 1.0), @varname(x), LKJCholesky(2, 0.5))

where it should return true?

PS. After reading the tests I know realise you need things like @varname(x.U[1,1]), but I still wonder if this would ever come up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this ambiguity could come up if the user manually mis-specified without the .L or .U. However, assuming that the varnames come from MCMCChains, then the varnames will have been constructed correctly (via varname_and_value_leaves... which ALSO special-cases Cholesky https://github.com/TuringLang/DynamicPPL.jl/blob/ce7c8b1ae48624e12ebf3064b6099e3dfca8c985/src/utils.jl#L1258-L1265)


"""
make_empty_value(dist::MultivariateDistribution)
make_empty_value(dist::MatrixDistribution)
make_empty_value(dist::LKJCholesky)

Construct a fresh value filled with zeros that corresponds to the size of `dist`.

For all distributions that this function accepts, it should hold that
`o(make_empty_value(dist))` is zero for all `o` in `get_optics(dist)`.
"""
function make_empty_value(
dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
)
return zeros(size(dist))
end
function make_empty_value(dist::Distributions.LKJCholesky)
if dist.uplo == 'U'
return Cholesky(UpperTriangular(zeros(size(dist))))
else
return Cholesky(LowerTriangular(zeros(size(dist))))
end
end

"""
hasvalue(
vals::Union{AbstractDict,NamedTuple},
vn::VarName,
dist::Distribution;
error_on_incomplete::Bool=false
)

Check if `vals` contains values for `vn` that is compatible with the
distribution `dist`.

This is a more general version of `hasvalue(vals, vn)`, in that even if
`vn` itself is not inside `vals`, it further checks if `vals` contains
sub-values of `vn` that can be used to reconstruct `vn` given `dist`.

The `error_on_incomplete` flag can be used to detect cases where _some_ of
the values needed for `vn` are present, but others are not. This may help
to detect invalid cases where the user has provided e.g. data of the wrong
shape.

Note that this check is only possible if a Dict is passed, because the key type
of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
information. If this method is called with a NamedTuple, it will just defer
to `hasvalue(vals, vn)`.

For example:

```jldoctest; setup=:(using Distributions, LinearAlgebra))
julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0);

julia> hasvalue(d, @varname(x), MvNormal(zeros(2), I))
true

julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I))
false

julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I); error_on_incomplete=true)
ERROR: only partial values for `x` found in the dictionary provided
[...]
```
"""
function AbstractPPL.hasvalue(

Check warning on line 122 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L122

Added line #L122 was not covered by tests
vals::NamedTuple,
vn::VarName,
dist::Distributions.Distribution;
error_on_incomplete::Bool=false,
)
# NamedTuples can't have such complicated hierarchies, so it's safe to
# defer to the simpler `hasvalue(vals, vn)`.
return hasvalue(vals, vn)

Check warning on line 130 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L130

Added line #L130 was not covered by tests
end
function AbstractPPL.hasvalue(

Check warning on line 132 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L132

Added line #L132 was not covered by tests
vals::AbstractDict,
vn::VarName,
dist::Distributions.Distribution;
error_on_incomplete::Bool=false,
)
@warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`."
return AbstractPPL.hasvalue(vals, vn)

Check warning on line 139 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L138-L139

Added lines #L138 - L139 were not covered by tests
end
function AbstractPPL.hasvalue(
vals::AbstractDict,
vn::VarName,
::Distributions.UnivariateDistribution;
error_on_incomplete::Bool=false,
)
# TODO(penelopeysm): We could also implement a check for the type to catch
# invalid values. Unsure if that is worth it. It may be easier to just let
# the user handle it.
return AbstractPPL.hasvalue(vals, vn)
end
function AbstractPPL.hasvalue(
vals::AbstractDict{<:VarName},
vn::VarName{sym},
dist::Union{
Distributions.MultivariateDistribution,
Distributions.MatrixDistribution,
Distributions.LKJCholesky,
};
error_on_incomplete::Bool=false,
) where {sym}
# If `vn` is present as-is, then we are good
AbstractPPL.hasvalue(vals, vn) && return true
# If not, then we need to check inside `vals` to see if a subset of
# `vals` is enough to reconstruct `vn`. For example, if `vals` contains
# `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we
# can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we
# can't.
# To do this, we get the size of the distribution and iterate over all
# possible indices. If every index can be found in `subsumed_keys`, then we
# can return true.
optics = get_optics(dist)
original_optic = AbstractPPL.getoptic(vn)
expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics)
if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
return true
else
if error_on_incomplete &&
any(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
error("only partial values for `$vn` found in the dictionary provided")
end
return false
end
end

"""
getvalue(
vals::Union{AbstractDict,NamedTuple},
vn::VarName,
dist::Distribution
)

Retrieve the value of `vn` from `vals`, using the distribution `dist` to
reconstruct the value if necessary.

This is a more general version of `getvalue(vals, vn)`, in that even if `vn`
itself is not inside `vals`, it can still reconstruct the value of `vn`
from sub-values of `vn` that are present in `vals`.

Note that this reconstruction is only possible if a Dict is passed, because the
key type of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
information. If this method is called with a NamedTuple, it will just defer
to `getvalue(vals, vn)`.

For example:

```jldoctest; setup=:(using Distributions, LinearAlgebra))
julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0);

julia> getvalue(d, @varname(x), MvNormal(zeros(2), I))
2-element Vector{Float64}:
1.0
2.0

julia> # Use `hasvalue` to check for this case before calling `getvalue`.
getvalue(d, @varname(x), MvNormal(zeros(3), I))
ERROR: `x` was not found in the dictionary provided
[...]
```
"""
function AbstractPPL.getvalue(

Check warning on line 221 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L221

Added line #L221 was not covered by tests
vals::NamedTuple, vn::VarName, dist::Distributions.Distribution
)
# NamedTuples can't have such complicated hierarchies, so it's safe to
# defer to the simpler `getvalue(vals, vn)`.
return getvalue(vals, vn)

Check warning on line 226 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L226

Added line #L226 was not covered by tests
end
function AbstractPPL.getvalue(

Check warning on line 228 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L228

Added line #L228 was not covered by tests
vals::AbstractDict, vn::VarName, dist::Distributions.Distribution;
)
@warn "`getvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `getvalue(vals, vn)`."
return AbstractPPL.getvalue(vals, vn)

Check warning on line 232 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L231-L232

Added lines #L231 - L232 were not covered by tests
end
function AbstractPPL.getvalue(
vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution;
)
# TODO(penelopeysm): We could also implement a check for the type to catch
# invalid values. Unsure if that is worth it. It may be easier to just let
# the user handle it.
return AbstractPPL.getvalue(vals, vn)
end
function AbstractPPL.getvalue(
vals::AbstractDict{<:VarName},
vn::VarName{sym},
dist::Union{
Distributions.MultivariateDistribution,
Distributions.MatrixDistribution,
Distributions.LKJCholesky,
};
) where {sym}
# If `vn` is present as-is, then we can just return that
AbstractPPL.hasvalue(vals, vn) && return AbstractPPL.getvalue(vals, vn)
# If not, then we need to start looking inside `vals`, in exactly the
# same way we did for `hasvalue`.
optics = get_optics(dist)
original_optic = AbstractPPL.getoptic(vn)
expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics)
if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
# Reconstruct the value index by index.
value = make_empty_value(dist)
for (o, sub_vn) in zip(optics, expected_vns)
# Retrieve the value of this given index
sub_value = AbstractPPL.getvalue(vals, sub_vn)
# Set it inside the value we're reconstructing.
# Note: `o` is normally non-mutating. We have to wrap it in `Lens!`
# to make it mutating, because Cholesky distributions are broken
# by https://github.com/JuliaObjects/Accessors.jl/issues/203.
Accessors.set(value, Lens!(o), sub_value)
end
return value
else
error("$(vn) was not found in the dictionary provided")

Check warning on line 272 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L272

Added line #L272 was not covered by tests
end
end

end
5 changes: 4 additions & 1 deletion src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ export VarName,
varname_to_string,
string_to_varname,
prefix,
unprefix
unprefix,
getvalue,
hasvalue

# Abstract model functions
export AbstractProbabilisticProgram,
Expand All @@ -29,5 +31,6 @@ include("varname.jl")
include("abstractmodeltrace.jl")
include("abstractprobprog.jl")
include("evaluate.jl")
include("hasvalue.jl")

end # module
Loading
Loading