From a34fb044798844156f6ffba86db39517a45e590c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 15:52:35 +0000 Subject: [PATCH 01/14] Init release 0.35 --- .github/workflows/CI.yml | 2 ++ Project.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index fce8d9e30..722993c11 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -5,10 +5,12 @@ on: branches: - master - backport-* + - release-0.35 pull_request: branches: - master - backport-* + - release-0.35 merge_group: types: [checks_requested] diff --git a/Project.toml b/Project.toml index 2bf60214f..e0561147b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.33.1" +version = "0.35.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 161edf83c6374dbca617c4fc7eeaa033d06488af Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 30 Jan 2025 11:42:36 +0000 Subject: [PATCH 02/14] Reverse order of prefixing & add changelog (#792) * Reverse order of prefixing * Simplify generated function (the non-generated path wasn't being used) * Expand on submodel behaviour in changelog --- HISTORY.md | 93 ++++++++++++++++++++++++++++++++++++++++++++++++ src/contexts.jl | 27 ++++++-------- test/contexts.jl | 14 ++++---- 3 files changed, 111 insertions(+), 23 deletions(-) create mode 100644 HISTORY.md diff --git a/HISTORY.md b/HISTORY.md new file mode 100644 index 000000000..f77d3fa74 --- /dev/null +++ b/HISTORY.md @@ -0,0 +1,93 @@ +# DynamicPPL Changelog + +## 0.35.0 + +**Breaking** + +- For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed. + Previously, the order was that outer prefixes were applied first, then inner ones. + This version reverses that. + To illustrate: + + ```julia + using DynamicPPL, Distributions + + @model function subsubmodel() + x ~ Normal() + end + + @model function submodel() + x ~ to_submodel(prefix(subsubmodel(), :c), false) + return x + end + + @model function parentmodel() + x1 ~ to_submodel(prefix(submodel(), :a), false) + x2 ~ to_submodel(prefix(submodel(), :b), false) + end + + keys(VarInfo(parentmodel())) + ``` + + Previously, the final line would return the variable names `c.a.x` and `c.b.x`. + With this version, it will return `a.c.x` and `b.c.x`, which is more intuitive. + (Note that this change brings `to_submodel`'s behaviour in line with the now-deprecated `@submodel` macro.) + + This change also affects sampling in Turing.jl. + + +## 0.34.2 + +- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. + From a user-facing perspective, this means that for models which use manually prefixed submodels, e.g. + + ```julia + using DynamicPPL, Distributions + + @model inner() = x ~ Normal() + + @model function outer() + x1 ~ to_submodel(prefix(inner(), :a), false) + x2 ~ to_submodel(prefix(inner(), :b), false) + end + ``` + + will: (1) no longer error when sampling due to `check_model_and_trace`; and (2) contain both submodel's variables in the resulting chain (the behaviour before this patch was that the second `x` would override the first `x`). + +- More broadly, implemented a general `prefix(ctx::AbstractContext, ::VarName)` which traverses the context tree in `ctx` to apply all necessary prefixes. This was a necessary step in fixing the above issues, but it also means that `prefix` is now capable of handling context trees with e.g. multiple prefixes at different levels of nesting. + +## 0.34.1 + +- Fix an issue that prevented merging two VarInfos if they had different dimensions for a variable. + +- Upper bound the compat version of KernelAbstractions to work around an issue in determining the right VarInfo type to use. + +## 0.34.0 + +**Breaking** + +- `rng` argument removed from `values_as_in_model`, and `varinfo` made non-optional. This means that the only signatures allowed are + + ``` + values_as_in_model(::Model, ::Bool, ::AbstractVarInfo) + values_as_in_model(::Model, ::Bool, ::AbstractVarInfo, ::AbstractContext) + ``` + + If you aren't using this function (it's probably only used in Turing.jl) then this won't affect you. + +## 0.33.1 + +Reworked internals of `condition` and `decondition`. +There are no changes to the public-facing API, but internally you can no longer use `condition` and `decondition` on an `AbstractContext`, you can only use it on a `DynamicPPL.Model`. If you want to modify a context, use `ConditionContext` and `decondition_context`. + +## 0.33.0 + +**Breaking** + +- `values_as_in_model()` now requires an extra boolean parameter, specifying whether variables on the lhs of `:=` statements are to be included in the resulting `OrderedDict` of values. + The type signature is now `values_as_in_model([rng,] model, include_colon_eq::Bool [, varinfo, context])` + +**Other** + +- Moved the implementation of `predict` from Turing.jl to DynamicPPL.jl; the user-facing behaviour is otherwise the same +- Improved error message when a user tries to initialise a model with parameters that don't correspond strictly to the underlying VarInfo used diff --git a/src/contexts.jl b/src/contexts.jl index 99b2136f3..0b4633283 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -253,31 +253,26 @@ function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} return PrefixContext{Prefix,typeof(context)}(context) end -NodeTrait(context::PrefixContext) = IsParent() +NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -function setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} +function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} return PrefixContext{Prefix}(child) end const PREFIX_SEPARATOR = Symbol(".") -# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here -function PrefixContext{PrefixInner}( - context::PrefixContext{PrefixOuter} -) where {PrefixInner,PrefixOuter} - if @generated - :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - context.context - )) - else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) - end +@generated function PrefixContext{PrefixOuter}( + context::PrefixContext{PrefixInner} +) where {PrefixOuter,PrefixInner} + return :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( + context.context + )) end -# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - return prefix( - childcontext(ctx), VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn)) + vn_prefixed_inner = prefix(childcontext(ctx), vn) + return VarName{Symbol(Prefix, PREFIX_SEPARATOR, getsym(vn_prefixed_inner))}( + getoptic(vn_prefixed_inner) ) end prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn) diff --git a/test/contexts.jl b/test/contexts.jl index ef55335d0..faa831cc1 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -142,11 +142,11 @@ end @testset "PrefixContext" begin @testset "prefixing" begin - ctx = @inferred PrefixContext{:f}( - PrefixContext{:e}( - PrefixContext{:d}( - PrefixContext{:c}( - PrefixContext{:b}(PrefixContext{:a}(DefaultContext())) + ctx = @inferred PrefixContext{:a}( + PrefixContext{:b}( + PrefixContext{:c}( + PrefixContext{:d}( + PrefixContext{:e}(PrefixContext{:f}(DefaultContext())) ), ), ), @@ -174,8 +174,8 @@ end vn_prefixed4 = prefix(ctx4, vn) @test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x") @test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed3) == Symbol("a.b.x") - @test DynamicPPL.getsym(vn_prefixed4) == Symbol("a.b.x") + @test DynamicPPL.getsym(vn_prefixed3) == Symbol("b.a.x") + @test DynamicPPL.getsym(vn_prefixed4) == Symbol("b.a.x") @test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn) @test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn) @test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn) From 7140f3dba5a961d331d940e0e1b3826c33acec60 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 30 Jan 2025 11:45:17 +0000 Subject: [PATCH 03/14] Format HISTORY.md --- HISTORY.md | 123 ++++++++++++++++++++++++++--------------------------- 1 file changed, 61 insertions(+), 62 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f77d3fa74..0b9e8091b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,76 +4,75 @@ **Breaking** -- For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed. - Previously, the order was that outer prefixes were applied first, then inner ones. - This version reverses that. - To illustrate: - - ```julia - using DynamicPPL, Distributions - - @model function subsubmodel() - x ~ Normal() - end - - @model function submodel() - x ~ to_submodel(prefix(subsubmodel(), :c), false) - return x - end - - @model function parentmodel() - x1 ~ to_submodel(prefix(submodel(), :a), false) - x2 ~ to_submodel(prefix(submodel(), :b), false) - end - - keys(VarInfo(parentmodel())) - ``` - - Previously, the final line would return the variable names `c.a.x` and `c.b.x`. - With this version, it will return `a.c.x` and `b.c.x`, which is more intuitive. - (Note that this change brings `to_submodel`'s behaviour in line with the now-deprecated `@submodel` macro.) - - This change also affects sampling in Turing.jl. - + - For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed. + Previously, the order was that outer prefixes were applied first, then inner ones. + This version reverses that. + To illustrate: + + ```julia + using DynamicPPL, Distributions + + @model function subsubmodel() + return x ~ Normal() + end + + @model function submodel() + x ~ to_submodel(prefix(subsubmodel(), :c), false) + return x + end + + @model function parentmodel() + x1 ~ to_submodel(prefix(submodel(), :a), false) + return x2 ~ to_submodel(prefix(submodel(), :b), false) + end + + keys(VarInfo(parentmodel())) + ``` + + Previously, the final line would return the variable names `c.a.x` and `c.b.x`. + With this version, it will return `a.c.x` and `b.c.x`, which is more intuitive. + (Note that this change brings `to_submodel`'s behaviour in line with the now-deprecated `@submodel` macro.) + + This change also affects sampling in Turing.jl. ## 0.34.2 -- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. - From a user-facing perspective, this means that for models which use manually prefixed submodels, e.g. - - ```julia - using DynamicPPL, Distributions - - @model inner() = x ~ Normal() - - @model function outer() - x1 ~ to_submodel(prefix(inner(), :a), false) - x2 ~ to_submodel(prefix(inner(), :b), false) - end - ``` - - will: (1) no longer error when sampling due to `check_model_and_trace`; and (2) contain both submodel's variables in the resulting chain (the behaviour before this patch was that the second `x` would override the first `x`). - -- More broadly, implemented a general `prefix(ctx::AbstractContext, ::VarName)` which traverses the context tree in `ctx` to apply all necessary prefixes. This was a necessary step in fixing the above issues, but it also means that `prefix` is now capable of handling context trees with e.g. multiple prefixes at different levels of nesting. + - Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. + From a user-facing perspective, this means that for models which use manually prefixed submodels, e.g. + + ```julia + using DynamicPPL, Distributions + + @model inner() = x ~ Normal() + + @model function outer() + x1 ~ to_submodel(prefix(inner(), :a), false) + return x2 ~ to_submodel(prefix(inner(), :b), false) + end + ``` + + will: (1) no longer error when sampling due to `check_model_and_trace`; and (2) contain both submodel's variables in the resulting chain (the behaviour before this patch was that the second `x` would override the first `x`). + + - More broadly, implemented a general `prefix(ctx::AbstractContext, ::VarName)` which traverses the context tree in `ctx` to apply all necessary prefixes. This was a necessary step in fixing the above issues, but it also means that `prefix` is now capable of handling context trees with e.g. multiple prefixes at different levels of nesting. ## 0.34.1 -- Fix an issue that prevented merging two VarInfos if they had different dimensions for a variable. + - Fix an issue that prevented merging two VarInfos if they had different dimensions for a variable. -- Upper bound the compat version of KernelAbstractions to work around an issue in determining the right VarInfo type to use. + - Upper bound the compat version of KernelAbstractions to work around an issue in determining the right VarInfo type to use. ## 0.34.0 **Breaking** -- `rng` argument removed from `values_as_in_model`, and `varinfo` made non-optional. This means that the only signatures allowed are - - ``` - values_as_in_model(::Model, ::Bool, ::AbstractVarInfo) - values_as_in_model(::Model, ::Bool, ::AbstractVarInfo, ::AbstractContext) - ``` - - If you aren't using this function (it's probably only used in Turing.jl) then this won't affect you. + - `rng` argument removed from `values_as_in_model`, and `varinfo` made non-optional. This means that the only signatures allowed are + + ``` + values_as_in_model(::Model, ::Bool, ::AbstractVarInfo) + values_as_in_model(::Model, ::Bool, ::AbstractVarInfo, ::AbstractContext) + ``` + + If you aren't using this function (it's probably only used in Turing.jl) then this won't affect you. ## 0.33.1 @@ -84,10 +83,10 @@ There are no changes to the public-facing API, but internally you can no longer **Breaking** -- `values_as_in_model()` now requires an extra boolean parameter, specifying whether variables on the lhs of `:=` statements are to be included in the resulting `OrderedDict` of values. - The type signature is now `values_as_in_model([rng,] model, include_colon_eq::Bool [, varinfo, context])` + - `values_as_in_model()` now requires an extra boolean parameter, specifying whether variables on the lhs of `:=` statements are to be included in the resulting `OrderedDict` of values. + The type signature is now `values_as_in_model([rng,] model, include_colon_eq::Bool [, varinfo, context])` **Other** -- Moved the implementation of `predict` from Turing.jl to DynamicPPL.jl; the user-facing behaviour is otherwise the same -- Improved error message when a user tries to initialise a model with parameters that don't correspond strictly to the underlying VarInfo used + - Moved the implementation of `predict` from Turing.jl to DynamicPPL.jl; the user-facing behaviour is otherwise the same + - Improved error message when a user tries to initialise a model with parameters that don't correspond strictly to the underlying VarInfo used From c5f2f7a14566e9c73884404f57f8fde6165c7aed Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 12:46:56 +0000 Subject: [PATCH 04/14] Remove selector stuff from VarInfo tests and link/invlink (#780) * Remove selector stuff from varinfo tests * Implement link and invlink for varnames rather than samplers * Replace set_retained_vns_del_by_spl! with set_retained_vns_del! * Make linking tests more extensive * Remove sampler indexing from link methods (but not invlink) * Remove indexing by samplers from invlink * Work towards removing sampler indexing with StaticTransformation * Fix invlink/link for TypedVarInfo and StaticTransformation * Fix a test in models.jl * Move some functions to utils.jl, add tests and docstrings * Fix a docstring typo * Various simplification to link/invlink * Improve a docstring * Style improvements * Fix broken link/invlink dispatch cascade for VectorVarInfo * Fix some more broken dispatch cascades * Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> * Remove comments that messed with docstrings * Apply suggestions from code review Co-authored-by: Penelope Yong * Fix issues surfaced in code review * Simplify link/invlink arguments * Fix a bug in unflatten VarNamedVector * Rename VarNameCollection -> VarNameTuple * Remove test of a removed varname_namedtuple method * Apply suggestions from code review Co-authored-by: Penelope Yong * Respond to review feedback * Remove _default_sampler and a dead argument of maybe_invlink_before_eval * Fix a typo in a comment * Add HISTORY entry, fix one set_retained_vns_del! method --------- Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Co-authored-by: Penelope Yong --- HISTORY.md | 9 + docs/src/api.md | 2 +- src/DynamicPPL.jl | 2 +- src/abstract_varinfo.jl | 130 +++++++------- src/model.jl | 8 +- src/simple_varinfo.jl | 6 +- src/threadsafe.jl | 61 +++---- src/transforming.jl | 20 +-- src/utils.jl | 47 ++++++ src/varinfo.jl | 364 ++++++++++++++++++++++++++-------------- src/varnamedvector.jl | 7 +- test/model.jl | 2 +- test/simple_varinfo.jl | 4 +- test/utils.jl | 23 +++ test/varinfo.jl | 329 ++++++++++-------------------------- 15 files changed, 513 insertions(+), 501 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 0b9e8091b..03c564b64 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,15 @@ **Breaking** +### Remove indexing by samplers + +This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, + + - `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`. + - `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables. + +### Reverse prefixing order + - For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed. Previously, the order was that outer prefixes were applied first, then inner ones. This version reverses that. diff --git a/docs/src/api.md b/docs/src/api.md index 093cb06a6..36dd24250 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -304,7 +304,7 @@ set_num_produce! increment_num_produce! reset_num_produce! setorder! -set_retained_vns_del_by_spl! +set_retained_vns_del! ``` ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c1cdbd94e..55e1f7e88 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -59,7 +59,7 @@ export AbstractVarInfo, set_num_produce!, reset_num_produce!, increment_num_produce!, - set_retained_vns_del_by_spl!, + set_retained_vns_del!, is_flagged, set_flag!, unset_flag!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 3f513d71d..26c4268d8 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -537,117 +537,118 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl """ function settrans!! end +# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback +# method for the case when no `vns` is provided, that would get all the keys from the +# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case +# where `vns` is provided and the one where it is not. This is because having separate +# implementations is typically much more performant, and because not all AbstractVarInfo +# types support partial linking. + """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) + +Transform variables in `vi` to their linked space, mutating `vi` if possible. -Transform the variables in `vi` to their linked space, using the transformation `t`, -mutating `vi` if possible. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. See also: [`default_transformation`](@ref), [`invlink!!`](@ref). """ -link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, vi, SampleFromPrior(), model) +function link!!(vi::AbstractVarInfo, model::Model) + return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link!!(default_transformation(model, vi), vi, spl, model) +function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) + return link!!(default_transformation(model, vi), vi, vns, model) end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) + +Transform variables in `vi` to their linked space without mutating `vi`. -Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ -link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model) -function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link(t, deepcopy(vi), SampleFromPrior(), model) +function link(vi::AbstractVarInfo, model::Model) + return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link(default_transformation(model, vi), deepcopy(vi), spl, model) +function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) + return link(default_transformation(model, vi), vi, vns, model) end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) -Transform the variables in `vi` to their constrained space, using the (inverse of) -transformation `t`, mutating `vi` if possible. +Transform variables in `vi` to their constrained space, mutating `vi` if possible. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Either transform all variables, or only ones specified in `vns`. + +Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is +not provided. See also: [`default_transformation`](@ref), [`link!!`](@ref). """ -invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) -function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, vi, SampleFromPrior(), model) +function invlink!!(vi::AbstractVarInfo, model::Model) + return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - # Here we extract the `transformation` from `vi` rather than using the default one. - return invlink!!(transformation(vi), vi, spl, model) +function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) + return invlink!!(default_transformation(model, vi), vi, vns, model) end # Vector-based ones. function link!!( - t::StaticTransformation{<:Bijectors.Transform}, - vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) b = inverse(t.bijector) - x = vi[spl] + x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) + vi_new = setlogp!!(unflatten(vi, y), lp_new) return settrans!!(vi_new, t) end function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, - vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) b = t.bijector - y = vi[spl] + y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) + vi_new = setlogp!!(unflatten(vi, x), lp_new) return settrans!!(vi_new, NoTransformation()) end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) + +Transform variables in `vi` to their constrained space without mutating `vi`. -Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) -transformation `t`. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is +not provided. See also: [`default_transformation`](@ref), [`link`](@ref). """ -invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model) -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink(t, vi, SampleFromPrior(), model) +function invlink(vi::AbstractVarInfo, model::Model) + return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - return invlink(transformation(vi), vi, spl, model) +function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) + return invlink(default_transformation(model, vi), vi, vns, model) end """ - maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) + maybe_invlink_before_eval!!([t::Transformation,] vi, model) Return a possibly invlinked version of `vi`. @@ -698,34 +699,23 @@ julia> # Now performs a single `invlink!!` before model evaluation. -1001.4189385332047 ``` """ -function maybe_invlink_before_eval!!( - vi::AbstractVarInfo, context::AbstractContext, model::Model -) - return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) +function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model) + return maybe_invlink_before_eval!!(transformation(vi), vi, model) end -function maybe_invlink_before_eval!!( - ::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model -) +function maybe_invlink_before_eval!!(::NoTransformation, vi::AbstractVarInfo, model::Model) return vi end function maybe_invlink_before_eval!!( - ::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, model::Model ) - # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing. + # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we + # do nothing. return vi end function maybe_invlink_before_eval!!( - t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + t::StaticTransformation, vi::AbstractVarInfo, model::Model ) - return invlink!!(t, vi, _default_sampler(context), model) -end - -function _default_sampler(context::AbstractContext) - return _default_sampler(NodeTrait(_default_sampler, context), context) -end -_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior() -function _default_sampler(::IsParent, context::AbstractContext) - return _default_sampler(childcontext(context)) + return invlink!!(t, vi, model) end # Utilities diff --git a/src/model.jl b/src/model.jl index 6fb0b40b0..462db7397 100644 --- a/src/model.jl +++ b/src/model.jl @@ -971,7 +971,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the # lazy `invlink`-ing of the parameters. This can be useful for # speeding up computation. See docs for `maybe_invlink_before_eval!!` # for more information. - maybe_invlink_before_eval!!(varinfo, context_new, model), + maybe_invlink_before_eval!!(varinfo, model), context_new, $(unwrap_args...), ) @@ -1169,10 +1169,10 @@ end """ predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches +Generate samples from the posterior predictive distribution by evaluating `model` at each set +of parameter values provided in `chain`. The number of posterior predictive samples matches the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. +and the predicted values. """ function predict( rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b6a84238e..57b167077 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -680,8 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - spl::AbstractSampler, - model::Model, + ::Model, ) # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) @@ -695,8 +694,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - spl::AbstractSampler, - model::Model, + ::Model, ) # TODO: Make sure that `spl` is respected. b = t.bijector diff --git a/src/threadsafe.jl b/src/threadsafe.jl index cedb0efad..69be5dcb1 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,70 +81,51 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) -function link!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) +function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) end -function invlink!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) +function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model) +function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) end -function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) +function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure # consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates # to define `getlogp(vi)`. -function link!!( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function invlink!!( - ::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) +function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), NoTransformation(), ) end -function link( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return link!!(t, deepcopy(vi), spl, model) +function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) end -function invlink( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model -) - return invlink!!(t, deepcopy(vi), spl, model) +function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) end -function maybe_invlink_before_eval!!( - vi::ThreadSafeVarInfo, context::AbstractContext, model::Model -) +function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)` - # hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`. - return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!( - vi.varinfo, context, model - ) + # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the + # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogp(vi)`. + return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end # `getindex` @@ -182,8 +163,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) return vector_getranges(vi.varinfo, vns) end -function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) - return set_retained_vns_del_by_spl!(vi.varinfo, spl) +function set_retained_vns_del!(vi::ThreadSafeVarInfo) + return set_retained_vns_del!(vi.varinfo) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) diff --git a/src/transforming.jl b/src/transforming.jl index 1f6c55e24..1a26d212f 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,29 +91,21 @@ function dot_tilde_assume( return r, lp, vi end -function link!!( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) +function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function invlink!!( - ::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) +function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), NoTransformation(), ) end -function link( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) - return link!!(t, deepcopy(vi), spl, model) +function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) end -function invlink( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model -) - return invlink!!(t, deepcopy(vi), spl, model) +function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) end diff --git a/src/utils.jl b/src/utils.jl index 5fedd3039..2539b7179 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,9 @@ struct NoDefault end const NO_DEFAULT = NoDefault() +# A short-hand for a type commonly used in type signatures for VarInfo methods. +VarNameTuple = NTuple{N,VarName} where {N} + """ @addlogprob!(ex) @@ -1268,3 +1271,47 @@ _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) _merge(left::AbstractDict, right::NamedTuple{()}) = left _merge(left::NamedTuple{()}, right::AbstractDict) = right + +""" + unique_syms(vns::T) where {T<:NTuple{N,VarName}} + +Return the unique symbols of the variables in `vns`. + +Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike +`Base.unique`, returns a `Tuple`. The point of `unique_syms` is that it supports constant +propagating the result, which is possible only when the argument and the return value are +`Tuple`s. +""" +@generated function unique_syms(::T) where {T<:VarNameTuple} + retval = Expr(:tuple) + syms = [first(vn.parameters) for vn in T.parameters] + for sym in unique(syms) + push!(retval.args, QuoteNode(sym)) + end + return retval +end + +""" + group_varnames_by_symbol(vns::NTuple{N,VarName}) where {N} + +Return a `NamedTuple` of the variables in `vns` grouped by symbol. + +Note that `group_varnames_by_symbol` only accepts a `Tuple` of `VarName`s. This allows it to +be type stable. + +Example: +```julia +julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) +(x, y[1], x.a, z[15], y[2]) + +julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) +(x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) + +julia> group_varnames_by_symbol(vns_tuple) == vns_nt +``` +""" +function group_varnames_by_symbol(vns::VarNameTuple) + syms = unique_syms(vns) + elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) + return NamedTuple{syms}(elements) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 3f36cc293..09f5960c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -791,6 +791,9 @@ Returns a tuple of the unique symbols of random variables sampled in `vi`. syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::TypedVarInfo) = keys(vi.metadata) +_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) +_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) + # Get all indices of variables belonging to SampleFromPrior: # if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to # the SampleFromPrior sampler @@ -897,6 +900,22 @@ end return :($(exprs...),) end +""" + all_varnames_grouped_by_symbol(vi::TypedVarInfo) + +Return a `NamedTuple` of the variables in `vi` grouped by symbol. +""" +all_varnames_grouped_by_symbol(vi::TypedVarInfo) = + all_varnames_grouped_by_symbol(vi.metadata) + +@generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :($f = keys(md.$f))) + end + return expr +end + # Get the index (in vals) ranges of all the vns of variables belonging to spl @inline function _getranges(vi::VarInfo, spl::Sampler) ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} @@ -1150,29 +1169,50 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end +function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) + vns = all_varnames_grouped_by_symbol(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _link(model, vi, vns) + _link!(vi, vns) + return vi +end + +function link!!(::DynamicTransformation, vi::VarInfo, model::Model) + vns = keys(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _link(model, vi, vns) + _link!(vi, vns) + return vi +end + +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) +end + # X -> R for all variables associated with given sampler -function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return link(t, vi, spl, model) + has_varnamedvector(vi) && return _link(model, vi, vns) # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, spl) + _link!(vi, vns) return vi end function link!!( t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameTuple, model::Model, ) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, spl::AbstractSampler) +function _link!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` - vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) @@ -1183,24 +1223,41 @@ function _link!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to link a linked vi") end end -function _link!(vi::TypedVarInfo, spl::AbstractSampler) - return _link!(vi, spl, Val(getspace(spl))) + +# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. +function _link!(vi::TypedVarInfo, vns::VarNameTuple) + return _link!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) - vns = _getvns(vi, spl) - return _link!(vi.metadata, vi, vns, spaceval) + +function _link!(vi::TypedVarInfo, vns::NamedTuple) + return _link!(vi.metadata, vi, vns) +end + +""" + filter_subsumed(filter_vns, filtered_vns) + +Return the subset of `filtered_vns` that are subsumed by any variable in `filter_vns`. +""" +function filter_subsumed(filter_vns, filtered_vns) + return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end + @generated function _link!( - metadata::NamedTuple{names}, vi, vns, ::Val{space} -) where {names,space} + ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} +) where {metadata_names,vns_names} expr = Expr(:block) - for f in names - if inspace(f, space) || length(space) == 0 - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - if ~istrans(vi, f_vns[1]) + for f in metadata_names + if !(f in vns_names) + continue + end + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns.$f, f_vns) + if !isempty(f_vns) + if !istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) @@ -1210,45 +1267,65 @@ end else @warn("[DynamicPPL] attempt to link a linked vi") end - end, - ) - end + end + end, + ) end return expr end +function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) + vns = all_varnames_grouped_by_symbol(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _invlink(model, vi, vns) + # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. + _invlink!(vi, vns) + return vi +end + +function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) + vns = keys(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _invlink(model, vi, vns) + _invlink!(vi, vns) + return vi +end + +function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) +end + # R -> X for all variables associated with given sampler -function invlink!!( - t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model -) +function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return invlink(t, vi, spl, model) + has_varnamedvector(vi) && return _invlink(model, vi, vns) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, spl) + _invlink!(vi, vns) return vi end function invlink!!( ::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) end -function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) +function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do # is just assume that `default_transformation` is the correct one if `istrans(vi)`. t = istrans(vi) ? default_transformation(model, vi) : NoTransformation() - return maybe_invlink_before_eval!!(t, vi, context, model) + return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) +function _invlink!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1259,36 +1336,43 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -function _invlink!(vi::TypedVarInfo, spl::AbstractSampler) - return _invlink!(vi, spl, Val(getspace(spl))) + +# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. +function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) + return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) - vns = _getvns(vi, spl) - return _invlink!(vi.metadata, vi, vns, spaceval) + +function _invlink!(vi::TypedVarInfo, vns::NamedTuple) + return _invlink!(vi.metadata, vi, vns) end + @generated function _invlink!( - metadata::NamedTuple{names}, vi, vns, ::Val{space} -) where {names,space} + ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} +) where {metadata_names,vns_names} expr = Expr(:block) - for f in names - if inspace(f, space) || length(space) == 0 - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end, - ) + for f in metadata_names + if !(f in vns_names) + continue end + + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns.$f, f_vns) + if istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, false, vn) + end + else + @warn("[DynamicPPL] attempt to invlink an invlinked vi") + end + end, + ) end return expr end @@ -1320,59 +1404,72 @@ function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) return map(Returns(nothing), varinfo.metadata) end -function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(model, varinfo, spl) +function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _link(model, vi, all_varnames_grouped_by_symbol(vi)) +end + +function link(::DynamicTransformation, varinfo::VarInfo, model::Model) + return _link(model, varinfo, keys(varinfo)) +end + +function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) +end + +function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) + return _link(model, varinfo, vns) end function link( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameTuple, model::Model, ) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end -function _link( - model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, spl::AbstractSampler -) +function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _link_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) +end + +# If we try to _link a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. +function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) + return _link(model, varinfo, group_varnames_by_symbol(vns)) end -function _link(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) +function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _link_metadata_namedtuple!( - model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) - ) + md = _link_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _link_metadata_namedtuple!( +@generated function _link_metadata!( model::Model, varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space}, -) where {names,space} + metadata::NamedTuple{metadata_names}, + vns::NamedTuple{vns_names}, +) where {metadata_names,vns_names} vals = Expr(:tuple) - for f in names - if inspace(f, space) || length(space) == 0 + for f in metadata_names + if f in vns_names push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end end - return :(NamedTuple{$names}($vals)) + return :(NamedTuple{$metadata_names}($vals)) end -function _link_metadata!!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) + +function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. @@ -1444,57 +1541,76 @@ function _link_metadata!!( return metadata end +function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) +end + +function invlink(::DynamicTransformation, vi::VarInfo, model::Model) + return _invlink(model, vi, keys(vi)) +end + function invlink( - ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model ) - return _invlink(model, varinfo, spl) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) +end + +function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) + return _invlink(model, varinfo, vns) end + function invlink( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) end -function _invlink(model::Model, varinfo::VarInfo, spl::AbstractSampler) +function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) +# If we try to _invlink a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. +function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) + return _invlink(model, varinfo, group_varnames_by_symbol(vns)) +end + +function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _invlink_metadata_namedtuple!( - model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) - ) + md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _invlink_metadata_namedtuple!( +@generated function _invlink_metadata!( model::Model, varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space}, -) where {names,space} + metadata::NamedTuple{metadata_names}, + vns::NamedTuple{vns_names}, +) where {metadata_names,vns_names} vals = Expr(:tuple) - for f in names - if inspace(f, space) || length(space) == 0 + for f in metadata_names + if (f in vns_names) push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end end - return :(NamedTuple{$names}($vals)) + return :(NamedTuple{$metadata_names}($vals)) end + function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns @@ -1545,7 +1661,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ end function _invlink_metadata!!( - model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns + ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns for vn in vns @@ -1966,37 +2082,35 @@ function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bo end """ - set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) + set_retained_vns_del!(vi::VarInfo) Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. """ -function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) - # Get the indices of `vns` that belong to `spl` as a vector - gidcs = _getidcs(vi, spl) +function set_retained_vns_del!(vi::UntypedVarInfo) + idcs = _getidcs(vi) if get_num_produce(vi) == 0 - for i in length(gidcs):-1:1 - vi.metadata.flags["del"][gidcs[i]] = true + for i in length(idcs):-1:1 + vi.metadata.flags["del"][idcs[i]] = true end else for i in 1:length(vi.orders) - if i in gidcs && vi.orders[i] > get_num_produce(vi) + if i in idcs && vi.orders[i] > get_num_produce(vi) vi.metadata.flags["del"][i] = true end end end return nothing end -function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler) - # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol - gidcs = _getidcs(vi, spl) - return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi)) +function set_retained_vns_del!(vi::TypedVarInfo) + idcs = _getidcs(vi) + return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end -@generated function _set_retained_vns_del_by_spl!( - metadata, gidcs::NamedTuple{names}, num_produce +@generated function _set_retained_vns_del!( + metadata, idcs::NamedTuple{names}, num_produce ) where {names} expr = Expr(:block) for f in names - f_gidcs = :(gidcs.$f) + f_idcs = :(idcs.$f) f_orders = :(metadata.$f.orders) f_flags = :(metadata.$f.flags) push!( @@ -2004,12 +2118,12 @@ end quote # Set the flag for variables with symbol `f` if num_produce == 0 - for i in length($f_gidcs):-1:1 - $f_flags["del"][$f_gidcs[i]] = true + for i in length($f_idcs):-1:1 + $f_flags["del"][$f_idcs[i]] = true end else for i in 1:length($f_orders) - if i in $f_gidcs && $f_orders[i] > num_produce + if i in $f_idcs && $f_orders[i] > num_produce $f_flags["del"][i] = true end end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 565e82480..7da126321 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1068,7 +1068,12 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) new_ranges = deepcopy(vnv.ranges) recontiguify_ranges!(new_ranges) return VarNamedVector( - vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms + vnv.varname_to_index, + vnv.varnames, + new_ranges, + vals, + vnv.transforms, + vnv.is_unconstrained, ) end diff --git a/test/model.jl b/test/model.jl index 118f60a40..e91de4bd2 100644 --- a/test/model.jl +++ b/test/model.jl @@ -226,7 +226,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = DynamicPPL.TestUtils.demo_dynamic_constraint() spl = SampleFromPrior() vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) - link!!(vi, spl, model) + vi = link!!(vi, model) for i in 1:10 # Sample with large variations. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 4343563eb..137c791c2 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -275,9 +275,7 @@ # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. @test !DynamicPPL.istrans( - DynamicPPL.maybe_invlink_before_eval!!( - deepcopy(vi), SamplingContext(), model - ), + DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) ) # Resulting varinfo should no longer be transformed. diff --git a/test/utils.jl b/test/utils.jl index 3f435dca4..d683f132d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,4 +48,27 @@ x = rand(dist) @test DynamicPPL.tovec(x) == vec(x.UL) end + + @testset "unique_syms" begin + vns = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) + @inferred DynamicPPL.unique_syms(vns) + @inferred DynamicPPL.unique_syms(()) + @test DynamicPPL.unique_syms(vns) == (:x, :y, :z) + @test DynamicPPL.unique_syms(()) == () + end + + @testset "group_varnames_by_symbol" begin + vns_tuple = ( + @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) + ) + vns_vec = collect(vns_tuple) + vns_nt = (; + x=[@varname(x), @varname(x.a)], + y=[@varname(y[1]), @varname(y[2])], + z=[@varname(z[15])], + ) + vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] + @inferred DynamicPPL.group_varnames_by_symbol(vns_tuple) + @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt + end end diff --git a/test/varinfo.jl b/test/varinfo.jl index fce87b2f3..d689a1bf4 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,8 +1,3 @@ -# Dummy algorithm for testing -# Invoke with: DynamicPPL.Sampler(MyAlg{(:x, :y)}(), ...) -struct MyAlg{space} end -DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg{space}}) where {space} = space - function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, @@ -19,16 +14,13 @@ function check_varinfo_keys(varinfo, vns) end end -function randr( - vi::DynamicPPL.VarInfo, - vn::VarName, - dist::Distribution, - spl::DynamicPPL.Sampler, - count::Bool=false, -) +""" +Return the value of `vn` in `vi`. If one doesn't exist, sample and set it. +""" +function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) if !haskey(vi, vn) r = rand(dist) - push!!(vi, vn, r, dist, spl) + push!!(vi, vn, r, dist) r elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") @@ -37,8 +29,6 @@ function randr( DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) r else - count && checkindex(vn, vi, spl) - DynamicPPL.updategid!(vi, vn, spl) vi[vn] end end @@ -66,7 +56,6 @@ end tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] @test meta.orders[ind] == fmeta.orders[tind] - @test meta.gids[ind] == fmeta.gids[tind] for flag in keys(meta.flags) @test meta.flags[flag][ind] == fmeta.flags[flag][tind] end @@ -89,22 +78,6 @@ end vn2 = @varname x[1][2] @test vn2 == vn1 @test hash(vn2) == hash(vn1) - @test inspace(vn1, (:x,)) - - # Tests for `inspace` - space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) - - @test inspace(@varname(x), space) - @test inspace(@varname(y), space) - @test inspace(@varname(x[1]), space) - @test inspace(@varname(z[1][1]), space) - @test inspace(@varname(z[1][:]), space) - @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2, 3], 1]), space) - @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) - @test inspace(@varname(M[1, [2, 4, 6]]), space) - @test !inspace(@varname(z[2]), space) - @test !inspace(@varname(z), space) function test_base!!(vi_original) vi = empty!!(vi_original) @@ -114,38 +87,31 @@ end vn = @varname x dist = Normal(0, 1) r = rand(dist) - gid = DynamicPPL.Selector() @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - vi = push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @test length(vi[vn]) == 1 - @test length(vi[SampleFromPrior()]) == 1 - @test vi[vn] == r - @test vi[SampleFromPrior()][1] == r vi = DynamicPPL.setindex!!(vi, 2 * r, vn) @test vi[vn] == 2 * r - @test vi[SampleFromPrior()][1] == 2 * r - vi = DynamicPPL.setindex!!(vi, [3 * r], SampleFromPrior()) - @test vi[vn] == 3 * r - @test vi[SampleFromPrior()][1] == 3 * r # TODO(mhauru) Implement these functions for other VarInfo types too. if vi isa DynamicPPL.VectorVarInfo delete!(vi, vn) @test isempty(vi) - vi = push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) end vi = empty!!(vi) @test isempty(vi) - return push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) + @test ~isempty(vi) end vi = VarInfo() @@ -182,9 +148,8 @@ end vn_x = @varname x dist = Normal(0, 1) r = rand(dist) - gid = Selector() - push!!(vi, vn_x, r, dist, gid) + push!!(vi, vn_x, r, dist) # del is set by default @test !is_flagged(vi, vn_x, "del") @@ -204,35 +169,13 @@ end vn_x = @varname x vn_y = @varname y untyped_vi = VarInfo() - untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) typed_vi = TypedVarInfo(untyped_vi) - typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) @test typed_vi[vn_x] == 1.0 @test typed_vi[vn_y] == 2.0 end - @testset "setgid!" begin - vi = VarInfo(DynamicPPL.Metadata()) - meta = vi.metadata - vn = @varname x - dist = Normal(0, 1) - r = rand(dist) - gid1 = Selector() - gid2 = Selector(2, :HMC) - - push!!(vi, vn, r, dist, gid1) - @test meta.gids[meta.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) - @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) - - vi = empty!!(TypedVarInfo(vi)) - meta = vi.metadata - push!!(vi, vn, r, dist, gid1) - @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) - @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) - end - @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -397,10 +340,9 @@ end """ function test_setval!(model, chain; sample_idx=1, chain_idx=1) var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] + θ_old = var_info[:] DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] + θ_new = var_info[:] @test θ_old != θ_new vals = DynamicPPL.values_as(var_info, OrderedDict) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) @@ -432,13 +374,21 @@ end end @testset "link!! and invlink!!" begin - @model gdemo(x, y) = begin + @model gdemo(a, b, ::Type{T}=Float64) where {T} = begin s ~ InverseGamma(2, 3) m ~ Uniform(0, 2) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) + x = Vector{T}(undef, length(a)) + x .~ Normal(m, sqrt(s)) + y = Vector{T}(undef, length(a)) + for i in eachindex(y) + y[i] ~ Normal(m, sqrt(s)) + end + a .~ Normal(m, sqrt(s)) + for i in eachindex(b) + b[i] ~ Normal(x[i] * y[i], sqrt(s)) + end end - model = gdemo(1.0, 2.0) + model = gdemo([1.0, 1.5], [2.0, 2.5]) # Check that instantiating the model does not perform linking vi = VarInfo() @@ -448,38 +398,55 @@ end # Check that linking and invlinking set the `trans` flag accordingly v = copy(meta.vals) - link!!(vi, model) + vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.vns) - invlink!!(vi, model) + vi = invlink!!(vi, model) @test all(x -> !istrans(vi, x), meta.vns) @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values vi = TypedVarInfo(vi) meta = vi.metadata - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) - link!!(vi, model) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) - invlink!!(vi, model) + v_x = copy(meta.x.vals) + v_y = copy(meta.y.vals) + @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 - - # Transform only one variable (`s`) but not the others (`m`) - spl = DynamicPPL.Sampler(MyAlg{(:s,)}(), model) - link!!(vi, spl, model) + vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - invlink!!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.m.vns) + vi = invlink!!(vi, model) @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 + + # Transform only one variable + all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) + for vn in [ + @varname(s), + @varname(m), + @varname(x), + @varname(y), + @varname(x[2]), + @varname(y[2]) + ] + target_vns = filter(x -> subsumes(vn, x), all_vns) + other_vns = filter(x -> !subsumes(vn, x), all_vns) + @test !isempty(target_vns) + @test !isempty(other_vns) + vi = link!!(vi, (vn,), model) + @test all(x -> istrans(vi, x), target_vns) + @test all(x -> !istrans(vi, x), other_vns) + vi = invlink!!(vi, (vn,), model) + @test all(x -> !istrans(vi, x), all_vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + @test meta.x.vals ≈ v_x atol = 1e-10 + @test meta.y.vals ≈ v_y atol = 1e-10 + end end @testset "istrans" begin @@ -856,73 +823,17 @@ end @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end - - # The below used to error, testing to avoid regression. - @testset "merge gids" begin - gidset_left = Set([Selector(1)]) - vi_left = VarInfo() - vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) - gidset_right = Set([Selector(2)]) - vi_right = VarInfo() - vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) - varinfo_merged = merge(vi_left, vi_right) - @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left - @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right - end - - # The below used to error, testing to avoid regression. - @testset "merge different dimensions" begin - vn = @varname(x) - vi_single = VarInfo() - vi_single = push!!(vi_single, vn, 1.0, Normal()) - vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) - @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] - @test merge(vi_double, vi_single)[vn] == 1.0 - end end - @testset "VarInfo with selectors" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo( - model, - DynamicPPL.SampleFromPrior(), - DynamicPPL.DefaultContext(), - DynamicPPL.Metadata(), - ) - selector = DynamicPPL.Selector() - spl = Sampler(MyAlg{(:s,)}(), model, selector) - - vns = DynamicPPL.TestUtils.varnames(model) - vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) - vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) - for vn in vns_s - DynamicPPL.updategid!(varinfo, vn, spl) - end - - # Should only get the variables subsumed by `@varname(s)`. - @test varinfo[spl] == - mapreduce(Base.Fix1(DynamicPPL.getindex_internal, varinfo), vcat, vns_s) - - # `link` - varinfo_linked = DynamicPPL.link(varinfo, spl, model) - # `s` variables should be linked - @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) - # `m` variables should NOT be linked - @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) - # And `varinfo` should be unchanged - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns) - - # `invlink` - varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model) - # `s` variables should no longer be linked - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s) - # `m` variables should still not be linked - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m) - # And `varinfo_linked` should be unchanged - @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) - @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) - end + # The below used to error, testing to avoid regression. + @testset "merge different dimensions" begin + vn = @varname(x) + vi_single = VarInfo() + vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_double = VarInfo() + vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] + @test merge(vi_double, vi_single)[vn] == 1.0 end @testset "sampling from linked varinfo" begin @@ -1025,25 +936,22 @@ end vi = DynamicPPL.VarInfo() dists = [Categorical([0.7, 0.3]), Normal()] - spl1 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) - spl2 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) - # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_b, dists[2]) + randr(vi, vn_z2, dists[1]) + randr(vi, vn_a2, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_z3, dists[1]) @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] @test DynamicPPL.get_num_produce(vi) == 3 DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @@ -1051,13 +959,13 @@ end @test DynamicPPL.is_flagged(vi, vn_z3, "del") DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_z2, dists[1]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_z3, dists[1]) + randr(vi, vn_a2, dists[2]) @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @test DynamicPPL.get_num_produce(vi) == 3 @@ -1065,21 +973,21 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_b, dists[2]) + randr(vi, vn_z2, dists[1]) + randr(vi, vn_a2, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_z3, dists[1]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 2] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @@ -1087,69 +995,16 @@ end @test DynamicPPL.is_flagged(vi, vn_z3, "del") DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_z2, dists[1]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_z3, dists[1]) + randr(vi, vn_a2, dists[2]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 3] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 end - - @testset "varinfo ranges" begin - @model empty_model() = x = 1 - dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] - - function test_varinfo!(vi) - spl2 = DynamicPPL.Sampler(MyAlg{(:w, :u)}(), empty_model()) - vn_w = @varname w - randr(vi, vn_w, dists[1], spl2, true) - - vn_x = @varname x - vn_y = @varname y - vn_z = @varname z - vns = [vn_x, vn_y, vn_z] - - spl1 = DynamicPPL.Sampler(MyAlg{(:x, :y, :z)}(), empty_model()) - for i in 1:3 - r = randr(vi, vns[i], dists[i], spl1, false) - val = vi[vns[i]] - @test sum(val - r) <= 1e-9 - end - - idcs = DynamicPPL._getidcs(vi, spl1) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 - else - @test length(idcs) == 3 - end - @test length(vi[spl1]) == 7 - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 - else - @test length(idcs) == 1 - end - @test length(vi[spl2]) == 1 - - vn_u = @varname u - randr(vi, vn_u, dists[1], spl2, true) - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 - else - @test length(idcs) == 2 - end - @test length(vi[spl2]) == 2 - end - vi = DynamicPPL.VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) - end end From 136644074415910aaed773cb250f0da0e5d5586b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 5 Feb 2025 11:23:22 +0000 Subject: [PATCH 05/14] Remove samplers from VarInfo - indexing (#793) * Remove selector stuff from varinfo tests * Implement link and invlink for varnames rather than samplers * Replace set_retained_vns_del_by_spl! with set_retained_vns_del! * Make linking tests more extensive * Remove sampler indexing from link methods (but not invlink) * Remove indexing by samplers from invlink * Work towards removing sampler indexing with StaticTransformation * Fix invlink/link for TypedVarInfo and StaticTransformation * Fix a test in models.jl * Move some functions to utils.jl, add tests and docstrings * Fix a docstring typo * Various simplification to link/invlink * Improve a docstring * Style improvements * Fix broken link/invlink dispatch cascade for VectorVarInfo * Fix some more broken dispatch cascades * Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> * Remove comments that messed with docstrings * Apply suggestions from code review Co-authored-by: Penelope Yong * Fix issues surfaced in code review * Simplify link/invlink arguments * Fix a bug in unflatten VarNamedVector * Rename VarNameCollection -> VarNameTuple * Remove test of a removed varname_namedtuple method * Apply suggestions from code review Co-authored-by: Penelope Yong * Respond to review feedback * Remove _default_sampler and a dead argument of maybe_invlink_before_eval * Fix a typo in a comment * Add HISTORY entry, fix one set_retained_vns_del! method * Remove some VarInfo getindex with samplers stuff * Remove some index setting with samplers * Remove more sampler indexing * Remove unflatten with samplers * Clean up some setindex stuff * Remove a bunch of varinfo.jl internal functions that used samplers/space, update HISTORY.md * Fix HISTORY.md * Miscalleanous small fixes * Fix a bug in VarInfo constructor * Fix getparams(::LogDensityFunction) * Apply suggestions from code review Co-authored-by: Penelope Yong --------- Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Co-authored-by: Penelope Yong --- HISTORY.md | 5 + src/abstract_varinfo.jl | 32 ++--- src/compiler.jl | 65 ++++----- src/logdensityfunction.jl | 9 +- src/model.jl | 4 +- src/sampler.jl | 31 ++-- src/simple_varinfo.jl | 13 +- src/threadsafe.jl | 20 +-- src/utils.jl | 4 +- src/varinfo.jl | 296 +++++--------------------------------- src/varnamedvector.jl | 15 -- test/model.jl | 4 +- test/sampler.jl | 4 +- 13 files changed, 101 insertions(+), 401 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 03c564b64..6b7247c8d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -10,6 +10,11 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`. - `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables. + - `getindex`, `setindex!`, and `setindex!!` no longer accept samplers as arguments + - `unflatten` no longer accepts a sampler as an argument + - `eltype(::VarInfo)` no longer accepts a sampler as an argument + - `keys(::VarInfo)` no longer accepts a sampler as an argument + - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. ### Reverse prefixing order diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 26c4268d8..4e9e5c554 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -149,7 +149,6 @@ If `dist` is specified, the value(s) will be massaged into the representation ex """ getindex(vi::AbstractVarInfo, ::Colon) - getindex(vi::AbstractVarInfo, ::AbstractSampler) Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) distribution(s) as a flattened `Vector`. @@ -159,7 +158,6 @@ The default implementation is to call [`values_as`](@ref) with `Vector` as the t See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) """ Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) -Base.getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] """ getindex_internal(vi::AbstractVarInfo, vn::VarName) @@ -341,9 +339,9 @@ julia> values_as(vi, Vector) function values_as end """ - eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior} + eltype(vi::AbstractVarInfo) -Determine the default `eltype` of the values returned by `vi[spl]`. +Return the `eltype` of the values returned by `vi[:]`. !!! warning This should generally not be called explicitly, as it's only used in @@ -352,13 +350,13 @@ Determine the default `eltype` of the values returned by `vi[spl]`. This method is considered legacy, and is likely to be deprecated in the future. """ -function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) - T = Base.promote_op(getindex, typeof(vi), typeof(spl)) +function Base.eltype(vi::AbstractVarInfo) + T = Base.promote_op(getindex, typeof(vi), Colon) if T === Union{} - # In this case `getindex(vi, spl)` errors + # In this case `getindex(vi, :)` errors # Let us throw a more descriptive error message # Ref https://github.com/TuringLang/Turing.jl/issues/2151 - return eltype(vi[spl]) + return eltype(vi[:]) end return eltype(T) end @@ -720,25 +718,11 @@ end # Utilities """ - unflatten(vi::AbstractVarInfo[, context::AbstractContext], x::AbstractVector) + unflatten(vi::AbstractVarInfo, x::AbstractVector) Return a new instance of `vi` with the values of `x` assigned to the variables. - -If `context` is provided, `x` is assumed to be realizations only for variables not -filtered out by `context`. """ -function unflatten(varinfo::AbstractVarInfo, context::AbstractContext, θ) - if hassampler(context) - unflatten(getsampler(context), varinfo, context, θ) - else - DynamicPPL.unflatten(varinfo, θ) - end -end - -# TODO: deprecate this once `sampler` is no longer the main way of filtering out variables. -function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::AbstractContext, θ) - return unflatten(varinfo, sampler, θ) -end +function unflatten end """ to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) diff --git a/src/compiler.jl b/src/compiler.jl index c67da6f95..8743641af 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3,7 +3,7 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) """ need_concretize(expr) -Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or +Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or requires a dynamic optic. # Examples @@ -730,19 +730,19 @@ function warn_empty(body) return nothing end +# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? +# TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(sampler, vi, value) - matchingvalue(context::AbstractContext, vi, value) + matchingvalue(vi, value) -Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. - -For a `context` that is _not_ a `SamplingContext`, we fall back to -`matchingvalue(SampleFromPrior(), vi, value)`. +Convert the `value` to the correct type for the `vi` object. """ -function matchingvalue(sampler, vi, value) +function matchingvalue(vi, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(sampler, vi, T), value) + _value = convert(get_matching_type(vi, T), value) + # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we + # are happy to return `value` as-is? if _value === value return deepcopy(_value) else @@ -752,45 +752,30 @@ function matchingvalue(sampler, vi, value) return value end end -# If we hit `Type` or `TypeWrap`, we immediately jump to `get_matching_type`. -function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) - return get_matching_type(sampler, vi, value) -end -function matchingvalue(sampler::AbstractSampler, vi, value::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(sampler, vi, T)}() -end -function matchingvalue(context::AbstractContext, vi, value) - return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value) +function matchingvalue(vi, value::FloatOrArrayType) + return get_matching_type(vi, value) end -function matchingvalue(::IsLeaf, context::AbstractContext, vi, value) - return matchingvalue(SampleFromPrior(), vi, value) -end -function matchingvalue(::IsParent, context::AbstractContext, vi, value) - return matchingvalue(childcontext(context), vi, value) -end -function matchingvalue(context::SamplingContext, vi, value) - return matchingvalue(context.sampler, vi, value) +function matchingvalue(vi, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(vi, T)}() end +# TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(spl::AbstractSampler, vi, ::TypeWrap{T}) where {T} - -Get the specialized version of type `T` for sampler `spl`. + get_matching_type(vi, ::TypeWrap{T}) where {T} -For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is -`eltype(vi[spl])`. +Get the specialized version of type `T` for `vi`. """ -get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi, spl))} +get_matching_type(_, ::Type{T}) where {T} = T +function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(eltype(vi))} end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi, spl)) +function get_matching_type(vi, ::Type{<:AbstractFloat}) + return float_type_with_fallback(eltype(vi)) end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(spl, vi, T),N} +function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(vi, T),N} end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(spl, vi, T)} +function get_matching_type(vi, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(vi, T)} end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 214369ab0..29f591cc3 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -121,22 +121,17 @@ end getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) -_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx) -_get_indexer(ctx::SamplingContext) = ctx.sampler -_get_indexer(::IsParent, ctx::AbstractContext) = _get_indexer(childcontext(ctx)) -_get_indexer(::IsLeaf, ctx::AbstractContext) = Colon() - """ getparams(f::LogDensityFunction) Return the parameters of the wrapped varinfo as a vector. """ -getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))] +getparams(f::LogDensityFunction) = f.varinfo[:] # LogDensityProblems interface function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) context = getcontext(f) - vi_new = unflatten(f.varinfo, context, θ) + vi_new = unflatten(f.varinfo, θ) return getlogp(last(evaluate!!(f.model, vi_new, context))) end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) diff --git a/src/model.jl b/src/model.jl index 462db7397..3601d77fd 100644 --- a/src/model.jl +++ b/src/model.jl @@ -948,9 +948,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(context_new, varinfo, model.args.$var)...) + :($matchingvalue(varinfo, model.args.$var)...) else - :($matchingvalue(context_new, varinfo, model.args.$var)) + :($matchingvalue(varinfo, model.args.$var)) end for var in argnames ] diff --git a/src/sampler.jl b/src/sampler.jl index 974828e8b..56cd8404e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -118,7 +118,7 @@ function AbstractMCMC.step( # Update the parameters if provided. if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, spl, model) + vi = initialize_parameters!!(vi, initial_params, model) # Update joint log probability. # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 @@ -156,9 +156,7 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function set_values!!( - varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler -) +function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector) throw( ArgumentError( "`initial_params` must be a vector of type `Union{Real,Missing}`. " * @@ -168,11 +166,9 @@ function set_values!!( end function set_values!!( - varinfo::AbstractVarInfo, - initial_params::AbstractVector{<:Union{Real,Missing}}, - spl::AbstractSampler, + varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} ) - flattened_param_vals = varinfo[spl] + flattened_param_vals = varinfo[:] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( "Provided initial value size ($(length(initial_params))) doesn't match " * @@ -189,12 +185,11 @@ function set_values!!( end # Update in `varinfo`. - return setindex!!(varinfo, flattened_param_vals, spl) + setall!(varinfo, flattened_param_vals) + return varinfo end -function set_values!!( - varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler -) +function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple) vars_in_varinfo = keys(varinfo) for v in keys(initial_params) vn = VarName{v}() @@ -219,23 +214,21 @@ function set_values!!( ) end -function initialize_parameters!!( - vi::AbstractVarInfo, initial_params, spl::AbstractSampler, model::Model -) +function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) @debug "Using passed-in initial variable values" initial_params # `link` the varinfo if needed. - linked = islinked(vi, spl) + linked = islinked(vi) if linked - vi = invlink!!(vi, spl, model) + vi = invlink!!(vi, model) end # Set the values in `vi`. - vi = set_values!!(vi, initial_params, spl) + vi = set_values!!(vi, initial_params) # `invlink` if needed. if linked - vi = link!!(vi, spl, model) + vi = link!!(vi, model) end return vi diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 57b167077..07296c3f7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -258,7 +258,6 @@ function typed_simple_varinfo(model::Model) return last(evaluate!!(model, varinfo, SamplingContext())) end -unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x) function unflatten(svi::SimpleVarInfo, x::AbstractVector) logp = getlogp(svi) vals = unflatten(svi.values, x) @@ -342,10 +341,6 @@ function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) return Accessors.@set vi.values = set!!(vi.values, vn, val) end -function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler) - return unflatten(vi, spl, val) -end - # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with # same symbol and same type of, say, `IndexLens`, for improved `.~` performance. function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) @@ -428,11 +423,7 @@ const SimpleOrThreadSafeSimple{T,V,C} = Union{ } # Necessary for `matchingvalue` to work properly. -function Base.eltype( - vi::SimpleOrThreadSafeSimple{<:Any,V}, spl::Union{AbstractSampler,SampleFromPrior} -) where {V} - return V -end +Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) @@ -562,7 +553,7 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi) +islinked(vi::SimpleVarInfo) = istrans(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 69be5dcb1..4367ff06d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -79,7 +79,7 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) +islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo) function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) @@ -138,17 +138,6 @@ end function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution) return getindex(vi.varinfo, vns, dist) end -getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) - -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) @@ -184,13 +173,9 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) end -# Transformations. function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) end -function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) -end istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) @@ -200,9 +185,6 @@ getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.var function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) end -function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, spl, x) -end function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns) diff --git a/src/utils.jl b/src/utils.jl index 2539b7179..d64f6dc66 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -942,9 +942,9 @@ function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) end """ - float_type_with_fallback(x) + float_type_with_fallback(T::DataType) -Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`. +Return `float(T)` if possible; otherwise return `float(Real)`. """ float_type_with_fallback(::Type) = float(Real) float_type_with_fallback(::Type{Union{}}) = float(Real) diff --git a/src/varinfo.jl b/src/varinfo.jl index 09f5960c1..8f7f7b6c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -111,10 +111,11 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` # multiple times. -transformation(vi::VarInfo) = DynamicTransformation() +transformation(::VarInfo) = DynamicTransformation() -function VarInfo(old_vi::VarInfo, spl, x::AbstractVector) - md = replace_values(old_vi.metadata, Val(getspace(spl)), x) +# TODO(mhauru) Isn't this the same as unflatten and/or replace_values? +function VarInfo(old_vi::VarInfo, x::AbstractVector) + md = replace_values(old_vi.metadata, x) return VarInfo( md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) ) @@ -217,53 +218,42 @@ vector_length(varinfo::VarInfo) = length(varinfo.metadata) vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) vector_length(md::Metadata) = sum(length, md.ranges) -unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) - -# TODO: deprecate. -function unflatten(vi::VarInfo, spl::AbstractSampler, x::AbstractVector) - md = unflatten(vi.metadata, spl, x) +function unflatten(vi::VarInfo, x::AbstractVector) + md = unflatten_metadata(vi.metadata, x) + # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases + # where e.g. x is a type gradient of some AD backend. return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi))) end -# The Val(getspace(spl)) is used to dispatch into the below generated function. -function unflatten(metadata::NamedTuple, spl::AbstractSampler, x::AbstractVector) - return unflatten(metadata, Val(getspace(spl)), x) -end - -@generated function unflatten( - metadata::NamedTuple{names}, ::Val{space}, x -) where {names,space} +# We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in +# utils.jl. +@generated function unflatten_metadata( + metadata::NamedTuple{names}, x::AbstractVector +) where {names} exprs = [] offset = :(0) for f in names mdf = :(metadata.$f) - if inspace(f, space) || length(space) == 0 - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - else - push!(exprs, :($f = $mdf)) - end + len = :(sum(length, $mdf.ranges)) + push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) + offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) end # For Metadata unflatten and replace_values are the same. For VarNamedVector they are not. -function unflatten(md::Metadata, x::AbstractVector) +function unflatten_metadata(md::Metadata, x::AbstractVector) return replace_values(md, x) end -function unflatten(md::Metadata, spl::AbstractSampler, x::AbstractVector) - return replace_values(md, spl, x) -end + +unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) # without AbstractSampler function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end -# TODO: Remove `space` argument when no longer needed. Ref: https://github.com/TuringLang/DynamicPPL.jl/issues/573 -replace_values(metadata::Metadata, space, x) = replace_values(metadata, x) function replace_values(metadata::Metadata, x) return Metadata( metadata.idcs, @@ -277,20 +267,14 @@ function replace_values(metadata::Metadata, x) ) end -@generated function replace_values( - metadata::NamedTuple{names}, ::Val{space}, x -) where {names,space} +@generated function replace_values(metadata::NamedTuple{names}, x) where {names} exprs = [] offset = :(0) for f in names mdf = :(metadata.$f) - if inspace(f, space) || length(space) == 0 - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - else - push!(exprs, :($f = $mdf)) - end + len = :(sum(length, $mdf.ranges)) + push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) + offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) @@ -786,7 +770,7 @@ settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) """ syms(vi::VarInfo) -Returns a tuple of the unique symbols of random variables sampled in `vi`. +Returns a tuple of the unique symbols of random variables in `vi`. """ syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::TypedVarInfo) = keys(vi.metadata) @@ -794,16 +778,6 @@ syms(vi::TypedVarInfo) = keys(vi.metadata) _getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) _getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) -# Get all indices of variables belonging to SampleFromPrior: -# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to -# the SampleFromPrior sampler -@inline function _getidcs(vi::UntypedVarInfo, ::SampleFromPrior) - return filter(i -> isempty(vi.metadata.gids[i]), 1:length(vi.metadata.gids)) -end -# Get a NamedTuple of all the indices belonging to SampleFromPrior, one for each symbol -@inline function _getidcs(vi::TypedVarInfo, ::SampleFromPrior) - return _getidcs(vi.metadata) -end @generated function _getidcs(metadata::NamedTuple{names}) where {names} exprs = [] for f in names @@ -813,93 +787,15 @@ end return :($(exprs...),) end -# Get all indices of variables belonging to a given sampler -@inline function _getidcs(vi::VarInfo, spl::Sampler) - # NOTE: 0b00 is the sanity flag for - # |\____ getidcs (mask = 0b10) - # \_____ getranges (mask = 0b01) - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - # Checks if cache is valid, i.e. no new pushes were made, to return the cached idcs - # Otherwise, it recomputes the idcs and caches it - #if haskey(spl.info, :idcs) && (spl.info[:cache_updated] & CACHEIDCS) > 0 - # spl.info[:idcs] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS - idcs = _getidcs(vi, spl.selector, Val(getspace(spl))) - #spl.info[:idcs] = idcs - #end - return idcs -end -@inline _getidcs(vi::UntypedVarInfo, s::Selector, space) = findinds(vi.metadata, s, space) -@inline _getidcs(vi::TypedVarInfo, s::Selector, space) = _getidcs(vi.metadata, s, space) -# Get a NamedTuple for all the indices belonging to a given selector for each symbol -@generated function _getidcs( - metadata::NamedTuple{names}, s::Selector, ::Val{space} -) where {names,space} - exprs = [] - # Iterate through each varname in metadata. - for f in names - # If the varname is in the sampler space - # or the sample space is empty (all variables) - # then return the indices for that variable. - if inspace(f, space) || length(space) == 0 - push!(exprs, :($f = findinds(metadata.$f, s, Val($space)))) - end - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end -@inline function findinds(f_meta::Metadata, s, ::Val{space}) where {space} - # Get all the idcs of the vns in `space` and that belong to the selector `s` - return filter( - (i) -> - (s in f_meta.gids[i] || isempty(f_meta.gids[i])) && - (isempty(space) || inspace(f_meta.vns[i], space)), - 1:length(f_meta.gids), - ) -end @inline function findinds(f_meta::Metadata) # Get all the idcs of the vns return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) end -function findinds(vnv::VarNamedVector, ::Selector, ::Val{space}) where {space} - # New Metadata objects are created with an empty list of gids, which is intrepreted as - # all Selectors applying to all variables. We assume the same behavior for - # VarNamedVector, and thus ignore the Selector argument. - if space !== () - msg = "VarNamedVector does not support selecting variables based on samplers" - throw(ErrorException(msg)) - else - return findinds(vnv) - end -end - function findinds(vnv::VarNamedVector) return 1:length(vnv.varnames) end -# Get all vns of variables belonging to spl -_getvns(vi::VarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -function _getvns(vi::VarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) - return _getvns(vi, Selector(), Val(())) -end -function _getvns(vi::UntypedVarInfo, s::Selector, space) - return view(vi.metadata.vns, _getidcs(vi, s, space)) -end -function _getvns(vi::TypedVarInfo, s::Selector, space) - return _getvns(vi.metadata, _getidcs(vi, s, space)) -end -# Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol -@generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = Base.keys(metadata.$f)[idcs.$f])) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - """ all_varnames_grouped_by_symbol(vi::TypedVarInfo) @@ -916,47 +812,6 @@ all_varnames_grouped_by_symbol(vi::TypedVarInfo) = return expr end -# Get the index (in vals) ranges of all the vns of variables belonging to spl -@inline function _getranges(vi::VarInfo, spl::Sampler) - ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 - # spl.info[:ranges] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES - ranges = _getranges(vi, spl.selector, Val(getspace(spl))) - #spl.info[:ranges] = ranges - return ranges - #end -end -# Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::VarInfo, s::Selector, space) - return _getranges(vi, _getidcs(vi, s, space)) -end -@inline function _getranges(vi::VarInfo, idcs::Vector{Int}) - return mapreduce(i -> vi.metadata.ranges[i], vcat, idcs; init=Int[]) -end -@inline _getranges(vi::TypedVarInfo, idcs::NamedTuple) = _getranges(vi.metadata, idcs) - -@generated function _getranges(metadata::NamedTuple, idcs::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = findranges(metadata.$f.ranges, idcs.$f))) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -@inline function findranges(f_ranges, f_idcs) - # Old implementation was using `mapreduce` but turned out - # to be type-unstable. - results = Int[] - for i in f_idcs - append!(results, f_ranges[i]) - end - return results -end - # TODO(mhauru) These set_flag! methods return the VarInfo. They should probably be called # set_flag!!. """ @@ -1096,12 +951,6 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] return expr end -# FIXME(torfjelde): Don't use `_getvns`. -Base.keys(vi::UntypedVarInfo, spl::AbstractSampler) = _getvns(vi, spl) -function Base.keys(vi::TypedVarInfo, spl::AbstractSampler) - return mapreduce(values, vcat, _getvns(vi, spl)) -end - """ setgid!(vi::VarInfo, gid::Selector, vn::VarName) @@ -1191,7 +1040,6 @@ function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, mode return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) end -# X -> R for all variables associated with given sampler function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -1297,7 +1145,6 @@ function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, m return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) end -# R -> X for all variables associated with given sampler function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1394,16 +1241,6 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) return vi end -# HACK: We need `SampleFromPrior` to result in ALL values which are in need -# of a transformation to be transformed. `_getvns` will by default return -# an empty iterable for `SampleFromPrior`, so we need to override it here. -# This is quite hacky, but seems safer than changing the behavior of `_getvns`. -_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) -_getvns_link(varinfo::VarInfo, spl::SampleFromPrior) = nothing -function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) - return map(Returns(nothing), varinfo.metadata) -end - function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1617,7 +1454,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. + # supposed to touch this `vn`. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] @@ -1677,30 +1514,26 @@ function _invlink_metadata!!( return metadata end +# TODO(mhauru) The treatment of the case when some variables are linked and others are not +# should be revised. It used to be the case that for UntypedVarInfo `islinked` returned +# whether the first variable was linked. For TypedVarInfo we did an OR over the first +# variables under each symbol. We now more consistently use OR, but I'm not convinced this +# is really the right thing to do. """ - islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) + islinked(vi::VarInfo) -Check whether `vi` is in the transformed space for a particular sampler `spl`. +Check whether `vi` is in the transformed space. Turing's Hamiltonian samplers use the `link` and `invlink` functions from [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable (for example, one bounded to the space `[0, 1]`) from its constrained space to the set of real numbers. `islinked` checks if the number is in the constrained space or the real space. + +If some but only some of the variables in `vi` are linked, this function will return `true`. +This behavior will likely change in the future. """ -function islinked(vi::UntypedVarInfo, spl::Union{Sampler,SampleFromPrior}) - vns = _getvns(vi, spl) - return istrans(vi, vns[1]) -end -function islinked(vi::TypedVarInfo, spl::Union{Sampler,SampleFromPrior}) - vns = _getvns(vi, spl) - return _islinked(vi, vns) -end -@generated function _islinked(vi, vns::NamedTuple{names}) where {names} - out = [] - for f in names - push!(out, :(isempty(vns.$f) ? false : istrans(vi, vns.$f[1]))) - end - return Expr(:||, false, out...) +function islinked(vi::VarInfo) + return any(istrans(vi, vn) for vn in keys(vi)) end function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) @@ -1788,22 +1621,6 @@ function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) return recombine(dist, vals_linked, length(vns)) end -""" - getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) - -Return the current value(s) of the random variables sampled by `spl` in `vi`. - -The value(s) may or may not be transformed to Euclidean space. -""" -getindex(vi::UntypedVarInfo, spl::Sampler) = - copy(getindex(vi.metadata.vals, _getranges(vi, spl))) -getindex(vi::VarInfo, spl::Sampler) = copy(getindex_internal(vi, _getranges(vi, spl))) -function getindex(vi::TypedVarInfo, spl::Sampler) - # Gets the ranges as a NamedTuple - ranges = _getranges(vi, spl) - # Calling getfield(ranges, f) gives all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi` - return reduce(vcat, _getindex(vi.metadata, ranges)) -end # Recursively builds a tuple of the `vals` of all the symbols @generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} expr = Expr(:tuple) @@ -1828,43 +1645,6 @@ function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) return vi end -""" - setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) - -Set the current value(s) of the random variables sampled by `spl` in `vi` to `val`. - -The value(s) may or may not be transformed to Euclidean space. -""" -setindex!(vi::VarInfo, val, spl::SampleFromPrior) = setall!(vi, val) -setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) -function setindex!(vi::TypedVarInfo, val, spl::Sampler) - # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` - ranges = _getranges(vi, spl) - _setindex!(vi.metadata, val, ranges) - return nothing -end - -function BangBang.setindex!!(vi::VarInfo, val, spl::AbstractSampler) - setindex!(vi, val, spl) - return vi -end - -# Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. -@generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} - expr = Expr(:block) - offset = :(0) - for f in names - f_vals = :(metadata.$f.vals) - f_range = :(ranges.$f) - start = :($offset + 1) - len = :(length($f_range)) - finish = :($offset + $len) - push!(expr.args, :(@views $f_vals[$f_range] .= val[($start):($finish)])) - offset = :($offset + $len) - end - return expr -end - @inline function findvns(vi, f_vns) if length(f_vns) == 0 throw("Unidentified error, please report this error in an issue.") @@ -1877,7 +1657,7 @@ Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) """ haskey(vi::VarInfo, vn::VarName) -Check whether `vn` has been sampled in `vi`. +Check whether `vn` has a value in `vi`. """ Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) function Base.haskey(vi::TypedVarInfo, vn::VarName) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 7da126321..3b3f0ce42 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -510,12 +510,6 @@ function getindex_internal(vnv::VarNamedVector, ::Colon) end end -# TODO(mhauru): Remove this as soon as possible. Only needed because of the old Gibbs -# sampler. -function Base.getindex(vnv::VarNamedVector, spl::AbstractSampler) - throw(ErrorException("Cannot index a VarNamedVector with a sampler.")) -end - function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) if haskey(vnv, vn) return update!(vnv, val, vn) @@ -1077,15 +1071,6 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) ) end -# TODO(mhauru) To be removed once the old Gibbs sampler is removed. -function unflatten(vnv::VarNamedVector, spl::AbstractSampler, vals::AbstractVector) - if length(getspace(spl)) > 0 - msg = "Selecting values in a VarNamedVector with a space is not supported." - throw(ArgumentError(msg)) - end - return unflatten(vnv, vals) -end - function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Return early if possible. isempty(left_vnv) && return deepcopy(right_vnv) diff --git a/test/model.jl b/test/model.jl index e91de4bd2..256ada0ad 100644 --- a/test/model.jl +++ b/test/model.jl @@ -230,8 +230,8 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 # Sample with large variations. - r_raw = randn(length(vi[spl])) * 10 - vi[spl] = r_raw + r_raw = randn(length(vi[:])) * 10 + DynamicPPL.setall!(vi, r_raw) @test vi[@varname(m)] == r_raw[1] @test vi[@varname(x)] != r_raw[2] model(vi) diff --git a/test/sampler.jl b/test/sampler.jl index 3b5424671..50111b1fd 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -196,11 +196,11 @@ vi = VarInfo(model) @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], DynamicPPL.SampleFromPrior(), model + vi, [initial_z, initial_x], model ) @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), DynamicPPL.SampleFromPrior(), model + vi, (X=initial_x, Z=initial_z), model ) end end From 4b9665a7b0d0142b0c7aba27795c5a1c6141745b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 17 Feb 2025 18:20:12 +0000 Subject: [PATCH 06/14] Remove samplers from VarInfo - Selectors and GIDs (#808) * Remove Selectors and Gibbs IDs * Remove getspace * Remove a dead VNV method --------- Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- HISTORY.md | 2 + docs/src/api.md | 7 --- ext/DynamicPPLChainRulesCoreExt.jl | 10 +--- src/DynamicPPL.jl | 12 ---- src/abstract_varinfo.jl | 46 --------------- src/context_implementations.jl | 12 ++-- src/sampler.jl | 9 +-- src/selector.jl | 13 ---- src/simple_varinfo.jl | 28 ++------- src/threadsafe.jl | 9 +-- src/varinfo.jl | 95 +++++------------------------- src/varnamedvector.jl | 17 ++---- test/ad.jl | 3 +- test/runtests.jl | 2 +- test/sampler.jl | 1 - 15 files changed, 40 insertions(+), 226 deletions(-) delete mode 100644 src/selector.jl diff --git a/HISTORY.md b/HISTORY.md index 6b7247c8d..90db022e7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -15,6 +15,8 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `eltype(::VarInfo)` no longer accepts a sampler as an argument - `keys(::VarInfo)` no longer accepts a sampler as an argument - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. + - `push!!` and `push!` no longer accept samplers or `Selector`s as arguments + - `getgid`, `setgid!`, `updategid!`, `getspace`, and `inspace` no longer exist ### Reverse prefixing order diff --git a/docs/src/api.md b/docs/src/api.md index 6c58264fe..b9cafaaf4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -289,13 +289,6 @@ unset_flag! is_flagged ``` -For Gibbs sampling the following functions were added. - -```@docs -setgid! -updategid! -``` - The following functions were used for sequential Monte Carlo methods. ```@docs diff --git a/ext/DynamicPPLChainRulesCoreExt.jl b/ext/DynamicPPLChainRulesCoreExt.jl index 1559467f8..12b816c60 100644 --- a/ext/DynamicPPLChainRulesCoreExt.jl +++ b/ext/DynamicPPLChainRulesCoreExt.jl @@ -10,15 +10,7 @@ end # See https://github.com/TuringLang/Turing.jl/issues/1199 ChainRulesCore.@non_differentiable BangBang.push!!( - vi::DynamicPPL.VarInfo, - vn::DynamicPPL.VarName, - r, - dist::Distributions.Distribution, - gidset::Set{DynamicPPL.Selector}, -) - -ChainRulesCore.@non_differentiable DynamicPPL.updategid!( - vi::DynamicPPL.AbstractVarInfo, vn::DynamicPPL.VarName, spl::DynamicPPL.Sampler + vi::DynamicPPL.VarInfo, vn::DynamicPPL.VarName, r, dist::Distributions.Distribution ) # No need + causes issues for some AD backends, e.g. Zygote. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 55e1f7e88..8fea43e50 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -63,8 +63,6 @@ export AbstractVarInfo, is_flagged, set_flag!, unset_flag!, - setgid!, - updategid!, setorder!, istrans, link, @@ -74,7 +72,6 @@ export AbstractVarInfo, values_as, # VarName (reexport from AbstractPPL) VarName, - inspace, subsumes, @varname, # Compiler @@ -152,9 +149,6 @@ macro prob_str(str) )) end -# Used here and overloaded in Turing -function getspace end - """ AbstractVarInfo @@ -166,14 +160,8 @@ See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). """ abstract type AbstractVarInfo <: AbstractModelTrace end -const LEGACY_WARNING = """ -!!! warning - This method is considered legacy, and is likely to be deprecated in the future. -""" - # Necessary forward declarations include("utils.jl") -include("selector.jl") include("chains.jl") include("model.jl") include("sampler.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 4e9e5c554..66b098370 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -169,51 +169,6 @@ See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@re """ function getindex_internal end -""" - push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - -Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`, mutating if it makes sense. -""" -function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return BangBang.push!!(vi, vn, r, dist, Set{Selector}([])) -end - -""" - push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`, if it makes sense. - -The sampler is passed here to invalidate its cache where defined. - -$(LEGACY_WARNING) -""" -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler -) - return BangBang.push!!(vi, vn, r, dist, spl.selector) -end -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler -) - return BangBang.push!!(vi, vn, r, dist) -end - -""" - push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler of -selector `gid` from a distribution `dist` to `VarInfo` `vi`. - -$(LEGACY_WARNING) -""" -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector -) - return BangBang.push!!(vi, vn, r, dist, Set([gid])) -end - @doc """ empty!!(vi::AbstractVarInfo) @@ -768,7 +723,6 @@ end # Legacy code that is currently overloaded for the sake of simplicity. # TODO: Remove when possible. increment_num_produce!(::AbstractVarInfo) = nothing -setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 462012676..4594902dc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -239,11 +239,11 @@ function assume( r = init(rng, dist, sampler) if istrans(vi) f = to_linked_internal_transform(vi, vn, dist) - push!!(vi, vn, f(r), dist, sampler) + push!!(vi, vn, f(r), dist) # By default `push!!` sets the transformed flag to `false`. settrans!!(vi, true, vn) else - push!!(vi, vn, r, dist, sampler) + push!!(vi, vn, r, dist) end end @@ -466,11 +466,11 @@ function get_and_set_val!( vn = vns[i] if istrans(vi) ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i]) - push!!(vi, vn, ri_linked, dist, spl) + push!!(vi, vn, ri_linked, dist) # `push!!` sets the trans-flag to `false` by default. settrans!!(vi, true, vn) else - push!!(vi, vn, r[:, i], dist, spl) + push!!(vi, vn, r[:, i], dist) end end end @@ -513,14 +513,14 @@ function get_and_set_val!( # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) - push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,)) + push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists) # NOTE: Need to add the correction. # FIXME: This is not great. acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else - push!!.((vi,), vns, r, dists, (spl,)) + push!!.((vi,), vns, r, dists) end end return r diff --git a/src/sampler.jl b/src/sampler.jl index 56cd8404e..aa3a637ee 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -18,8 +18,6 @@ Sampling algorithm that samples unobserved random variables from their prior dis """ struct SampleFromPrior <: AbstractSampler end -getspace(::Union{SampleFromPrior,SampleFromUniform}) = () - # Initializations. init(rng, dist, ::SampleFromPrior) = rand(rng, dist) function init(rng, dist, ::SampleFromUniform) @@ -31,6 +29,8 @@ function init(rng, dist, ::SampleFromUniform, n::Int) return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) end +# TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? +# (Selector has been removed). """ Sampler{T} @@ -47,12 +47,7 @@ By default, values are sampled from the prior. """ struct Sampler{T} <: AbstractSampler alg::T - selector::Selector # Can we remove it? - # TODO: add space such that we can integrate existing external samplers in DynamicPPL end -Sampler(alg) = Sampler(alg, Selector()) -Sampler(alg, model::Model) = Sampler(alg, model, Selector()) -Sampler(alg, model::Model, s::Selector) = Sampler(alg, s) # AbstractMCMC interface for SampleFromUniform and SampleFromPrior function AbstractMCMC.step( diff --git a/src/selector.jl b/src/selector.jl deleted file mode 100644 index fd4aa6d1c..000000000 --- a/src/selector.jl +++ /dev/null @@ -1,13 +0,0 @@ -struct Selector - gid::UInt64 - tag::Symbol # :default, :invalid, :Gibbs, :HMC, etc. - rerun::Bool -end -function Selector(tag::Symbol=:default, rerun=tag != :default) - return Selector(time_ns(), tag, rerun) -end -function Selector(gid::Integer, tag::Symbol=:default) - return Selector(gid, tag, tag != :default) -end -hash(s::Selector) = hash(s.gid) -==(s1::Selector, s2::Selector) = s1.gid == s2.gid diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 07296c3f7..00d6b3437 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -374,42 +374,26 @@ end # `NamedTuple` function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, - vn::VarName{sym,typeof(identity)}, - value, - dist::Distribution, - gidset::Set{Selector}, + vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution ) where {sym} return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) end function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, - vn::VarName{sym}, - value, - dist::Distribution, - gidset::Set{Selector}, + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution ) where {sym} return Accessors.@set vi.values = set!!(vi.values, vn, value) end # `AbstractDict` function BangBang.push!!( - vi::SimpleVarInfo{<:AbstractDict}, - vn::VarName, - value, - dist::Distribution, - gidset::Set{Selector}, + vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution ) vi.values[vn] = value return vi end function BangBang.push!!( - vi::SimpleVarInfo{<:VarNamedVector}, - vn::VarName, - value, - dist::Distribution, - gidset::Set{Selector}, + vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution ) # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. @@ -483,7 +467,7 @@ function assume( value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. value_raw = to_maybe_linked_internal(vi, vn, dist, value) - vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) + vi = BangBang.push!!(vi, vn, value_raw, dist) return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end @@ -550,7 +534,7 @@ function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) +istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) islinked(vi::SimpleVarInfo) = istrans(vi) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 4367ff06d..539c1e9d6 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -57,10 +57,8 @@ end has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) -function BangBang.push!!( - vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} -) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) +function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) + return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) @@ -70,9 +68,6 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) - return setgid!(vi.varinfo, gid, vn) -end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) diff --git a/src/varinfo.jl b/src/varinfo.jl index 8f7f7b6c1..ca143ea63 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -18,8 +18,6 @@ Let `md` be an instance of `Metadata`: `md.vns`, `md.ranges` `md.dists`, `md.orders` and `md.flags`. - `md.vns[md.idcs[vn]] == vn`. - `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.gids[md.idcs[vn]]` is the set of algorithms used to sample `vn`. This is used in - the Gibbs sampling process. - `md.orders[md.idcs[vn]]` is the number of `observe` statements before `vn` is sampled. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. @@ -41,7 +39,6 @@ struct Metadata{ TDists<:AbstractVector{<:Distribution}, TVN<:AbstractVector{<:VarName}, TVal<:AbstractVector{<:Real}, - TGIds<:AbstractVector{Set{Selector}}, } # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` idcs::TIdcs # Dict{<:VarName,Int} @@ -60,10 +57,6 @@ struct Metadata{ # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} - # Vector of sampler ids corresponding to `vns` - # Each random variable can be sampled using multiple samplers, e.g. in Gibbs, hence the `Set` - gids::TGIds # AbstractVector{Set{Selector}} - # Number of `observe` statements before each random variable is sampled orders::Vector{Int} @@ -261,7 +254,6 @@ function replace_values(metadata::Metadata, x) metadata.ranges, x, metadata.dists, - metadata.gids, metadata.orders, metadata.flags, ) @@ -301,7 +293,6 @@ function Metadata() Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), - Vector{Set{Selector}}(), Vector{Int}(), flags, ) @@ -320,7 +311,6 @@ function empty!(meta::Metadata) empty!(meta.ranges) empty!(meta.vals) empty!(meta.dists) - empty!(meta.gids) empty!(meta.orders) for k in keys(meta.flags) empty!(meta.flags[k]) @@ -418,7 +408,6 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va ranges, vals, metadata.dists[indices_for_vns], - metadata.gids[indices_for_vns], metadata.orders[indices_for_vns], flags, ) @@ -493,7 +482,6 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - gids = Set{Selector}[] orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. @@ -516,15 +504,13 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)] - push!(gids, gid) push!(orders, getorder(metadata_for_vn, vn)) for k in keys(flags) push!(flags[k], is_flagged(metadata_for_vn, vn, k)) end end - return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags) + return Metadata(idcs, vns, ranges, vals, dists, orders, flags) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -730,13 +716,6 @@ end return expr end -""" - getgid(vi::VarInfo, vn::VarName) - -Return the set of sampler selectors associated with `vn` in `vi`. -""" -getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] - function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) settrans!!(getmetadata(vi, vn), trans, vn) return vi @@ -787,14 +766,8 @@ _getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) return :($(exprs...),) end -@inline function findinds(f_meta::Metadata) - # Get all the idcs of the vns - return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) -end - -function findinds(vnv::VarNamedVector) - return 1:length(vnv.varnames) -end +@inline findinds(f_meta::Metadata) = eachindex(f_meta.vns) +findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) """ all_varnames_grouped_by_symbol(vi::TypedVarInfo) @@ -876,9 +849,6 @@ function TypedVarInfo(vi::UntypedVarInfo) sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) # New dists sym_dists = getindex.((meta.dists,), inds) - # New gids, can make a resizeable FillArray - sym_gids = getindex.((meta.gids,), inds) - @assert length(sym_gids) <= 1 || all(x -> x == sym_gids[1], @view sym_gids[2:end]) # New orders sym_orders = getindex.((meta.orders,), inds) # New flags @@ -899,14 +869,7 @@ function TypedVarInfo(vi::UntypedVarInfo) push!( new_metas, Metadata( - sym_idcs, - sym_vns, - sym_ranges, - sym_vals, - sym_dists, - sym_gids, - sym_orders, - sym_flags, + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags ), ) end @@ -951,21 +914,6 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] return expr end -""" - setgid!(vi::VarInfo, gid::Selector, vn::VarName) - -Add `gid` to the set of sampler selectors associated with `vn` in `vi`. -""" -setgid!(vi::VarInfo, gid::Selector, vn::VarName) = setgid!(getmetadata(vi, vn), gid, vn) - -function setgid!(m::Metadata, gid::Selector, vn::VarName) - return push!(m.gids[getidx(m, vn)], gid) -end - -function setgid!(vnv::VarNamedVector, gid::Selector, vn::VarName) - throw(ErrorException("Calling setgid! on a VarNamedVector isn't valid.")) -end - istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") @@ -1348,7 +1296,6 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.gids, metadata.orders, metadata.flags, ) @@ -1491,7 +1438,6 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.gids, metadata.orders, metadata.flags, ) @@ -1675,7 +1621,6 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) | Varnames : $(string(vi.metadata.vns)) | Range : $(vi.metadata.ranges) | Vals : $(vi.metadata.vals) - | GIDs : $(vi.metadata.gids) | Orders : $(vi.metadata.orders) | Logp : $(getlogp(vi)) | #produce : $(get_num_produce(vi)) @@ -1715,13 +1660,17 @@ function Base.show(io::IO, vi::UntypedVarInfo) return print(io, ")") end -function BangBang.push!!( - vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} -) +""" + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) + +Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to +the `VarInfo` `vi`, mutating if it makes sense. +""" +function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist" end sym = getsym(vn) @@ -1734,14 +1683,13 @@ function BangBang.push!!( [1:length(val)], val, [dist], - [gidset], [get_num_produce(vi)], Dict{String,BitVector}("trans" => [false], "del" => [false]), ) vi = Accessors.@set vi.metadata[sym] = md else meta = getmetadata(vi, vn) - push!(meta, vn, r, dist, gidset, get_num_produce(vi)) + push!(meta, vn, r, dist, get_num_produce(vi)) end return vi @@ -1761,7 +1709,7 @@ end # exist in the TypedVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. -function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) +function Base.push!(meta::Metadata, vn, r, dist, num_produce) val = tovec(r) meta.idcs[vn] = length(meta.idcs) + 1 push!(meta.vns, vn) @@ -1770,7 +1718,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.gids, gidset) push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) @@ -1914,18 +1861,6 @@ end return expr end -""" - updategid!(vi::VarInfo, vn::VarName, spl::Sampler) - -Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked -and `vn`'s symbol is in the space of `spl`. -""" -function updategid!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, spl::Sampler) - if inspace(vn, getspace(spl)) - setgid!(vi, spl.selector, vn) - end -end - # TODO: Maybe rename or something? """ _apply!(kernel!, vi::VarInfo, values, keys) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 3b3f0ce42..6b7c82859 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,9 +766,9 @@ function update_internal!( return nothing end -# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# TODO(mhauru) The num_produce argument is used by Particle Gibbs. # Remove this method as soon as possible. -function BangBang.push!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) +function BangBang.push!(vnv::VarNamedVector, vn, val, dist, num_produce) f = from_vec_transform(dist) return setindex_internal!(vnv, tovec(val), vn, f) end @@ -963,9 +963,9 @@ function BangBang.push!!(vnv::VarNamedVector, pair::Pair) return setindex!!(vnv, val, vn) end -# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# TODO(mhauru) The num_produce argument is used by Particle Gibbs. # Remove this method as soon as possible. -function BangBang.push!!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) +function BangBang.push!!(vnv::VarNamedVector, vn, val, dist, num_produce) f = from_vec_transform(dist) return setindex_internal!!(vnv, tovec(val), vn, f) end @@ -1023,15 +1023,6 @@ julia> ForwardDiff.gradient(f, [1.0]) """ replace_raw_storage(vnv::VarNamedVector, vals) = Accessors.@set vnv.vals = vals -# TODO(mhauru) The space argument is used by the old Gibbs sampler. To be removed. -function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {space} - if length(space) > 0 - msg = "Selecting values in a VarNamedVector with a space is not supported." - throw(ArgumentError(msg)) - end - return replace_raw_storage(vnv, vals) -end - vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv) """ diff --git a/test/ad.jl b/test/ad.jl index 17981cf2a..87c7f22c3 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -41,7 +41,7 @@ σ = 0.3 y = @. rand(sin(t) + Normal(0, σ)) @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors + # Priors α ~ Normal(y[1], 0.001) τ ~ Exponential(1) η ~ filldist(Normal(0, 1), TT - 1) @@ -63,7 +63,6 @@ # overload assume so that model evaluation doesn't fail due to a lack # of implementation struct MyEmptyAlg end - DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = () DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = DynamicPPL.assume(dist, vn, vi) diff --git a/test/runtests.jl b/test/runtests.jl index 29a148789..25cd2fb40 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,7 +33,7 @@ using JET: JET using Combinatorics: combinations using OrderedCollections: OrderedSet -using DynamicPPL: getargs_dottilde, getargs_tilde, Selector +using DynamicPPL: getargs_dottilde, getargs_tilde const GROUP = get(ENV, "GROUP", "All") Random.seed!(100) diff --git a/test/sampler.jl b/test/sampler.jl index 50111b1fd..8c4f1ed96 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -67,7 +67,6 @@ ) return vi, nothing end - DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = () # initial samplers DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() From f5e84f46d7ceadbc393062598e52f334143ec4be Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Feb 2025 17:15:30 +0000 Subject: [PATCH 07/14] Remove dot_tilde pipeline (#804) * Refactor dot_tilde, work in progress * Restrict dot_tilde to univariate dists on the RHS * Remove tests with multivariates or arrays as RHS of .~ * emove dot_tilde pipeline * Fix a .~ bug * Update HISTORY.md * Fix a tiny test bug * Re-enable some SimpleVarInfo tests * Improve changelog entry * Improve error message * Fix trivial typos * Fix pointwise_logdensity test * Remove pointless check_dot_tilde_rhs method * Add tests for old .~ syntax * Bump Mooncake patch version to v0.4.90 * Bump Mooncake to 0.4.95 --------- Co-authored-by: Penelope Yong --- HISTORY.md | 70 +++++- Project.toml | 6 +- docs/src/api.md | 2 - ext/DynamicPPLZygoteRulesExt.jl | 25 --- src/DynamicPPL.jl | 4 - src/compiler.jl | 91 ++++---- src/context_implementations.jl | 381 -------------------------------- src/debug_utils.jl | 145 ------------ src/extract_priors.jl | 5 - src/model.jl | 2 +- src/pointwise_logdensities.jl | 86 +------ src/simple_varinfo.jl | 51 ----- src/test_utils/contexts.jl | 12 - src/test_utils/models.jl | 121 ++++------ src/transforming.jl | 61 ----- src/values_as_in_model.jl | 23 -- test/compat/ad.jl | 28 --- test/compiler.jl | 27 +++ test/context_implementations.jl | 53 ++--- test/pointwise_logdensities.jl | 4 +- test/simple_varinfo.jl | 9 +- 21 files changed, 198 insertions(+), 1008 deletions(-) delete mode 100644 ext/DynamicPPLZygoteRulesExt.jl diff --git a/HISTORY.md b/HISTORY.md index 90db022e7..fa9e58e99 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,74 @@ **Breaking** +### `.~` right hand side must be a univariate distribution + +Previously we allowed statements like + +```julia +x .~ [Normal(), Gamma()] +``` + +where the right hand side of a `.~` was an array of distributions, and ones like + +```julia +x .~ MvNormal(fill(0.0, 2), I) +``` + +where the right hand side was a multivariate distribution. + +These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as + +```julia +x = Array{Float64,3}(undef, 2, 3, 4) +x .~ Normal() +``` + +The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read. + +If the right hand side and the left hand side have the same dimension, one can simply use `~`. Arrays of distributions can be replaced with `product_distribution`. So instead of + +```julia +x .~ [Normal(), Gamma()] +x .~ Normal.(y) +x .~ MvNormal(fill(0.0, 2), I) +``` + +do + +```julia +x ~ product_distribution([Normal(), Gamma()]) +x ~ product_distribution(Normal.(y)) +x ~ MvNormal(fill(0.0, 2), I) +``` + +This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as + +```julia +dists = Normal.(y) +for i in 1:length(dists) + x[i] ~ dists[i] +end +``` + +Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example, + +```julia +x = Array{Float64,3}(undef, 2, 3, 4) +x .~ MvNormal(fill(0, 2), I) +``` + +should be replaced with something like + +```julia +x = Array{Float64,3}(2, 3, 4) +for i in 1:3, j in 1:4 + x[:, i, j] ~ MvNormal(fill(0, 2), I) +end +``` + +This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side. + ### Remove indexing by samplers This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, @@ -14,7 +82,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `unflatten` no longer accepts a sampler as an argument - `eltype(::VarInfo)` no longer accepts a sampler as an argument - `keys(::VarInfo)` no longer accepts a sampler as an argument - - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. + - `VarInfo(::VarInfo, ::Sampler, ::AbstractVector)` no longer accepts the sampler argument. - `push!!` and `push!` no longer accept samplers or `Selector`s as arguments - `getgid`, `setgid!`, `updategid!`, `getspace`, and `inspace` no longer exist diff --git a/Project.toml b/Project.toml index be4586246..7cd47fdbb 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -33,7 +32,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] @@ -42,7 +40,6 @@ DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] -DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] ADTypes = "1" @@ -65,10 +62,9 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" -Mooncake = "0.4.59" +Mooncake = "0.4.95" OrderedCollections = "1" Random = "1.6" Requires = "1" Test = "1.6" -ZygoteRules = "0.2" julia = "1.10" diff --git a/docs/src/api.md b/docs/src/api.md index b9cafaaf4..4d3c6bc97 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -440,10 +440,8 @@ DynamicPPL.Experimental.is_suitable_varinfo ```@docs tilde_assume -dot_tilde_assume ``` ```@docs tilde_observe -dot_tilde_observe ``` diff --git a/ext/DynamicPPLZygoteRulesExt.jl b/ext/DynamicPPLZygoteRulesExt.jl deleted file mode 100644 index 78831fdc4..000000000 --- a/ext/DynamicPPLZygoteRulesExt.jl +++ /dev/null @@ -1,25 +0,0 @@ -module DynamicPPLZygoteRulesExt - -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL, Distributions - using ZygoteRules: ZygoteRules -else - using ..DynamicPPL: DynamicPPL, Distributions - using ..ZygoteRules: ZygoteRules -end - -# https://github.com/TuringLang/Turing.jl/issues/1595 -ZygoteRules.@adjoint function DynamicPPL.dot_observe( - spl::Union{DynamicPPL.SampleFromPrior,DynamicPPL.SampleFromUniform}, - dists::AbstractArray{<:Distributions.Distribution}, - value::AbstractArray, - vi, -) - function dot_observe_fallback(spl, dists, value, vi) - DynamicPPL.increment_num_produce!(vi) - return sum(map(Distributions.loglikelihood, dists, value)), vi - end - return ZygoteRules.pullback(dot_observe_fallback, __context__, spl, dists, value, vi) -end - -end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8fea43e50..f0d42f98c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -98,13 +98,9 @@ export AbstractVarInfo, PrefixContext, ConditionContext, assume, - dot_assume, observe, - dot_observe, tilde_assume, tilde_observe, - dot_tilde_assume, - dot_tilde_observe, # Pseudo distributions NamedDist, NoDist, diff --git a/src/compiler.jl b/src/compiler.jl index 8743641af..8bde5e784 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -161,7 +161,16 @@ Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` other """ isliteral(e) = false isliteral(::Number) = true -isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args) +function isliteral(e::Expr) + # In the special case that the expression is of the form `abc[blahblah]`, we consider it + # to be a literal if `abc` is a literal. This is necessary for cases like + # [1.0, 2.0][idx...] ~ Normal() + # which are generated when turning `.~` expressions into loops over `~` expressions. + if e.head == :ref + return isliteral(e.args[1]) + end + return !isempty(e.args) && all(isliteral, e.args) +end """ check_tilde_rhs(x) @@ -172,7 +181,7 @@ Check if the right-hand side `x` of a `~` is a `Distribution` or an array of function check_tilde_rhs(@nospecialize(x)) return throw( ArgumentError( - "the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s", + "the right-hand side of a `~` must be a `Distribution`, an array of `Distribution`s, or a submodel", ), ) end @@ -184,6 +193,27 @@ function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} return Sampleable{typeof(model),AutoPrefix}(model) end +""" + check_dot_tilde_rhs(x) + +Check if the right-hand side `x` of a `.~` is a `UnivariateDistribution`, then return `x`. +""" +function check_dot_tilde_rhs(@nospecialize(x)) + return throw( + ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`") + ) +end +function check_dot_tilde_rhs(::AbstractArray{<:Distribution}) + msg = """ + As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ + Please use `product_distribution` instead, or write a loop if necessary. \ + See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ + details.\ + """ + return throw(ArgumentError(msg)) +end +check_dot_tilde_rhs(x::UnivariateDistribution) = x + """ unwrap_right_vn(right, vn) @@ -356,11 +386,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - return Base.remove_linenums!( - generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), - ), + return generate_mainbody!( + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn ) end @@ -487,56 +514,16 @@ end Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - isliteral(left) && return generate_tilde_literal(left, right) - - # Otherwise it is determined by the model or its value, - # if the LHS represents an observation - @gensym vn isassumption value + @gensym dist left_axes idx return quote - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right - ) - $isassumption = $(DynamicPPL.isassumption(left, vn)) - if $(DynamicPPL.isfixed(left, vn)) - $left .= $(DynamicPPL.getfixed_nested)(__context__, $vn) - elseif $isassumption - $(generate_dot_tilde_assume(left, right, vn)) - else - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn) - end - - $value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $(maybe_view(left)), - $vn, - __varinfo__, - ) - $value + $dist = DynamicPPL.check_dot_tilde_rhs($right) + $left_axes = axes($left) + for $idx in Iterators.product($left_axes...) + $left[$idx...] ~ $dist end end end -function generate_dot_tilde_assume(left, right, vn) - # We don't need to use `Setfield.@set` here since - # `.=` is always going to be inplace + needs `left` to - # be something that supports `.=`. - @gensym value - return quote - $value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( - __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn - )..., - __varinfo__, - ) - $left .= $value - $value - end -end - # Note that we cannot use `MacroTools.isdef` because # of https://github.com/FluxML/MacroTools.jl/issues/154. """ diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4594902dc..af04d0f57 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -258,384 +258,3 @@ function observe(right::Distribution, left, vi) increment_num_produce!(vi) return Distributions.loglikelihood(right, left), vi end - -# .~ functions - -# assume -""" - dot_tilde_assume(context::SamplingContext, right, left, vn, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value for a context -associated with a sampler. - -Falls back to -```julia -dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, vi) -``` -""" -function dot_tilde_assume(context::SamplingContext, right, left, vn, vi) - return dot_tilde_assume( - context.rng, context.context, context.sampler, right, left, vn, vi - ) -end - -# `DefaultContext` -function dot_tilde_assume(context::AbstractContext, args...) - return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) -end -function dot_tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), rng, context, args...) -end - -function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) - return dot_assume(right, left, vns, vi) -end -function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi) - return dot_assume(rng, sampler, right, vns, left, vi) -end - -function dot_tilde_assume(::IsParent, context::AbstractContext, args...) - return dot_tilde_assume(childcontext(context), args...) -end -function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...) - return dot_tilde_assume(rng, childcontext(context), args...) -end - -function dot_tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, left, vns, vi -) - return dot_assume(rng, sampler, right, vns, left, vi) -end - -# `LikelihoodContext` -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(nodist(right), left, vn, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi -) - return dot_assume(rng, sampler, nodist(right), vn, left, vi) -end - -# `PrefixContext` -function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) - return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) -end - -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi -) - return dot_tilde_assume( - rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi - ) -end - -""" - dot_tilde_assume!!(context, right, left, vn, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value and updated `vi`. - -Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. -""" -function dot_tilde_assume!!(context, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`.~` with a model on the right-hand side is not supported; please use `~`" - ), - ) - value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp_assume!!(context, vi, logp) -end - -# `dot_assume` -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, -) - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns, dist] - lp = sum(zip(vns, eachcol(r))) do (vn, ri) - return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) - end - return r, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - vns::AbstractVector{<:VarName}, - var::AbstractMatrix, - vi::AbstractVarInfo, -) - @assert length(dist) == size(var, 1) - r = get_and_set_val!(rng, vi, vns, dist, spl) - lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - return r, lp, vi -end - -function dot_assume( - dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi -) - r = getindex.((vi,), vns, (dist,)) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) - return r, lp, vi -end - -function dot_assume( - dists::AbstractArray{<:Distribution}, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi, -) - r = getindex.((vi,), vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) - return r, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::AbstractVarInfo, -) - r = get_and_set_val!(rng, vi, vns, dists, spl) - # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) - return r, lp, vi -end -function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement" - ) -end - -# HACK: These methods are only used in the `get_and_set_val!` methods below. -# FIXME: Remove these. -function _link_broadcast_new(vi, vn, dist, r) - b = to_linked_internal_transform(vi, vn, dist) - return b(r) -end - -function _maybe_invlink_broadcast(vi, vn, dist) - xvec = getindex_internal(vi, vn) - b = from_maybe_linked_internal_transform(vi, vn, dist) - return b(xvec) -end - -function get_and_set_val!( - rng, - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - spl::Union{SampleFromPrior,SampleFromUniform}, -) - n = length(vns) - if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if - # that's okay. - unset_flag!(vi, vns[1], "del", true) - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) - setindex!!(vi, f_link_maybe(r[:, i]), vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - r = vi[vns, dist] - end - else - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - if istrans(vi) - ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i]) - push!!(vi, vn, ri_linked, dist) - # `push!!` sets the trans-flag to `false` by default. - settrans!!(vi, true, vn) - else - push!!(vi, vn, r[:, i], dist) - end - end - end - return r -end - -function get_and_set_val!( - rng, - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractArray{<:VarName}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - spl::Union{SampleFromPrior,SampleFromUniform}, -) - if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if - # that's okay. - unset_flag!(vi, vns[1], "del", true) - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - for i in eachindex(vns) - vn = vns[i] - dist = dists isa AbstractArray ? dists[i] : dists - f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) - setindex!!(vi, f_link_maybe(r[i]), vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - rs = _maybe_invlink_broadcast.((vi,), vns, dists) - r = reshape(rs, size(vns)) - end - else - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - # TODO: This will inefficient since it will allocate an entire vector. - # We could either: - # 1. Figure out the broadcast size and use a `foreach`. - # 2. Define an anonymous function which returns `nothing`, which - # we then broadcast. This will allocate a vector of `nothing` though. - if istrans(vi) - push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists) - # NOTE: Need to add the correction. - # FIXME: This is not great. - acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) - # `push!!` sets the trans-flag to `false` by default. - settrans!!.((vi,), true, vns) - else - push!!.((vi,), vns, r, dists) - end - end - return r -end - -function set_val!( - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - val::AbstractMatrix, -) - @assert size(val, 2) == length(vns) - foreach(enumerate(vns)) do (i, vn) - setindex!!(vi, val[:, i], vn) - end - return val -end -function set_val!( - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractArray{<:VarName}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - val::AbstractArray, -) - @assert size(val) == size(vns) - foreach(CartesianIndices(val)) do ind - setindex!!(vi, tovec(val[ind]), vns[ind]) - end - return val -end - -# observe -""" - dot_tilde_observe(context::SamplingContext, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value for a context associated with a sampler. - -Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. -""" -function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, context.sampler, right, left, vi) -end - -# Leaf contexts -function dot_tilde_observe(context::AbstractContext, args...) - return dot_tilde_observe(NodeTrait(tilde_observe, context), context, args...) -end -dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) -function dot_tilde_observe(::IsParent, context::AbstractContext, args...) - return dot_tilde_observe(childcontext(context), args...) -end - -dot_tilde_observe(::PriorContext, right, left, vi) = 0, vi -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi - -# `MiniBatchContext` -function dot_tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = dot_tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end - -# `PrefixContext` -function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vi) -end - -""" - dot_tilde_observe!!(context, right, left, vname, vi) - -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and updated `vi`. - -Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable -name and indices; if needed, these can be accessed through this function, though. -""" -function dot_tilde_observe!!(context, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - return dot_tilde_observe!!(context, right, left, vi) -end - -""" - dot_tilde_observe!!(context, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value and updated `vi`. - -Falls back to `dot_tilde_observe(context, right, left, vi)`. -""" -function dot_tilde_observe!!(context, right, left, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - logp, vi = dot_tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) -end - -# Falls back to non-sampler definition. -function dot_observe(::AbstractSampler, dist, value, vi) - return dot_observe(dist, value, vi) -end -function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value), vi -end -function dot_observe(dists::Distribution, value::AbstractArray, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(dists, value), vi -end -function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) - increment_num_produce!(vi) - return sum(Distributions.loglikelihood.(dists, value)), vi -end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 43b5054d5..328fe6983 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -113,52 +113,6 @@ function Base.show(io::IO, stmt::ObserveStmt) return print(io, ")") end -Base.@kwdef struct DotAssumeStmt <: Stmt - varname - left - right - value - logp - varinfo = nothing -end - -function Base.show(io::IO, stmt::DotAssumeStmt) - io = add_io_context(io) - print(io, " assume: ") - show_varname(io, stmt.varname) - print(io, " = ") - print(io, stmt.left) - print(io, " .~ ") - show_right(io, stmt.right) - print(io, " ") - print(io, RESULT_SYMBOL) - print(io, " ") - print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") -end - -Base.@kwdef struct DotObserveStmt <: Stmt - left - right - logp - varinfo = nothing -end - -function Base.show(io::IO, stmt::DotObserveStmt) - io = add_io_context(io) - print(io, "observe: ") - print(io, stmt.left) - print(io, " .~ ") - show_right(io, stmt.right) - print(io, " ") - print(io, RESULT_SYMBOL) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") -end - # Some utility methods for extracting information from a trace. """ varnames_in_trace(trace) @@ -168,24 +122,14 @@ Return all the varnames present in the trace. varnames_in_trace(trace::AbstractVector) = mapreduce(varnames_in_stmt, vcat, trace) varnames_in_stmt(stmt::AssumeStmt) = [stmt.varname] -function varnames_in_stmt(stmt::DotAssumeStmt) - return stmt.varname isa VarName ? [stmt.varname] : stmt.varname -end varnames_in_stmt(::ObserveStmt) = [] -varnames_in_stmt(::DotObserveStmt) = [] function distributions_in_trace(trace::AbstractVector) return mapreduce(distributions_in_stmt, vcat, trace) end distributions_in_stmt(stmt::AssumeStmt) = [stmt.right] -function distributions_in_stmt(stmt::DotAssumeStmt) - return stmt.right isa AbstractArray ? vec(stmt.right) : [stmt.right] -end distributions_in_stmt(stmt::ObserveStmt) = [stmt.right] -function distributions_in_stmt(stmt::DotObserveStmt) - return stmt.right isa AbstractArray ? vec(stmt.right) : [stmt.right] -end """ DebugContext <: AbstractContext @@ -382,95 +326,6 @@ function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, v return logp, vi end -# dot-assume -function record_pre_dot_tilde_assume!(context::DebugContext, vn, left, right, varinfo) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - error( - "Variable $(vn) has missing has missing value(s)!\n" * - "Usage of `missing` is not supported for dotted syntax, such as " * - "`@. x ~ dist` or `x .~ dist`", - ) - end - - # TODO: Can we do without the memory allocation here? - record_varname!.(broadcast_safe(context), vn, broadcast_safe(right)) - - # Check that `left` does not contain any `` - return nothing -end - -function record_post_dot_tilde_assume!( - context::DebugContext, vns, left, right, value, logp, varinfo -) - stmt = DotAssumeStmt(; - varname=vns, - left=left, - right=right, - value=value, - logp=logp, - varinfo=context.record_varinfo ? deepcopy(varinfo) : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - - return nothing -end - -function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi) - record_pre_dot_tilde_assume!(context, vn, left, right, vi) - value, logp, vi = DynamicPPL.dot_tilde_assume( - childcontext(context), right, left, vn, vi - ) - record_post_dot_tilde_assume!(context, vn, left, right, value, logp, vi) - return value, logp, vi -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi -) - record_pre_dot_tilde_assume!(context, vn, left, right, vi) - value, logp, vi = DynamicPPL.dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - record_post_dot_tilde_assume!(context, vn, left, right, value, logp, vi) - return value, logp, vi -end - -# dot-observe -function record_pre_dot_tilde_observe!(context::DebugContext, left, right, vi) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - # TODO: Once `observe` statements receive `vn`, refer to this in the - # error message. - error( - "Encountered missing value(s) in observe!\n" * - "Usage of `missing` is not supported for dotted syntax, such as " * - "`@. x ~ dist` or `x .~ dist`", - ) - end -end - -function record_post_dot_tilde_observe!(context::DebugContext, left, right, logp, vi) - stmt = DotObserveStmt(; - left=left, - right=right, - logp=logp, - varinfo=context.record_varinfo ? deepcopy(vi) : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing -end -function DynamicPPL.dot_tilde_observe(context::DebugContext, right, left, vi) - record_pre_dot_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.dot_tilde_observe(childcontext(context), right, left, vi) - record_post_dot_tilde_observe!(context, left, right, logp, vi) - return logp, vi -end - _conditioned_varnames(d::AbstractDict) = keys(d) _conditioned_varnames(d) = map(sym -> VarName{sym}(), keys(d)) function conditioned_varnames(context) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index dd5aeeb04..0f312fa2c 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -39,11 +39,6 @@ function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) end -function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi) -end - """ extract_priors([rng::Random.AbstractRNG, ]model::Model) diff --git a/src/model.jl b/src/model.jl index 3601d77fd..0fb18f463 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstactContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 8c18163e3..cb9ea4894 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -100,52 +100,6 @@ function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, v return left, acclogp!!(vi, logp) end -function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return dot_tilde_observe!!(context.context, right, left, vi) -end -function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return dot_tilde_observe!!(context.context, right, left, vn, vi) - end - - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. - - # We want to treat `.~` as a collection of independent observations, - # hence we need the `logp` for each of them. Broadcasting the univariate - # `tilde_observe` does exactly this. - logps = _pointwise_tilde_observe(context.context, right, left, vi) - - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - _, _, vns = unwrap_right_left_vns(right, left, vn) - for (vn, logp) in zip(vns, logps) - # Track loglikelihood value. - push!(context, vn, logp) - end - - return left, acclogp!!(vi, sum(logps)) -end - -# FIXME: This is really not a good approach since it needs to stay in sync with -# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. -function _pointwise_tilde_observe(context, right, left, vi) - # We need to drop the `vi` returned. - return broadcast(right, left) do r, l - return first(tilde_observe(context, r, l, vi)) - end -end - -function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo -) - # We need to drop the `vi` returned. - return map(eachcol(left)) do l - return first(tilde_observe(context, right, l, vi)) - end -end - # Note on submodels (penelopeysm) # # We don't need to overload tilde_observe!! for Sampleables (yet), because it @@ -174,44 +128,6 @@ function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) return value, acclogp!!(vi, logp) end -function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) - !_include_prior(context) && - return (dot_tilde_assume!!(context.context, right, left, vns, vi)) - value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) - # Track loglikelihood values. - for (vn, logp) in zip(vns, logps) - push!(context, vn, logp) - end - return value, acclogp!!(vi, sum(logps)) -end - -function _pointwise_tilde_assume(context, right, left, vns, vi) - # We need to drop the `vi` returned. - values_and_logps = broadcast(right, left, vns) do r, l, vn - # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated - # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, - # and b) even if the variables aren't stored in the vi correctly, we're not going to use - # this vi for anything downstream anyways, i.e. I don't see a case where this would matter - # for this particular use case. - val, logp, _ = tilde_assume(context, r, vn, vi) - return val, logp - end - return map(first, values_and_logps), map(last, values_and_logps) -end -function _pointwise_tilde_assume( - context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi -) - # We need to drop the `vi` returned. - values_and_logps = map(eachcol(left), vns) do l, vn - val, logp, _ = tilde_assume(context, right, vn, vi) - return val, logp - end - # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. - # But this also means that we need to first flatten the entire `values` component before recombining. - values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) - return values, map(last, values_and_logps) -end - """ pointwise_logdensities(model::Model, chain::Chains, keytype = String) @@ -357,7 +273,7 @@ end """ pointwise_loglikelihoods(model, chain[, keytype, context]) - + Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the likelihood terms. diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 00d6b3437..173eaa9e1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -471,57 +471,6 @@ function assume( return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::SimpleOrThreadSafeSimple, -) - f = (vn, dist) -> init(rng, dist, spl) - value = f.(vns, dists) - - # Transform if we're working in transformed space. - value_raw = if dists isa Distribution - to_maybe_linked_internal.((vi,), vns, (dists,), value) - else - to_maybe_linked_internal.((vi,), vns, dists, value) - end - - # Update `vi` - vi = BangBang.setindex!!(vi, value_raw, vns) - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, istrans.((vi,), vns))) - return value, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - vns::AbstractVector{<:VarName}, - var::AbstractMatrix, - vi::SimpleOrThreadSafeSimple, -) - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - - # r = get_and_set_val!(rng, vi, vns, dist, spl) - n = length(vns) - value = init(rng, dist, spl, n) - - # Update `vi`. - for (vn, val) in zip(vns, eachcol(value)) - val_linked = to_maybe_linked_internal(vi, vn, dist, val) - vi = BangBang.setindex!!(vi, val_linked, vn) - end - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi))) - return value, lp, vi -end - # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 93bb02d3b..5150be64b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -26,22 +26,10 @@ function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, v value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) return value, logp * context.mod, vi end -function DynamicPPL.dot_tilde_assume( - context::TestLogModifyingChildContext, right, left, vn, vi -) - value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) - return value, logp * context.mod, vi -end function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) return logp * context.mod, vi end -function DynamicPPL.dot_tilde_observe( - context::TestLogModifyingChildContext, right, left, vi -) - logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi -end # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index c506e1ba3..e29614982 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -186,31 +186,29 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp end -@model function demo_dot_assume_dot_observe( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} -) where {TV} +@model function demo_dot_assume_observe(x=[1.5, 2.0], ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` s = TV(undef, length(x)) m = TV(undef, length(x)) s .~ InverseGamma(2, 3) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_dot_observe)}, s, m + model::Model{typeof(demo_dot_assume_observe)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] +function varnames(model::Model{typeof(demo_dot_assume_observe)}) + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function demo_assume_index_observe( @@ -276,7 +274,7 @@ end s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) for i in eachindex(x) x[i] ~ Normal(m[i], sqrt(s[i])) end @@ -295,7 +293,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -355,7 +353,7 @@ end s = TV(undef, 2) m = TV(undef, 2) s .~ InverseGamma(2, 3) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) @@ -376,7 +374,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function demo_assume_observe_literal() @@ -431,7 +429,7 @@ end s = TV(undef, 2) s .~ InverseGamma(2, 3) m = TV(undef, 2) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) return s, m end @@ -460,7 +458,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function _likelihood_multivariate_observe(s, m, x) @@ -473,7 +471,7 @@ end s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to @@ -494,76 +492,39 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end -@model function demo_dot_assume_dot_observe_matrix( +@model function demo_dot_assume_observe_matrix_index( x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s)) + x[:, 1] ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) - return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) -end -function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m -) - return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) -end -function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] -end - -@model function demo_dot_assume_matrix_dot_observe_matrix( - x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} -) where {TV} - n = length(x) - d = length(x) ÷ 2 - s = TV(undef, d, 2) - s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) - s_vec = vec(s) - m ~ MvNormal(zeros(n), Diagonal(s_vec)) - - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s_vec)) - - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) -end -function logprior_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m -) - n = length(model.args.x) - s_vec = vec(s) - return loglikelihood(InverseGamma(2, 3), s_vec) + - logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) -end function loglikelihood_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m ) - return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - s = zeros(1, 2) # used for varname concretization only - return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] +function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)}) + return [@varname(s[1]), @varname(s[2]), @varname(m)] end -@model function demo_assume_matrix_dot_observe_matrix( +@model function demo_assume_matrix_observe_matrix_index( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} n = length(x) @@ -572,33 +533,32 @@ end s_vec = vec(s) m ~ MvNormal(zeros(n), Diagonal(s_vec)) - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s_vec)) + x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) s_vec = vec(s) return loglikelihood(InverseGamma(2, 3), s_vec) + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) end function loglikelihood_true( - model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m ) return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end const DemoModels = Union{ - Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_dot_assume_observe)}, Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, @@ -609,9 +569,8 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, - Model{typeof(demo_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_observe_matrix_index)}, + Model{typeof(demo_assume_matrix_observe_matrix_index)}, } const UnivariateAssumeDemoModels = Union{ @@ -637,7 +596,7 @@ function rand_prior_true(rng::Random.AbstractRNG, model::UnivariateAssumeDemoMod end const MultivariateAssumeDemoModels = Union{ - Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_dot_assume_observe)}, Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, @@ -645,8 +604,7 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_observe_matrix_index)}, } function posterior_mean(model::MultivariateAssumeDemoModels) # Get some containers to fill. @@ -699,7 +657,7 @@ function rand_prior_true(rng::Random.AbstractRNG, model::MultivariateAssumeDemoM end const MatrixvariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_matrix_dot_observe_matrix)} + Model{typeof(demo_assume_matrix_observe_matrix_index)} } function posterior_mean(model::MatrixvariateAssumeDemoModels) # Get some containers to fill. @@ -786,7 +744,7 @@ And for the multivariate one (the latter one): """ const DEMO_MODELS = ( - demo_dot_assume_dot_observe(), + demo_dot_assume_observe(), demo_assume_index_observe(), demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), @@ -797,7 +755,6 @@ const DEMO_MODELS = ( demo_assume_observe_literal(), demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), - demo_dot_assume_dot_observe_matrix(), - demo_dot_assume_matrix_dot_observe_matrix(), - demo_assume_matrix_dot_observe_matrix(), + demo_dot_assume_observe_matrix_index(), + demo_assume_matrix_observe_matrix_index(), ) diff --git a/src/transforming.jl b/src/transforming.jl index 1a26d212f..0239725ae 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -30,67 +30,6 @@ function tilde_assume( return r, lp, setindex!!(vi, r_transformed, vn) end -function dot_tilde_assume( - ::DynamicTransformationContext{isinverse}, - dist::Distribution, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi, -) where {isinverse} - r = getindex.((vi,), vns, (dist,)) - b = link_transform(dist) - - is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" - is_trans = first(is_trans_uniques) - if is_trans - @assert isinverse "Trying to link already transformed variables" - else - @assert !isinverse "Trying to invlink non-transformed variables" - end - - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : b.(r) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, (!isinverse,))) - return r, lp, setindex!!(vi, r_transformed, vns) -end - -function dot_tilde_assume( - ::DynamicTransformationContext{isinverse}, - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, -) where {isinverse} - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - r = vi[vns, dist] - - # Compute `logpdf` with logabsdet-jacobian correction. - lp = sum(zip(vns, eachcol(r))) do (vn, ri) - return Bijectors.logpdf_with_trans(dist, ri, !isinverse) - end - - # Transform _all_ values. - is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" - is_trans = first(is_trans_uniques) - if is_trans - @assert isinverse "Trying to link already transformed variables" - else - @assert !isinverse "Trying to invlink non-transformed variables" - end - - b = link_transform(dist) - for (vn, ri) in zip(vns, eachcol(r)) - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - vi = setindex!!(vi, isinverse ? ri : b(ri), vn) - end - - return r, lp, vi -end - function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4cef5fa4e..d3bfd697a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -90,29 +90,6 @@ function tilde_assume( return value, logp, vi end -# `dot_tilde_assume` -function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi) - value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) - - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi -) - value, logp, vi = dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, left, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end - """ values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) diff --git a/test/compat/ad.jl b/test/compat/ad.jl index f76ce6f6e..e6b23f379 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -26,32 +26,4 @@ test_model_ad(wishart_ad(), logp_wishart_ad) end - - # https://github.com/TuringLang/Turing.jl/issues/1595 - @testset "dot_observe" begin - function f_dot_observe(x) - logp, _ = DynamicPPL.dot_observe( - SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo() - ) - return logp - end - function f_dot_observe_manual(x) - return logpdf(Normal(), x[1]) + logpdf(Normal(-1.0, 2.0), x[2]) - end - - # Manual computation of the gradient. - x = randn(2) - val = f_dot_observe_manual(x) - grad = ForwardDiff.gradient(f_dot_observe_manual, x) - - @test ForwardDiff.gradient(f_dot_observe, x) ≈ grad - - y, back = Tracker.forward(f_dot_observe, x) - @test Tracker.data(y) ≈ val - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(f_dot_observe, x) - @test y ≈ val - @test back(1)[1] ≈ grad - end end diff --git a/test/compiler.jl b/test/compiler.jl index 051eba618..8d81c530a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -288,6 +288,33 @@ module Issue537 end x = vdemo()() @test all((isassigned(x, i) for i in eachindex(x))) end + + # A couple of uses of .~ that are no longer valid as of v0.35. + @testset "old .~ syntax" begin + @model function multivariate_dot_tilde() + x = Vector{Float64}(undef, 2) + x .~ MvNormal(zeros(2), I) + return x + end + expected_error = ArgumentError( + "the right-hand side of a `.~` must be a `UnivariateDistribution`" + ) + @test_throws expected_error (multivariate_dot_tilde()(); true) + + @model function vector_dot_tilde() + x = Vector{Float64}(undef, 2) + x .~ [Normal(), Normal()] + return x + end + expected_error = ArgumentError(""" + As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ + Please use `product_distribution` instead, or write a loop if necessary. \ + See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ + details.\ + """) + @test_throws expected_error (vector_dot_tilde()(); true) + end + @testset "nested model" begin function makemodel(p) @model function testmodel(x) diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 8a795320d..0ec88c07c 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -4,7 +4,7 @@ @model function test(x) μ ~ MvNormal(zeros(2), 4 * I) z = Vector{Int}(undef, length(x)) - z .~ Categorical.(fill([0.5, 0.5], length(x))) + z ~ product_distribution(Categorical.(fill([0.5, 0.5], length(x)))) for i in 1:length(x) x[i] ~ Normal(μ[z[i]], 0.1) end @@ -13,59 +13,36 @@ test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) end - # https://github.com/TuringLang/DynamicPPL.jl/issues/28#issuecomment-829223577 - @testset "dot tilde: arrays of distributions" begin + @testset "dot tilde with varying sizes" begin @testset "assume" begin @model function test(x, size) y = Array{Float64,length(size)}(undef, size...) - y .~ Normal.(x) + y .~ Normal(x) return y, getlogp(__varinfo__) end for ysize in ((2,), (2, 3), (2, 3, 4)) - for x in ( - # scalar - randn(), - # drop trailing dimensions - ntuple(i -> randn(ysize[1:i]), length(ysize))..., - # singleton dimensions - ntuple( - i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), - length(ysize), - )..., - ) - model = test(x, ysize) - y, lp = model() - @test lp ≈ sum(logpdf.(Normal.(x), y)) + x = randn() + model = test(x, ysize) + y, lp = model() + @test lp ≈ sum(logpdf.(Normal.(x), y)) - ys = [first(model()) for _ in 1:10_000] - @test norm(mean(ys) .- x, Inf) < 0.1 - @test norm(std(ys) .- 1, Inf) < 0.1 - end + ys = [first(model()) for _ in 1:10_000] + @test norm(mean(ys) .- x, Inf) < 0.1 + @test norm(std(ys) .- 1, Inf) < 0.1 end end @testset "observe" begin @model function test(x, y) - return y .~ Normal.(x) + return y .~ Normal(x) end for ysize in ((2,), (2, 3), (2, 3, 4)) - for x in ( - # scalar - randn(), - # drop trailing dimensions - ntuple(i -> randn(ysize[1:i]), length(ysize))..., - # singleton dimensions - ntuple( - i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), - length(ysize), - )..., - ) - y = randn(ysize) - z = logjoint(test(x, y), VarInfo()) - @test z ≈ sum(logpdf.(Normal.(x), y)) - end + x = randn() + y = randn(ysize) + z = logjoint(test(x, y), VarInfo()) + @test z ≈ sum(logpdf.(Normal.(x), y)) end end end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 5c0b2e090..61c842638 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -48,8 +48,8 @@ end @testset "pointwise_logdensities chain" begin # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just - # to ensure that we don't accidentally break the the version on `Chains`. - model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() + # to ensure that we don't accidentally break the version on `Chains`. + model = DynamicPPL.TestUtils.demo_assume_index_observe() # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced # an impl of this for containers. # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 137c791c2..e67b5656a 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -139,8 +139,6 @@ @testset "SimpleVarInfo on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix() - # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) @@ -155,9 +153,10 @@ svi_nt, svi_dict, svi_vnv, - DynamicPPL.settrans!!(deepcopy(svi_nt), true), - DynamicPPL.settrans!!(deepcopy(svi_dict), true), - DynamicPPL.settrans!!(deepcopy(svi_vnv), true), + # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. + # DynamicPPL.settrans!!(deepcopy(svi_nt), true), + # DynamicPPL.settrans!!(deepcopy(svi_dict), true), + # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) # RandOM seed is set in each `@testset`, so we need to sample # a new realization for `m` here. From 90c7b26c852b7d0ae87dee0a5a0010b097a0c1d3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 19 Feb 2025 18:36:21 +0000 Subject: [PATCH 08/14] Remove LogDensityProblemsAD; wrap adtype in LogDensityFunction (#806) * Remove LogDensityProblemsAD * Implement LogDensityFunctionWithGrad in place of ADgradient * Dynamically decide whether to use closure vs constant * Combine LogDensityFunction{,WithGrad} into one (#811) * Warn if unsupported AD type is used * Update changelog * Update DI compat bound Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> * Don't store with_closure inside LogDensityFunction Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> * setadtype --> LogDensityFunction * Re-add ForwardDiffExt (including tests) * Add more tests for coverage --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- HISTORY.md | 48 ++++- Project.toml | 7 +- docs/src/api.md | 2 +- ext/DynamicPPLForwardDiffExt.jl | 72 +++---- src/DynamicPPL.jl | 1 - src/contexts.jl | 1 + src/logdensityfunction.jl | 300 ++++++++++++++++++++------- test/Project.toml | 2 - test/ad.jl | 97 ++++++--- test/ext/DynamicPPLForwardDiffExt.jl | 42 ++-- test/logdensityfunction.jl | 24 +-- test/runtests.jl | 2 +- test/test_util.jl | 9 + 13 files changed, 420 insertions(+), 187 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index fa9e58e99..3f999ccab 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,7 +2,7 @@ ## 0.35.0 -**Breaking** +**Breaking changes** ### `.~` right hand side must be a univariate distribution @@ -119,6 +119,52 @@ This release removes the feature of `VarInfo` where it kept track of which varia This change also affects sampling in Turing.jl. +### `LogDensityFunction` argument order + + - The method `LogDensityFunction(varinfo, model, context)` has been removed. + The only accepted order is `LogDensityFunction(model, varinfo, context; adtype)`. + (For an explanation of `adtype`, see below.) + The varinfo and context arguments are both still optional. + +**Other changes** + +### `LogDensityProblems` interface + +LogDensityProblemsAD is now removed as a dependency. +Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters. + +Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this). + +However, in this version, `LogDensityFunction` now takes an extra AD type argument. +If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient. +However, if you do pass an AD type, that will allow you to calculate the gradient as well. +You may thus find that it is easier to instead do this: + +```julia +@model f() = ... + +ldf = LogDensityFunction(f(); adtype=AutoForwardDiff()) +``` + +This will return an object which satisfies the `LogDensityProblems` interface to first-order, i.e. you can now directly call both + +``` +LogDensityProblems.logdensity(ldf, params) +LogDensityProblems.logdensity_and_gradient(ldf, params) +``` + +without having to construct a separate `ADgradient` object. + +If you prefer, you can also construct a new `LogDensityFunction` with a new AD type afterwards. +The model, varinfo, and context will be taken from the original `LogDensityFunction`: + +```julia +@model f() = ... + +ldf = LogDensityFunction(f()) # by default, no adtype set +ldf_with_ad = LogDensityFunction(ldf, AutoForwardDiff()) +``` + ## 0.34.2 - Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. diff --git a/Project.toml b/Project.toml index 7cd47fdbb..26ab45425 100644 --- a/Project.toml +++ b/Project.toml @@ -12,13 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -51,15 +51,14 @@ Bijectors = "0.13.18, 0.14, 0.15" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" +DifferentiationInterface = "0.6.41" Distributions = "0.25" DocStringExtensions = "0.9" -KernelAbstractions = "0.9.33" EnzymeCore = "0.6 - 0.8" -ForwardDiff = "0.10" JET = "0.9" +KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" -LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" Mooncake = "0.4.95" diff --git a/docs/src/api.md b/docs/src/api.md index 4d3c6bc97..60bdc05d9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -54,7 +54,7 @@ logjoint ### LogDensityProblems.jl interface -The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`: +The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`. ```@docs DynamicPPL.LogDensityFunction diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 4bc33e217..6bd7a5d94 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -1,54 +1,40 @@ module DynamicPPLForwardDiffExt -if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ForwardDiff -else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ..ForwardDiff -end - -getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk - -standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true -standardtag(::ADTypes.AutoForwardDiff) = false - -function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction -) - θ = DynamicPPL.getparams(ℓ) - f = Base.Fix1(LogDensityProblems.logdensity, ℓ) - - # Define configuration for ForwardDiff. - tag = if standardtag(ad) - ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ)) +using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems +using ForwardDiff + +# check if the AD type already has a tag +use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true +use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false + +function DynamicPPL.tweak_adtype( + ad::ADTypes.AutoForwardDiff{chunk_size}, + ::DynamicPPL.Model, + vi::DynamicPPL.AbstractVarInfo, + ::DynamicPPL.AbstractContext, +) where {chunk_size} + params = vi[:] + + # Use DynamicPPL tag to improve stack traces + # https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ + # NOTE: DifferentiationInterface disables tag checking if the + # tag inside the AutoForwardDiff type is not nothing. See + # https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350. + # So we don't currently need to override ForwardDiff.checktag as well. + tag = if use_dynamicppl_tag(ad) + ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params)) else - ForwardDiff.Tag(f, eltype(θ)) + ad.tag end - chunk_size = getchunksize(ad) + + # Optimise chunk size according to size of model chunk = if chunk_size == 0 || chunk_size === nothing - ForwardDiff.Chunk(θ) + ForwardDiff.Chunk(params) else - ForwardDiff.Chunk(length(θ), chunk_size) + ForwardDiff.Chunk(length(params), chunk_size) end - return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ) -end - -# Allow Turing tag in gradient etc. calls of the log density function -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::DynamicPPL.LogDensityFunction, - ::AbstractArray{W}, -) where {V,W} - return true -end -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, - ::AbstractArray{W}, -) where {V,W} - return true + return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag) end end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f0d42f98c..b413017cf 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -14,7 +14,6 @@ using MacroTools: MacroTools using ConstructionBase: ConstructionBase using Accessors: Accessors using LogDensityProblems: LogDensityProblems -using LogDensityProblemsAD: LogDensityProblemsAD using LinearAlgebra: LinearAlgebra, Cholesky diff --git a/src/contexts.jl b/src/contexts.jl index 0b4633283..87ad8df0b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -184,6 +184,7 @@ at which point it will return the sampler of that context. getsampler(context::SamplingContext) = context.sampler getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) +getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") """ struct DefaultContext <: AbstractContext end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 29f591cc3..a42855f05 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,12 +1,53 @@ +import DifferentiationInterface as DI + """ - LogDensityFunction + is_supported(adtype::AbstractADType) + +Check if the given AD type is formally supported by DynamicPPL. + +AD backends that are not formally supported can still be used for gradient +calculation; it is just that the DynamicPPL developers do not commit to +maintaining compatibility with them. +""" +is_supported(::ADTypes.AbstractADType) = false +is_supported(::ADTypes.AutoForwardDiff) = true +is_supported(::ADTypes.AutoMooncake) = true +is_supported(::ADTypes.AutoReverseDiff) = true + +""" + LogDensityFunction( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=DefaultContext(); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at + that point. -A callable representing a log density function of a `model`. +At its most basic level, a LogDensityFunction wraps the model together with its +the type of varinfo to be used, as well as the evaluation context. These must +be known in order to calculate the log density (using +[`DynamicPPL.evaluate!!`](@ref)). + +If the `adtype` keyword argument is provided, then this struct will also store +the adtype along with other information for efficient calculation of the +gradient of the log density. Note that preparing a `LogDensityFunction` with an +AD type `AutoBackend()` requires the AD backend itself to have been loaded +(e.g. with `import Backend`). + +`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. +If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a +concrete AD backend type, then `logdensity_and_gradient` is also implemented. # Fields $(FIELDS) # Examples + ```jldoctest julia> using Distributions @@ -42,116 +83,221 @@ julia> # This also respects the context in `model`. julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true + +julia> # If we also need to calculate the gradient, we can specify an AD backend. + import ForwardDiff, ADTypes + +julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); + +julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) +(-2.3378770664093453, [1.0]) ``` """ -struct LogDensityFunction{V,M,C} - "varinfo used for evaluation" - varinfo::V +struct LogDensityFunction{ + M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} +} "model used for evaluation" model::M + "varinfo used for evaluation" + varinfo::V "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C + "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" + adtype::AD + "(internal use only) gradient preparation object for the model" + prep::Union{Nothing,DI.GradientPrep} + + function LogDensityFunction( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=leafcontext(model.context); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + if adtype === nothing + prep = nothing + else + # Make backend-specific tweaks to the adtype + adtype = tweak_adtype(adtype, model, varinfo, context) + # Check whether it is supported + is_supported(adtype) || + @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." + # Get a set of dummy params to use for prep + x = map(identity, varinfo[:]) + if use_closure(adtype) + prep = DI.prepare_gradient( + x -> logdensity_at(x, model, varinfo, context), adtype, x + ) + else + prep = DI.prepare_gradient( + logdensity_at, + adtype, + x, + DI.Constant(model), + DI.Constant(varinfo), + DI.Constant(context), + ) + end + end + return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( + model, varinfo, context, adtype, prep + ) + end end -# TODO: Deprecate. +""" + LogDensityFunction( + ldf::LogDensityFunction, + adtype::Union{Nothing,ADTypes.AbstractADType} + ) + +Create a new LogDensityFunction using the model, varinfo, and context from the given +`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass +`nothing` as the second argument. +""" function LogDensityFunction( - varinfo::AbstractVarInfo, - model::Model, - sampler::AbstractSampler, - context::AbstractContext, + f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} ) - return LogDensityFunction(varinfo, model, SamplingContext(sampler, context)) + return if adtype === f.adtype + f # Avoid recomputing prep if not needed + else + LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) + end end -function LogDensityFunction( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::Union{Nothing,AbstractContext}=nothing, +""" + logdensity_at( + x::AbstractVector, + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext + ) + +Evaluate the log density of the given `model` at the given parameter values `x`, +using the given `varinfo` and `context`. Note that the `varinfo` argument is provided +only for its structure, in the sense that the parameters from the vector `x` are inserted into +it, and its own parameters are discarded. +""" +function logdensity_at( + x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) - return LogDensityFunction(varinfo, model, context) + varinfo_new = unflatten(varinfo, x) + return getlogp(last(evaluate!!(model, varinfo_new, context))) end -# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. -function getcontext(f::LogDensityFunction) - return f.context === nothing ? leafcontext(f.model.context) : f.context +### LogDensityProblems interface + +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,V,C,Nothing}} +) where {M,V,C} + return LogDensityProblems.LogDensityOrder{0}() +end +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,V,C,AD}} +) where {M,V,C,AD<:ADTypes.AbstractADType} + return LogDensityProblems.LogDensityOrder{1}() +end +function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) + return logdensity_at(x, f.model, f.varinfo, f.context) +end +function LogDensityProblems.logdensity_and_gradient( + f::LogDensityFunction{M,V,C,AD}, x::AbstractVector +) where {M,V,C,AD<:ADTypes.AbstractADType} + f.prep === nothing && + error("Gradient preparation not available; this should not happen") + x = map(identity, x) # Concretise type + # Make branching statically inferrable, i.e. type-stable (even if the two + # branches happen to return different types) + return if use_closure(f.adtype) + DI.value_and_gradient( + x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x + ) + else + DI.value_and_gradient( + logdensity_at, + f.prep, + f.adtype, + x, + DI.Constant(f.model), + DI.Constant(f.varinfo), + DI.Constant(f.context), + ) + end end +# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? +LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) + +### Utils + +""" + tweak_adtype( + adtype::ADTypes.AbstractADType, + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext + ) + +Return an 'optimised' form of the adtype. This is useful for doing +backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating +the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`). +The model is passed as a parameter in case the optimisation depends on the +model. + +By default, this just returns the input unchanged. +""" +tweak_adtype( + adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext +) = adtype + +""" + use_closure(adtype::ADTypes.AbstractADType) + +In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) +with respect to x, where f is the model (in our case LogDensityFunction) and is +a constant. However, DifferentiationInterface generally expects a +single-argument function g(x) to differentiate. + +There are two ways of dealing with this: + +1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) + +2. Use a constant context. This lets us pass a two-argument function to + DifferentiationInterface, as long as we also give it the 'inactive argument' + (i.e. the model) wrapped in `DI.Constant`. + +The relative performance of the two approaches, however, depends on the AD +backend used. Some benchmarks are provided here: +https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480 + +This function is used to determine whether a given AD backend should use a +closure or a constant. If `use_closure(adtype)` returns `true`, then the +closure approach will be used. By default, this function returns `false`, i.e. +the constant approach will be used. +""" +use_closure(::ADTypes.AbstractADType) = false +use_closure(::ADTypes.AutoForwardDiff) = false +use_closure(::ADTypes.AutoMooncake) = false +use_closure(::ADTypes.AutoReverseDiff) = true + """ getmodel(f) Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. """ -getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = - getmodel(LogDensityProblemsAD.parent(f)) getmodel(f::DynamicPPL.LogDensityFunction) = f.model """ setmodel(f, model[, adtype]) Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. - -!!! warning - Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a - `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` - might require recompilation of the gradient tape, depending on the AD backend. """ -function setmodel( - f::LogDensityProblemsAD.ADGradientWrapper, - model::DynamicPPL.Model, - adtype::ADTypes.AbstractADType, -) - # TODO: Should we handle `SciMLBase.NoAD`? - # For an `ADGradientWrapper` we do the following: - # 1. Update the `Model` in the underlying `LogDensityFunction`. - # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype` - # to ensure that the recompilation of gradient tapes, etc. also occur. For example, - # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just - # replacing the corresponding field with the new model won't be sufficient to obtain - # the correct gradients. - return LogDensityProblemsAD.ADgradient( - adtype, setmodel(LogDensityProblemsAD.parent(f), model) - ) -end function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return Accessors.@set f.model = model + return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) end -# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time -# we need to define these annoying methods to ensure that we stay compatible with everything. -getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) -hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) - """ getparams(f::LogDensityFunction) Return the parameters of the wrapped varinfo as a vector. """ getparams(f::LogDensityFunction) = f.varinfo[:] - -# LogDensityProblems interface -function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) - context = getcontext(f) - vi_new = unflatten(f.varinfo, θ) - return getlogp(last(evaluate!!(f.model, vi_new, context))) -end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) - return LogDensityProblems.LogDensityOrder{0}() -end -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -# This is important for performance -- one needs to provide `ADGradient` with a vector of -# parameters, or DifferentiationInterface will not have sufficient information to e.g. -# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate -# a tape when using ReverseDiff.jl. -function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) - x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params - return LogDensityProblemsAD.ADgradient(ad, ℓ; x) -end - -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) - return _make_ad_gradient(ad, f) -end -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) - return _make_ad_gradient(ad, f) -end diff --git a/test/Project.toml b/test/Project.toml index c7583c672..420edba94 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,7 +16,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -46,7 +45,6 @@ EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12" JET = "0.9" LogDensityProblems = "2" -LogDensityProblemsAD = "1.7.0" MCMCChains = "6.0.4" MacroTools = "0.5.6" Mooncake = "0.4.59" diff --git a/test/ad.jl b/test/ad.jl index 87c7f22c3..73519c3f5 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,35 +1,63 @@ -@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - f = DynamicPPL.LogDensityFunction(m) - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) +using DynamicPPL: LogDensityFunction - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = DynamicPPL.LogDensityFunction(m, varinfo) +@testset "Automatic differentiation" begin + @testset "Unsupported backends" begin + @model demo() = x ~ Normal() + @test_logs (:warn, r"not officially supported") LogDensityFunction( + demo(); adtype=AutoZygote() + ) + end + + @testset "Correctness: ForwardDiff, ReverseDiff, and Mooncake" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) + vns = DynamicPPL.TestUtils.varnames(m) + varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + f = LogDensityFunction(m, varinfo) + x = DynamicPPL.getparams(f) + # Calculate reference logp + gradient of logp using ForwardDiff + ref_adtype = ADTypes.AutoForwardDiff() + ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) + ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + + @testset "$adtype" for adtype in [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" - # use ForwardDiff result as reference - ad_forwarddiff_f = LogDensityProblemsAD.ADgradient( - ADTypes.AutoForwardDiff(; chunksize=0), f - ) - # convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0 - # reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489 - θ = convert(Vector{Float64}, varinfo[:]) - logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) + # Put predicates here to avoid long lines + is_mooncake = adtype isa AutoMooncake + is_1_10 = v"1.10" <= VERSION < v"1.11" + is_1_11 = v"1.11" <= VERSION < v"1.12" + is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} - @testset "$adtype" for adtype in [ - ADTypes.AutoReverseDiff(; compile=false), - ADTypes.AutoReverseDiff(; compile=true), - ADTypes.AutoMooncake(; config=nothing), - ] - # Mooncake can't currently handle something that is going on in - # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. - if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo - @test_broken 1 == 0 - else - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) - @test grad ≈ ref_grad + # Mooncake doesn't work with several combinations of SimpleVarInfo. + if is_mooncake && is_1_11 && is_svi_vnv + # https://github.com/compintell/Mooncake.jl/issues/470 + @test_throws ArgumentError DynamicPPL.LogDensityFunction( + ref_ldf, adtype + ) + elseif is_mooncake && is_1_10 && is_svi_vnv + # TODO: report upstream + @test_throws UndefRefError DynamicPPL.LogDensityFunction( + ref_ldf, adtype + ) + elseif is_mooncake && is_1_10 && is_svi_od + # TODO: report upstream + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( + ref_ldf, adtype + ) + else + ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype) + logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) + @test grad ≈ ref_grad + @test logp ≈ ref_logp + end end end end @@ -63,13 +91,16 @@ # overload assume so that model evaluation doesn't fail due to a lack # of implementation struct MyEmptyAlg end - DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = - DynamicPPL.assume(dist, vn, vi) + DynamicPPL.assume( + ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi + ) = DynamicPPL.assume(dist, vn, vi) # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) - ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) - @test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any + ldf = LogDensityFunction( + model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) + ) + @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 8de28046b..73a0510e9 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -1,14 +1,32 @@ -@testset "tag" begin - for chunksize in (nothing, 0, 1, 10) - ad = ADTypes.AutoForwardDiff(; chunksize=chunksize) - standardtag = if !isdefined(Base, :get_extension) - DynamicPPL.DynamicPPLForwardDiffExt.standardtag - else - Base.get_extension(DynamicPPL, :DynamicPPLForwardDiffExt).standardtag - end - @test standardtag(ad) - for tag in (false, 0, 1) - @test !standardtag(AutoForwardDiff(; chunksize=chunksize, tag=tag)) - end +module DynamicPPLForwardDiffExtTests + +using DynamicPPL +using ADTypes: AutoForwardDiff +using ForwardDiff: ForwardDiff +using Distributions: MvNormal +using LinearAlgebra: I +using Test: @test, @testset + +# get_chunksize(ad::AutoForwardDiff{chunk}) where {chunk} = chunk + +@testset "ForwardDiff tweak_adtype" begin + MODEL_SIZE = 10 + @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) + model = f() + varinfo = VarInfo(model) + context = DefaultContext() + + @testset "Chunk size setting" for chunksize in (nothing, 0) + base_adtype = AutoForwardDiff(; chunksize=chunksize) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + @test new_adtype isa AutoForwardDiff{MODEL_SIZE} end + + @testset "Tag setting" begin + base_adtype = AutoForwardDiff() + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} + end +end + end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index beda767e6..d6e66ec59 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff +using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff @testset "`getmodel` and `setmodel`" begin @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS @@ -6,17 +6,6 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever ℓ = DynamicPPL.LogDensityFunction(model) @test DynamicPPL.getmodel(ℓ) == model @test DynamicPPL.setmodel(ℓ, model).model == model - - # ReverseDiff related - ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) - @test DynamicPPL.getmodel(∇ℓ) == model - @test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) == - model - ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) - new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff()) - @test DynamicPPL.getmodel(new_∇ℓ) == model - # HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape` - @test new_∇ℓ.compiledtape != ∇ℓ.compiledtape end end @@ -33,4 +22,15 @@ end @test LogDensityProblems.dimension(logdensity) == length(θ) end end + + @testset "capabilities" begin + model = DynamicPPL.TestUtils.DEMO_MODELS[1] + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.capabilities(typeof(ldf)) == + LogDensityProblems.LogDensityOrder{0}() + + ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) + @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == + LogDensityProblems.LogDensityOrder{1}() + end end diff --git a/test/runtests.jl b/test/runtests.jl index 25cd2fb40..caddef5f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,7 @@ using Distributions using DistributionsAD using Documenter using ForwardDiff -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using MacroTools using MCMCChains using Mooncake: Mooncake diff --git a/test/test_util.jl b/test/test_util.jl index 27a68456c..d831a5ea6 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -56,6 +56,15 @@ function short_varinfo_name(vi::TypedVarInfo) end short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" +function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) + return "SimpleVarInfo{<:NamedTuple,<:Ref}" +end +function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) + return "SimpleVarInfo{<:OrderedDict,<:Ref}" +end +function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) + return "SimpleVarInfo{<:VarNamedVector,<:Ref}" +end short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) From a765af3bc38009f70d61dc1b0138ebda8646ff62 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 26 Feb 2025 16:32:56 +0000 Subject: [PATCH 09/14] Export `LogDensityFunction` (#820) * Export LogDensityFunction * Remove qualifier in docs --- docs/src/api.md | 2 +- src/DynamicPPL.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 60bdc05d9..35baa558c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -57,7 +57,7 @@ logjoint The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`. ```@docs -DynamicPPL.LogDensityFunction +LogDensityFunction ``` ## Condition and decondition diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b413017cf..ed0803b25 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -88,6 +88,8 @@ export AbstractVarInfo, Sampler, SampleFromPrior, SampleFromUniform, + # LogDensityFunction + LogDensityFunction, # Contexts SamplingContext, DefaultContext, From 1dec7a177cc8d6288e2b55cfc1c437f7600f7a01 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:47:34 +0000 Subject: [PATCH 10/14] Export `predict` with 0.35 (#821) * export `predict` * use mcmcchains when build doc * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Add MCMCChainsExt to docs build (#822) --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Penelope Yong --- docs/make.jl | 8 ++++++-- docs/src/api.md | 26 ++++++++++++++++++++++++++ src/DynamicPPL.jl | 1 + 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 66cc690f0..c69b72fb8 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,16 +9,20 @@ using DynamicPPL: AbstractPPL # consistent with that. using Distributions using DocumenterMermaid +# load MCMCChains package extension to make `predict` available +using MCMCChains # Doctest setup -DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true) +DocMeta.setdocmeta!( + DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true +) makedocs(; sitename="DynamicPPL", # The API index.html page is fairly large, and violates the default HTML page size # threshold of 200KiB, so we double that. format=Documenter.HTML(; size_threshold=2^10 * 400), - modules=[DynamicPPL], + modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] ], diff --git a/docs/src/api.md b/docs/src/api.md index 35baa558c..9c8249c97 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -104,6 +104,32 @@ Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original unfix ``` +## Predicting + +DynamicPPL provides functionality for generating samples from the posterior predictive distribution through the `predict` function. This allows you to use posterior parameter samples to generate predictions for unobserved data points. + +The `predict` function has two main methods: + + 1. For `AbstractVector{<:AbstractVarInfo}` - useful when you have a collection of `VarInfo` objects representing posterior samples. + 2. For `MCMCChains.Chains` (only available when `MCMCChains.jl` is loaded) - useful when you have posterior samples in the form of an `MCMCChains.Chains` object. + +```@docs +predict +``` + +### Basic Usage + +The typical workflow for posterior prediction involves: + + 1. Fitting a model to observed data to obtain posterior samples + 2. Creating a new model instance with some variables marked as missing (unobserved) + 3. Using `predict` to generate samples for these missing variables based on the posterior parameter samples + +When using `predict` with `MCMCChains.Chains`, you can control which variables are included in the output with the `include_all` parameter: + + - `include_all=false` (default): Include only newly predicted variables + - `include_all=true`: Include both parameters from the original chain and predicted variables + ## Models within models One can include models and call another model inside the model function with `left ~ to_submodel(model)`. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ed0803b25..50fe0edc7 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -115,6 +115,7 @@ export AbstractVarInfo, decondition, fix, unfix, + predict, prefix, returned, to_submodel, From ebece4bbd5143f8bc09db054a64390aafa40fb4d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 26 Feb 2025 20:56:38 +0000 Subject: [PATCH 11/14] Add note about exports to HISTORY.md --- HISTORY.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 3f999ccab..748b9d506 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -128,6 +128,10 @@ This release removes the feature of `VarInfo` where it kept track of which varia **Other changes** +### New exports + +`LogDensityFunction` and `predict` are now exported from DynamicPPL. + ### `LogDensityProblems` interface LogDensityProblemsAD is now removed as a dependency. From d2f6a8d19114c1fbe97ef01062511e608b69c0bc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 26 Feb 2025 20:57:20 +0000 Subject: [PATCH 12/14] Remove release-0.35 from list of tested CI branches CI will still run because it's a PR Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- .github/workflows/CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 23b35f468..ef90920f2 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -5,7 +5,6 @@ on: branches: - master - backport-* - - release-0.35 pull_request: merge_group: types: [checks_requested] From e6a42dd6b410d4ed60d853222c297a391cffe6d1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Feb 2025 10:52:05 +0000 Subject: [PATCH 13/14] Restore ForwardDiff compat bound, update test/Project bounds --- Project.toml | 1 + test/Project.toml | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 26ab45425..a9463a821 100644 --- a/Project.toml +++ b/Project.toml @@ -55,6 +55,7 @@ DifferentiationInterface = "0.6.41" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" +ForwardDiff = "0.10.12" JET = "0.9" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" diff --git a/test/Project.toml b/test/Project.toml index 420edba94..e0fbbb8c5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -37,7 +37,7 @@ Accessors = "0.1" Bijectors = "0.15.1" Combinatorics = "1" Compat = "4.3.0" -DifferentiationInterface = "0.6" +DifferentiationInterface = "0.6.41" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" @@ -47,10 +47,10 @@ JET = "0.9" LogDensityProblems = "2" MCMCChains = "6.0.4" MacroTools = "0.5.6" -Mooncake = "0.4.59" +Mooncake = "0.4.95" OrderedCollections = "1" ReverseDiff = "1" StableRNGs = "1" Tracker = "0.2.23" Zygote = "0.6" -julia = "1.6" +julia = "1.10" From 6fe46ee2815c9c0055bc3e20f185d94a7e0b86af Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Feb 2025 14:47:47 +0000 Subject: [PATCH 14/14] Update HISTORY.md Co-authored-by: Tor Erlend Fjelde --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 748b9d506..b36003965 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -121,7 +121,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia ### `LogDensityFunction` argument order - - The method `LogDensityFunction(varinfo, model, context)` has been removed. + - The method `LogDensityFunction(varinfo, model, sampler)` has been removed. The only accepted order is `LogDensityFunction(model, varinfo, context; adtype)`. (For an explanation of `adtype`, see below.) The varinfo and context arguments are both still optional.