-
Notifications
You must be signed in to change notification settings - Fork 9
Move hasvalue
and getvalue
from DynamicPPL; implement extra Distributions-based methods
#125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #125 +/- ##
==========================================
+ Coverage 83.56% 86.28% +2.72%
==========================================
Files 2 5 +3
Lines 292 401 +109
==========================================
+ Hits 244 346 +102
- Misses 48 55 +7 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
9310e27
to
07975a2
Compare
Pull Request Test Coverage Report for Build 16221573874Details
💛 - Coveralls |
1e3ed7c
to
9291e07
Compare
AbstractPPL.jl documentation for PR #125 is available at: |
6fe03f9
to
398b42f
Compare
398b42f
to
6a01588
Compare
Remove identity lenses from composed optics. | ||
""" | ||
_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity | ||
function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} | ||
return _strip_identity(o.outer) | ||
end | ||
function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner} | ||
return _strip_identity(o.inner) | ||
end | ||
_strip_identity(o::Base.ComposedFunction) = o | ||
_strip_identity(o::Accessors.PropertyLens) = o | ||
_strip_identity(o::Accessors.IndexLens) = o | ||
_strip_identity(o::typeof(identity)) = o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normalise
strips identities now so this function isn't needed any more
_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner | ||
_head(o::Accessors.PropertyLens) = o | ||
_head(o::Accessors.IndexLens) = o | ||
_head(::typeof(identity)) = identity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_head
, _tail
, _init
, and _last
take their names from the equivalent Haskell functions on linked lists:
λ> head [1,2,3]
1
λ> tail [1,2,3]
[2,3]
λ> init [1,2,3]
[1,2]
λ> last [1,2,3]
3
-- empty list is turned into identity in our case
λ> head [1]
1
λ> tail [1]
[]
λ> init [1]
[]
λ> last [1]
1
3a254b4
to
ecf0eac
Compare
src/hasvalue.jl
Outdated
else | ||
error("getvalue: $(vn) was not found in the values provided") | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually a bit on the fence about this error message. From the perspective of AbstractPPL, it makes perfect sense. But when we use this in DynamicPPL we get things like this:
julia> vi[@varname(x)]
- ERROR: KeyError: key x not found
+ ERROR: getvalue: x was not found in the values provided
which seems to unnecessarily leak details of internal implementation.
Previously in DynamicPPL it would throw a KeyError which made sense when calling getindex
. However, it doesn't make sense to throw a KeyError here because getvalue
isn't indexing into a dictionary. So I'm struggling a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the error message is bad. Error messages quite often reveal implementation details, I don't think that can be avoided except by using a lot of try-catch
higher up (in this case in DPPL), which I think is usually far too much complexity to be worth it. My only alternative would be "$(vn) was not found in the NamedTuple provided"
.
Also, if we could skip the canview
call, I wonder if that could have a (small but) noticeable impact on performance, due to not bounds checking twice. Or maybe the return optic(vals[sym])
call could be done with @inbounds
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error message changed, I think I do prefer a bit not mentioning the function name (it's in the stacktrace anyway).
Re. canview
: I think this ties into one of the things I wasn't fully sure about (on the main comment).
The idea is that you could use getvalue
without having checked hasvalue
, so in principle you do need to check bounds in both. That is very annoying not only because of the double bounds check but also because of the severe code duplication (notice that this PR could have been half the lines if we combined the two functions into one).
I don't fully know how to solve this though. The obvious solution would be for getvalue
to return a sentinel value if it's not found. However I worry that that would introduce type instability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with the code except for some minor localised comments, but confused about whether we really need to deal with LKJCholesky, and thus whether we could restrict ourselves to just caring about dimensions rather than more complicated forms of distribution outputs.
These functions check whether a given `VarName` has a value in the given `NamedTuple` or `AbstractDict`, and return the value if it exists. | ||
|
||
The optional `Distribution` argument allows one to reconstruct a full value from its component indices. | ||
For example, if `container` has `x[1]` and `x[2]`, then `hasvalue(container, @varname(x), dist)` will return true if `size(dist) == (2,)` (for example, `MvNormal(zeros(2), I)`). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you consider having the third argument be the dimension, rather than the distribution? I'm not sure at all that this would be better, but it would avoid a dependence on Distributions.jl for hasvalue
, and I was wondering if it has merit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were so close, it's just Cholesky that breaks it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is another option actually, which is instead of taking a Distribution, take a value sampled from that Distribution. That means we would only need a dependency on LinearAlgebra rather than Distributions.
However, that would mean an additional call to rand()
which, while quite minor in the grand scheme of things, I feel opposed to in principle.
src/hasvalue.jl
Outdated
else | ||
error("getvalue: $(vn) was not found in the values provided") | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the error message is bad. Error messages quite often reveal implementation details, I don't think that can be avoided except by using a lot of try-catch
higher up (in this case in DPPL), which I think is usually far too much complexity to be worth it. My only alternative would be "$(vn) was not found in the NamedTuple provided"
.
Also, if we could skip the canview
call, I wonder if that could have a (small but) noticeable impact on performance, due to not bounds checking twice. Or maybe the return optic(vals[sym])
call could be done with @inbounds
?
function get_optics(dist::Distributions.LKJCholesky) | ||
is_up = dist.uplo == 'U' | ||
cartesian_indices = filter(CartesianIndices(size(dist))) do cartesian_index | ||
i, j = cartesian_index.I | ||
is_up ? i <= j : i >= j | ||
end | ||
# there is an additional layer as we need to access `.L` or `.U` before we | ||
# can index into it | ||
field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L) | ||
return map(idx -> Accessors.IndexLens(idx.I) ∘ field_lens, cartesian_indices) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused by the scenario in which we need this. LKJCholesky returns objects of type Cholesky
, which are not AbstractArray
s and can't be indexed. Would we ever have a situation where we would have something like
getvalue(Dict(@varname(x[1,1]) => 1.0, @varname(x[1,2]) => 0.0, @varname(x[2,2]) => 1.0), @varname(x), LKJCholesky(2, 0.5))
where it should return true
?
PS. After reading the tests I know realise you need things like @varname(x.U[1,1])
, but I still wonder if this would ever come up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this ambiguity could come up if the user manually mis-specified without the .L
or .U
. However, assuming that the varnames come from MCMCChains, then the varnames will have been constructed correctly (via varname_and_value_leaves
... which ALSO special-cases Cholesky https://github.com/TuringLang/DynamicPPL.jl/blob/ce7c8b1ae48624e12ebf3064b6099e3dfca8c985/src/utils.jl#L1258-L1265)
ext/AbstractPPLDistributionsExt.jl
Outdated
) where {sym} | ||
# If `vn` is present as-is, then we are good | ||
AbstractPPL.hasvalue(vals, vn) && return true | ||
# If not, then we need to check inside `vals` to see if a subset of | ||
# `vals` is enough to reconstruct `vn`. For example, if `vals` contains | ||
# `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we | ||
# can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we | ||
# can't. | ||
# To do this, we get the size of the distribution and iterate over all | ||
# possible indices. If every index can be found in `subsumed_keys`, then we | ||
# can return true. | ||
optics = get_optics(dist) | ||
original_optic = AbstractPPL.getoptic(vn) | ||
expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics) | ||
if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns) | ||
return true | ||
else | ||
if error_on_incomplete && | ||
any(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns) | ||
error("hasvalue: only partial values for `$vn` found in the values provided") | ||
end | ||
return false | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens, and what should happen, with the following?
hasvalue(
Dict(@varname(x) => [1.0], @varname(x[2]) => 2.0),
@varname(x),
MvNormal(zeros(2), I),
)
I think this returns true
on line 163 even though the shape of @varname(x) => [1.0]
doesn't match the distribution. But even if it didn't, I think it would it also return true because @varname(x[1])
is found in @varname(x) => [1.0]
and @varname(x[2])
is found in @varname(x[2]) => 2.0
. That doesn't quite feel right though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(cf. comment below in main thread)
@test hasvalue(d, @varname(x), MvNormal(zeros(1), I)) | ||
@test getvalue(d, @varname(x), MvNormal(zeros(1), I)) == [1.0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not wrong, but it feels funny.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(cf. comment below in main thread)
Hmm. I guess maybe the biggest overall takeaway is that although the implementation makes sense, there is an overarching theme of ill-defined behaviour for wrongly specified input, which I totally get, I felt the same way writing it. The problem is that this ill-defined behaviour already used to exist in e.g. DynamicPPL's Assuming that such ill-defined behaviour needs to exist somewhere in order for us to keep compatibility with This is where Right now, it just unconditionally uses the That would forbid, for example, people sampling and specifying initial values that don't align with the model. |
Co-authored-by: Markus Hauru <markus@mhauru.org>
Co-authored-by: Markus Hauru <markus@mhauru.org>
Julia minimum version bump
I bumped to 1.10 as I don't want to add extra code to handle extensions on pre-1.9. Most important packages in TuringLang are already using >= 1.10 anyway.
Moving functions from DynamicPPL
This PR moves
hasvalue
andgetvalue
from DynamicPPL to AbstractPPL. https://github.com/TuringLang/DynamicPPL.jl/blob/92f6eea8660be2142fa4087e5e025f37026bfa45/src/utils.jl#L763-L954A lot of the helper functions in DynamicPPL are not actually needed because there is existing functionality in here that accomplishes much the same. I modified the implementations accordingly.
Distributions-based methods
This part is new and warrants more explanation. To begin, notice the default behaviour of
hasvalue
:This makes sense, because
d
alone does not give us enough information to reconstruct some arbitrary variablex
.However, let's say that we know
x
is to be sampled from a given distributiondist
. In this case, we do have enough information to determine whetherx
can be reconstructed. This PR therefore also implements the following methods:The motivation for this is to (properly) fix issues where values for multivariate distributions are specified separately, see e.g., TuringLang/DynamicPPL.jl#712, see also this comment TuringLang/DynamicPPL.jl#710 (comment).
One might argue that we should force users to specify things properly, i.e., if
x ~ MvNormal(zeros(2), I)
then the user should condition onDict(@varname(x) => [1.0, 1.0])
rather thanDict(@varname(x[1]) => 1.0, @varname(x[2]) => 1.0)
. In an ideal world I would do that, and even now, I would still advocate for making this general guideline clear in e.g. the docs.However, there remains one specific case where this isn't enough, namely in DynamicPPL's
predict(model, chain)
orreturned(model, chain)
. These methods require extracting variable values fromchain
, inserting them into a VarInfo, and rerunning the model with the given values. Unfortunately,chain
is a lossy storage format, because array-valued variables likex
are split up intox[1]
andx[2]
and it's not possible to recover the original shape ofx
.Up until this PR, this has been handled in DynamicPPL using the
setval_and_resample!
andnested_setindex_maybe
methods which perform some direct manipulation of VarInfos. I think these methods are slightly dangerous and can lead to subtle bugs, for example, if only part of the variablex
is given, it marks the entire variablex
as to be not-resampled: https://github.com/TuringLang/DynamicPPL.jl/blob/92f6eea8660be2142fa4087e5e025f37026bfa45/src/varinfo.jl#L2177-L2181The good news, though, is that when evaluating a model, we have access to the distribution that
x
is supposed to be sampled from. Thus, we can determine whether enough of thex[i]
's are given to reconstruct it, which is what these new methods do. So, we can deal with this in a more principled fashion: if we can find all the indices needed to reconstruct the value ofx
, then we can confidently set that value; if we can't, then we don't even attempt to set any of the individual indices becausehasvalue
will return false.Remaining questions:
hasvalue
andgetvalue
have extremely similar logic, do we really need to have two functions with almost the same implementation? I've held off on attempting to do this because I'm worried about type stability, i.e.getvalue
is inherently type-unstable, and maybe guarding calls togetvalue
behind a call tohasvalue
avoids leaking type instability into the caller function. However, I think this is reliant on the compiler being able to infer the return value ofhasvalue
through e.g. constant propagation?!TODO
hasvalue
for other distributionsgetvalue
for distributionsThis PR doesn't support
ProductNamedTupleDistribution
. It shouldn't be overly complicated to implement IMO. However, almost nothing else in TuringLang works withProductNamedTupleDistribution
, so I don't feel bad not implementing it.Closes #124
This is required for the
InitContext
PR TuringLang/DynamicPPL.jl#967 asParamsInit
needs to usehasvalue
andgetvalue
. Specifically, I also want to use ParamsInit to handlepredict
, hence the need for the Distributions-based methods.