Skip to content

Remove dot_tilde pipeline #804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,63 @@

**Breaking**

### `.~` right hand side must be a univariate distribution

Previously we allowed statements like

```julia
x .~ [Normal(), Gamma()]
```

where the right hand side of a `.~` was an array of distributions, and ones like

```julia
x .~ MvNormal(fill(0.0, 2), I)
```

where the right hand was a multivariate distribution.

These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as

```julia
x = Array{Float64,3}(undef, 2, 3, 4)
x .~ Normal()
```

The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read.

Cases where the dimension of the multivariate distribution or the array of distribution is the same as the dimension of the left hand side variable can be replaced with `product_distribution`. For example, instead of

```julia
x .~ [Normal(), Gamma()]
```

do

```julia
x ~ product_distribution([Normal(), Gamma()])
```

This is often more performant as well. Note that using a product distribution will change how a `VarInfo` views the variable: Instead of viewing each `x[i]` as a distinct univariate variable like with `.~`, with `x ~ product_distribution(...)` `x` will be viewed as a single multivariate variable. This was already the case before this release. If, for some reason, you _do_ want each `x[i]` independently in your `VarInfo`, you can always turn the `.~` statement into a loop.

Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example,

```julia
x = Array{Float64,3}(undef, 2, 3, 4)
x .~ MvNormal(fill(0, 2), I)
```

should be replaced with something like

```julia
x = Array{Float64,3}(2, 3, 4)
for i in 1:3, j in 1:4
x[:, i, j] ~ MvNormal(fill(0, 2), I)
end
```

This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side.

### Remove indexing by samplers

This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular,
Expand Down
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -35,7 +34,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
Expand All @@ -44,7 +42,6 @@ DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
ADTypes = "1"
Expand Down Expand Up @@ -74,5 +71,4 @@ OrderedCollections = "1"
Random = "1.6"
Requires = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"
2 changes: 0 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,8 @@ DynamicPPL.Experimental.is_suitable_varinfo

```@docs
tilde_assume
dot_tilde_assume
```

```@docs
tilde_observe
dot_tilde_observe
```
25 changes: 0 additions & 25 deletions ext/DynamicPPLZygoteRulesExt.jl

This file was deleted.

4 changes: 0 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,9 @@ export AbstractVarInfo,
PrefixContext,
ConditionContext,
assume,
dot_assume,
observe,
dot_observe,
tilde_assume,
tilde_observe,
dot_tilde_assume,
dot_tilde_observe,
# Pseudo distributions
NamedDist,
NoDist,
Expand Down
86 changes: 34 additions & 52 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,16 @@
"""
isliteral(e) = false
isliteral(::Number) = true
isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args)
function isliteral(e::Expr)
# In the special case that the expression is of the form `abc[blahblah]`, we consider it
# to be a literal if `abc` is a literal. This is necessary for cases like
# [1.0, 2.0][idx...] ~ Normal()
# which are generated when turning `.~` expressions into loops over `~` expressions.
if e.head == :ref
return isliteral(e.args[1])
end
return !isempty(e.args) && all(isliteral, e.args)
end

"""
check_tilde_rhs(x)
Expand All @@ -172,7 +181,7 @@
function check_tilde_rhs(@nospecialize(x))
return throw(
ArgumentError(
"the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s",
"the right-hand side of a `~` must be a `Distribution`, an array of `Distribution`s, or a submodel",
),
)
end
Expand All @@ -184,6 +193,22 @@
return Sampleable{typeof(model),AutoPrefix}(model)
end

"""
check_dot_tilde_rhs(x)

Check if the right-hand side `x` of a `.~` is a `UnivariateDistribution`, then return `x`.
"""
function check_dot_tilde_rhs(@nospecialize(x))
return throw(

Check warning on line 202 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L201-L202

Added lines #L201 - L202 were not covered by tests
ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`")
)
end
check_dot_tilde_rhs(x::UnivariateDistribution) = x
function check_dot_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
model = check_dot_tilde_rhs(x.model)
return Sampleable{typeof(model),AutoPrefix}(model)

Check warning on line 209 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L207-L209

Added lines #L207 - L209 were not covered by tests
end

"""
unwrap_right_vn(right, vn)

Expand Down Expand Up @@ -356,11 +381,8 @@
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return Base.remove_linenums!(
generate_dot_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
),
return generate_mainbody!(
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
)
end

Expand Down Expand Up @@ -487,56 +509,16 @@
Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right)
isliteral(left) && return generate_tilde_literal(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn isassumption value
@gensym dist left_axes idx
return quote
$vn = $(DynamicPPL.resolve_varnames)(
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
$left .= $(DynamicPPL.getfixed_nested)(__context__, $vn)
elseif $isassumption
$(generate_dot_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn)
end

$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
)
$value
$dist = DynamicPPL.check_dot_tilde_rhs($right)
$left_axes = axes($left)
for $idx in Iterators.product($left_axes...)
$left[$idx...] ~ $dist
end
end
end

function generate_dot_tilde_assume(left, right, vn)
# We don't need to use `Setfield.@set` here since
# `.=` is always going to be inplace + needs `left` to
# be something that supports `.=`.
@gensym value
return quote
$value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
__varinfo__,
)
$left .= $value
$value
end
end

# Note that we cannot use `MacroTools.isdef` because
# of https://github.com/FluxML/MacroTools.jl/issues/154.
"""
Expand Down
Loading
Loading