Skip to content

Move hasvalue and getvalue from DynamicPPL; implement extra Distributions-based methods #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jul 5, 2025

Julia minimum version bump

I bumped to 1.10 as I don't want to add extra code to handle extensions on pre-1.9. Most important packages in TuringLang are already using >= 1.10 anyway.

Moving functions from DynamicPPL

This PR moves hasvalue and getvalue from DynamicPPL to AbstractPPL. https://github.com/TuringLang/DynamicPPL.jl/blob/92f6eea8660be2142fa4087e5e025f37026bfa45/src/utils.jl#L763-L954

A lot of the helper functions in DynamicPPL are not actually needed because there is existing functionality in here that accomplishes much the same. I modified the implementations accordingly.

Distributions-based methods

This part is new and warrants more explanation. To begin, notice the default behaviour of hasvalue:

julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 1.0)
Dict{VarName{:x, Accessors.IndexLens{Tuple{Int64}}}, Float64} with 2 entries:
  x[1] => 1.0
  x[2] => 1.0

julia> hasvalue(d, @varname(x))
false

This makes sense, because d alone does not give us enough information to reconstruct some arbitrary variable x.

However, let's say that we know x is to be sampled from a given distribution dist. In this case, we do have enough information to determine whether x can be reconstructed. This PR therefore also implements the following methods:

julia> using Distributions, LinearAlgebra

julia> hasvalue(d, @varname(x), MvNormal(zeros(2), I))
true

julia> getvalue(d, @varname(x), MvNormal(zeros(2), I))
[1.0, 1.0]

julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I))
false

The motivation for this is to (properly) fix issues where values for multivariate distributions are specified separately, see e.g., TuringLang/DynamicPPL.jl#712, see also this comment TuringLang/DynamicPPL.jl#710 (comment).

One might argue that we should force users to specify things properly, i.e., if x ~ MvNormal(zeros(2), I) then the user should condition on Dict(@varname(x) => [1.0, 1.0]) rather than Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 1.0). In an ideal world I would do that, and even now, I would still advocate for making this general guideline clear in e.g. the docs.

However, there remains one specific case where this isn't enough, namely in DynamicPPL's predict(model, chain) or returned(model, chain). These methods require extracting variable values from chain, inserting them into a VarInfo, and rerunning the model with the given values. Unfortunately, chain is a lossy storage format, because array-valued variables like x are split up into x[1] and x[2] and it's not possible to recover the original shape of x.

Up until this PR, this has been handled in DynamicPPL using the setval_and_resample! and nested_setindex_maybe methods which perform some direct manipulation of VarInfos. I think these methods are slightly dangerous and can lead to subtle bugs, for example, if only part of the variable x is given, it marks the entire variable x as to be not-resampled: https://github.com/TuringLang/DynamicPPL.jl/blob/92f6eea8660be2142fa4087e5e025f37026bfa45/src/varinfo.jl#L2177-L2181

The good news, though, is that when evaluating a model, we have access to the distribution that x is supposed to be sampled from. Thus, we can determine whether enough of the x[i]'s are given to reconstruct it, which is what these new methods do. So, we can deal with this in a more principled fashion: if we can find all the indices needed to reconstruct the value of x, then we can confidently set that value; if we can't, then we don't even attempt to set any of the individual indices because hasvalue will return false.

Remaining questions:

  • I wonder if we can simplify the API? Note that hasvalue and getvalue have extremely similar logic, do we really need to have two functions with almost the same implementation? I've held off on attempting to do this because I'm worried about type stability, i.e. getvalue is inherently type-unstable, and maybe guarding calls to getvalue behind a call to hasvalue avoids leaking type instability into the caller function. However, I think this is reliant on the compiler being able to infer the return value of hasvalue through e.g. constant propagation?!
  • Not sure if this should be a minor bump. According to semver, nothing in here is breaking, hence I did patch bump. But the changes are quite large and maybe it feels more correct to do a minor bump.

TODO

  • hasvalue for other distributions
  • getvalue for distributions
  • Appropriate tests
  • API documentation for the distributions bits
  • Changelog

This PR doesn't support ProductNamedTupleDistribution. It shouldn't be overly complicated to implement IMO. However, almost nothing else in TuringLang works with ProductNamedTupleDistribution, so I don't feel bad not implementing it.

Closes #124

This is required for the InitContext PR TuringLang/DynamicPPL.jl#967 as ParamsInit needs to use hasvalue and getvalue. Specifically, I also want to use ParamsInit to handle predict, hence the need for the Distributions-based methods.

@penelopeysm penelopeysm marked this pull request as draft July 5, 2025 22:34
Copy link

codecov bot commented Jul 5, 2025

Codecov Report

Attention: Patch coverage is 85.38462% with 19 lines in your changes missing coverage. Please review.

Project coverage is 86.28%. Comparing base (7be9556) to head (e20190c).

Files with missing lines Patch % Lines
ext/AbstractPPLDistributionsExt.jl 76.92% 15 Missing ⚠️
src/hasvalue.jl 93.02% 3 Missing ⚠️
src/varname.jl 95.45% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #125      +/-   ##
==========================================
+ Coverage   83.56%   86.28%   +2.72%     
==========================================
  Files           2        5       +3     
  Lines         292      401     +109     
==========================================
+ Hits          244      346     +102     
- Misses         48       55       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Base automatically changed from py/composed-assoc to main July 6, 2025 10:58
@coveralls
Copy link

coveralls commented Jul 6, 2025

Pull Request Test Coverage Report for Build 16221573874

Details

  • 111 of 130 (85.38%) changed or added relevant lines in 3 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+2.7%) to 86.284%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/varname.jl 21 22 95.45%
src/hasvalue.jl 40 43 93.02%
ext/AbstractPPLDistributionsExt.jl 50 65 76.92%
Totals Coverage Status
Change from base Build 16098285788: 2.7%
Covered Lines: 346
Relevant Lines: 401

💛 - Coveralls

@TuringLang TuringLang deleted a comment from github-actions bot Jul 6, 2025
Copy link
Contributor

github-actions bot commented Jul 6, 2025

AbstractPPL.jl documentation for PR #125 is available at:
https://TuringLang.github.io/AbstractPPL.jl/previews/PR125/

@penelopeysm penelopeysm marked this pull request as ready for review July 6, 2025 19:03
@penelopeysm penelopeysm requested a review from mhauru July 7, 2025 10:34
Comment on lines -968 to -980
Remove identity lenses from composed optics.
"""
_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity
function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
return _strip_identity(o.outer)
end
function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner}
return _strip_identity(o.inner)
end
_strip_identity(o::Base.ComposedFunction) = o
_strip_identity(o::Accessors.PropertyLens) = o
_strip_identity(o::Accessors.IndexLens) = o
_strip_identity(o::typeof(identity)) = o
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalise strips identities now so this function isn't needed any more

Comment on lines +995 to +998
_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner
_head(o::Accessors.PropertyLens) = o
_head(o::Accessors.IndexLens) = o
_head(::typeof(identity)) = identity
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_head, _tail, _init, and _last take their names from the equivalent Haskell functions on linked lists:

λ> head [1,2,3]
1
λ> tail [1,2,3]
[2,3]
λ> init [1,2,3]
[1,2]
λ> last [1,2,3]
3

-- empty list is turned into identity in our case
λ> head [1]
1
λ> tail [1]
[]
λ> init [1]
[]
λ> last [1]
1

src/hasvalue.jl Outdated
Comment on lines 133 to 135
else
error("getvalue: $(vn) was not found in the values provided")
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually a bit on the fence about this error message. From the perspective of AbstractPPL, it makes perfect sense. But when we use this in DynamicPPL we get things like this:

  julia> vi[@varname(x)]
- ERROR: KeyError: key x not found
+ ERROR: getvalue: x was not found in the values provided

which seems to unnecessarily leak details of internal implementation.

Previously in DynamicPPL it would throw a KeyError which made sense when calling getindex. However, it doesn't make sense to throw a KeyError here because getvalue isn't indexing into a dictionary. So I'm struggling a bit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the error message is bad. Error messages quite often reveal implementation details, I don't think that can be avoided except by using a lot of try-catch higher up (in this case in DPPL), which I think is usually far too much complexity to be worth it. My only alternative would be "$(vn) was not found in the NamedTuple provided".

Also, if we could skip the canview call, I wonder if that could have a (small but) noticeable impact on performance, due to not bounds checking twice. Or maybe the return optic(vals[sym]) call could be done with @inbounds?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message changed, I think I do prefer a bit not mentioning the function name (it's in the stacktrace anyway).

Re. canview: I think this ties into one of the things I wasn't fully sure about (on the main comment).

The idea is that you could use getvalue without having checked hasvalue, so in principle you do need to check bounds in both. That is very annoying not only because of the double bounds check but also because of the severe code duplication (notice that this PR could have been half the lines if we combined the two functions into one).

I don't fully know how to solve this though. The obvious solution would be for getvalue to return a sentinel value if it's not found. However I worry that that would introduce type instability.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with the code except for some minor localised comments, but confused about whether we really need to deal with LKJCholesky, and thus whether we could restrict ourselves to just caring about dimensions rather than more complicated forms of distribution outputs.

These functions check whether a given `VarName` has a value in the given `NamedTuple` or `AbstractDict`, and return the value if it exists.

The optional `Distribution` argument allows one to reconstruct a full value from its component indices.
For example, if `container` has `x[1]` and `x[2]`, then `hasvalue(container, @varname(x), dist)` will return true if `size(dist) == (2,)` (for example, `MvNormal(zeros(2), I)`).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider having the third argument be the dimension, rather than the distribution? I'm not sure at all that this would be better, but it would avoid a dependence on Distributions.jl for hasvalue, and I was wondering if it has merit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were so close, it's just Cholesky that breaks it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another option actually, which is instead of taking a Distribution, take a value sampled from that Distribution. That means we would only need a dependency on LinearAlgebra rather than Distributions.

However, that would mean an additional call to rand() which, while quite minor in the grand scheme of things, I feel opposed to in principle.

src/hasvalue.jl Outdated
Comment on lines 133 to 135
else
error("getvalue: $(vn) was not found in the values provided")
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the error message is bad. Error messages quite often reveal implementation details, I don't think that can be avoided except by using a lot of try-catch higher up (in this case in DPPL), which I think is usually far too much complexity to be worth it. My only alternative would be "$(vn) was not found in the NamedTuple provided".

Also, if we could skip the canview call, I wonder if that could have a (small but) noticeable impact on performance, due to not bounds checking twice. Or maybe the return optic(vals[sym]) call could be done with @inbounds?

Comment on lines +46 to +56
function get_optics(dist::Distributions.LKJCholesky)
is_up = dist.uplo == 'U'
cartesian_indices = filter(CartesianIndices(size(dist))) do cartesian_index
i, j = cartesian_index.I
is_up ? i <= j : i >= j
end
# there is an additional layer as we need to access `.L` or `.U` before we
# can index into it
field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L)
return map(idx -> Accessors.IndexLens(idx.I) ∘ field_lens, cartesian_indices)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by the scenario in which we need this. LKJCholesky returns objects of type Cholesky, which are not AbstractArrays and can't be indexed. Would we ever have a situation where we would have something like

getvalue(Dict(@varname(x[1,1]) => 1.0, @varname(x[1,2]) => 0.0, @varname(x[2,2]) => 1.0), @varname(x), LKJCholesky(2, 0.5))

where it should return true?

PS. After reading the tests I know realise you need things like @varname(x.U[1,1]), but I still wonder if this would ever come up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this ambiguity could come up if the user manually mis-specified without the .L or .U. However, assuming that the varnames come from MCMCChains, then the varnames will have been constructed correctly (via varname_and_value_leaves... which ALSO special-cases Cholesky https://github.com/TuringLang/DynamicPPL.jl/blob/ce7c8b1ae48624e12ebf3064b6099e3dfca8c985/src/utils.jl#L1258-L1265)

Comment on lines 161 to 183
) where {sym}
# If `vn` is present as-is, then we are good
AbstractPPL.hasvalue(vals, vn) && return true
# If not, then we need to check inside `vals` to see if a subset of
# `vals` is enough to reconstruct `vn`. For example, if `vals` contains
# `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we
# can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we
# can't.
# To do this, we get the size of the distribution and iterate over all
# possible indices. If every index can be found in `subsumed_keys`, then we
# can return true.
optics = get_optics(dist)
original_optic = AbstractPPL.getoptic(vn)
expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics)
if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
return true
else
if error_on_incomplete &&
any(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
error("hasvalue: only partial values for `$vn` found in the values provided")
end
return false
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens, and what should happen, with the following?

hasvalue(
    Dict(@varname(x) => [1.0], @varname(x[2]) => 2.0),
    @varname(x),
    MvNormal(zeros(2), I),
)

I think this returns true on line 163 even though the shape of @varname(x) => [1.0] doesn't match the distribution. But even if it didn't, I think it would it also return true because @varname(x[1]) is found in @varname(x) => [1.0] and @varname(x[2]) is found in @varname(x[2]) => 2.0. That doesn't quite feel right though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(cf. comment below in main thread)

Comment on lines +167 to +168
@test hasvalue(d, @varname(x), MvNormal(zeros(1), I))
@test getvalue(d, @varname(x), MvNormal(zeros(1), I)) == [1.0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not wrong, but it feels funny.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(cf. comment below in main thread)

@penelopeysm
Copy link
Member Author

penelopeysm commented Jul 11, 2025

Hmm. I guess maybe the biggest overall takeaway is that although the implementation makes sense, there is an overarching theme of ill-defined behaviour for wrongly specified input, which I totally get, I felt the same way writing it.

The problem is that this ill-defined behaviour already used to exist in e.g. DynamicPPL's nested_setindex_maybe. You would run into such behaviour if for example you passed in a chain with the wrong variables (like if you had MvNormal(zeros(1)) and passed in a chain with x[1] and x[2], I'm not sure what it would do with x[2]).

Assuming that such ill-defined behaviour needs to exist somewhere in order for us to keep compatibility with predict(..., ::MCMCChains), how about the following compromise solution?

This is where hasvalue(..., dist) gets (or will get) used upstream (TuringLang/DynamicPPL.jl#981):

https://github.com/TuringLang/DynamicPPL.jl/blob/b55c1e17f97ae518d1d149122e1fb1055557183f/src/contexts/init.jl#L93-L110

Right now, it just unconditionally uses the dist method. Maybe we can include a flag to ParamsInit which tells us when we allow using the dist method, and restrict its use to specific cases where we absolutely require this behaviour, i.e. predict.

That would forbid, for example, people sampling and specifying initial values that don't align with the model.

@penelopeysm penelopeysm requested a review from mhauru July 11, 2025 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

hasvalue and getvalue
3 participants