Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@

### Breaking changes

#### Fast Log Density Functions

This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.

For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.

As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
If you were previously relying on this behaviour, you will need to store a VarInfo separately.

#### Parent and leaf contexts

The `DynamicPPL.NodeTrait` function has been removed.
Expand All @@ -24,18 +35,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).

### Other changes

#### FastLDF

Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.

Please note that `FastLDF` is currently considered internal and its API may change without warning.
We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it.

For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.

## 0.38.9

Remove warning when using Enzyme as the AD backend.
Expand Down
8 changes: 7 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
LogDensityFunction
```

Internally, this is accomplished using [`init!!`](@ref) on:

```@docs
OnlyAccsVarInfo
```

## Condition and decondition

A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).
Expand Down Expand Up @@ -510,7 +516,7 @@ The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.

```@docs
DynamicPPL.init!!
init!!
```

To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
Expand Down
11 changes: 8 additions & 3 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
# below.
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction}
struct LogDensityFunctionWrapper{
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
}
logdensity::L
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
# actual log-density evaluation.
varinfo::V
end
function (lw::LogDensityFunctionWrapper)(x, _)
return LogDensityProblems.logdensity(lw.logdensity, x)
Expand Down Expand Up @@ -101,7 +106,7 @@ function DynamicPPL.marginalize(
# Construct the marginal log-density model.
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
mld = MarginalLogDensities.MarginalLogDensity(
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs...
LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs...
)
return mld
end
Expand Down Expand Up @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo(
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
)
# Extract the original VarInfo. Its contents will in general be junk.
original_vi = mld.logdensity.logdensity.varinfo
original_vi = mld.logdensity.varinfo
# Extract the stored parameters, which includes the modes for any marginalized
# parameters
full_params = MarginalLogDensities.cached_params(mld)
Expand Down
8 changes: 6 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ export AbstractVarInfo,
getargnames,
extract_priors,
values_as_in_model,
# LogDensityFunction
# evaluation
evaluate!!,
init!!,
# LogDensityFunction and fasteval
LogDensityFunction,
OnlyAccsVarInfo,
# Leaf contexts
AbstractContext,
contextualize,
Expand Down Expand Up @@ -198,7 +202,7 @@ include("simple_varinfo.jl")
include("onlyaccs.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("logdensityfunction.jl")
include("fasteval.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
Expand Down
8 changes: 3 additions & 5 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ end
"""
ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.Experimental.FastLDF,
ldf::DynamicPPL.LogDensityFunction,
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
Expand All @@ -156,7 +156,7 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
"""
function ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.Experimental.FastLDF,
ldf::DynamicPPL.LogDensityFunction,
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
Expand All @@ -174,9 +174,7 @@ function ParamsWithStats(
else
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
end
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
ldf.model, strategy, AccumulatorTuple(accs)
)
_, vi = DynamicPPL.init!!(ldf.model, OnlyAccsVarInfo(AccumulatorTuple(accs)), strategy)
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
if include_log_probs
stats = merge(
Expand Down
2 changes: 0 additions & 2 deletions src/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module Experimental

using DynamicPPL: DynamicPPL

include("fasteval.jl")

# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
"""
is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...)
Expand Down
147 changes: 56 additions & 91 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,61 +30,7 @@ import DifferentiationInterface as DI
using Random: Random

"""
DynamicPPL.Experimental.fast_evaluate!!(
[rng::Random.AbstractRNG,]
model::Model,
strategy::AbstractInitStrategy,
accs::AccumulatorTuple, params::AbstractVector{<:Real}
)

Evaluate a model using parameters obtained via `strategy`, and only computing the results in
the provided accumulators.

It is assumed that the accumulators passed in have been initialised to appropriate values,
as this function will not reset them. The default constructors for each accumulator will do
this for you correctly.

Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
in the function name.
"""
@inline function fast_evaluate!!(
# Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
# to extra allocations (even for trivial models) and much slower runtime.
rng::Random.AbstractRNG,
model::Model,
strategy::AbstractInitStrategy,
accs::AccumulatorTuple,
)
ctx = InitContext(rng, strategy)
model = DynamicPPL.setleafcontext(model, ctx)
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
# here.
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
param_eltype = DynamicPPL.get_param_eltype(strategy)
accs = map(accs) do acc
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
end
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
else
OnlyAccsVarInfo(accs)
end
return DynamicPPL._evaluate!!(model, vi)
end
@inline function fast_evaluate!!(
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
)
# This `@inline` is also mandatory for performance
return fast_evaluate!!(Random.default_rng(), model, strategy, accs)
end

"""
FastLDF(
DynamicPPL.LogDensityFunction(
model::Model,
getlogdensity::Function=getlogjoint_internal,
varinfo::AbstractVarInfo=VarInfo(model);
Expand Down Expand Up @@ -115,26 +61,27 @@ There are several options for `getlogdensity` that are 'supported' out of the bo
since transforms are only applied to random variables)

!!! note
By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of
`LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created
with a linked or unlinked VarInfo. This is done primarily to ease interoperability with
MCMC samplers.
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of
`LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction`
was created with a linked or unlinked VarInfo. This is done primarily to ease
interoperability with MCMC samplers.

If you provide one of these functions, a `VarInfo` will be automatically created for you. If
you provide a different function, you have to manually create a VarInfo and pass it as the
third argument.

If the `adtype` keyword argument is provided, then this struct will also store the adtype
along with other information for efficient calculation of the gradient of the log density.
Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend
itself to have been loaded (e.g. with `import Backend`).
Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD
backend itself to have been loaded (e.g. with `import Backend`).

## Fields

Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from:
Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart
from:

- `fastldf.model`: The original model from which this `FastLDF` was constructed.
- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD
- `ldf.model`: The original model from which this `LogDensityFunction` was constructed.
- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD
type was provided.

# Extended help
Expand Down Expand Up @@ -172,8 +119,9 @@ Traditionally, this problem has been solved by `unflatten`, because that functio
place values into the VarInfo's metadata alongside the information about ranges and linking.
That way, when we evaluate with `DefaultContext`, we can read this information out again.
However, we want to avoid using a metadata. Thus, here, we _extract this information from
the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we
store a mapping from VarNames to ranges in that vector, along with link status.
the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the
LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with
link status.

For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all
other VarNames, this is stored in a Dict. The internal data structure used to represent this
Expand All @@ -185,13 +133,13 @@ ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quick
parameter values from the vector.

Note that this assumes that the ranges and link status are static throughout the lifetime of
the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable
numbers of parameters, or models which may visit random variables in different orders depending
on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a
general limitation of vectorised parameters: the original `unflatten` + `evaluate!!`
approach also fails with such models.
the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle
models which have variable numbers of parameters, or models which may visit random variables
in different orders depending on stochastic control flow. **Indeed, silent errors may occur
with such models.** This is a general limitation of vectorised parameters: the original
`unflatten` + `evaluate!!` approach also fails with such models.
"""
struct FastLDF{
struct LogDensityFunction{
M<:Model,
AD<:Union{ADTypes.AbstractADType,Nothing},
F<:Function,
Expand All @@ -206,7 +154,7 @@ struct FastLDF{
_adprep::ADP
_dim::Int

function FastLDF(
function LogDensityFunction(
model::Model,
getlogdensity::Function=getlogjoint_internal,
varinfo::AbstractVarInfo=VarInfo(model);
Expand All @@ -224,7 +172,7 @@ struct FastLDF{
# Make backend-specific tweaks to the adtype
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
DI.prepare_gradient(
FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
adtype,
x,
)
Expand Down Expand Up @@ -261,56 +209,73 @@ end
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))

struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
model::M
getlogdensity::F
iden_varname_ranges::N
varname_ranges::Dict{VarName,RangeAndLinked}
end
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
function (f::LogDensityAt)(params::AbstractVector{<:Real})
strategy = InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
)
accs = fast_ldf_accs(f.getlogdensity)
_, vi = fast_evaluate!!(f.model, strategy, accs)
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
return f.getlogdensity(vi)
end

function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
return FastLogDensityAt(
fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
function LogDensityProblems.logdensity(
ldf::LogDensityFunction, params::AbstractVector{<:Real}
)
return LogDensityAt(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
)(
params
)
end

function LogDensityProblems.logdensity_and_gradient(
fldf::FastLDF, params::AbstractVector{<:Real}
ldf::LogDensityFunction, params::AbstractVector{<:Real}
)
return DI.value_and_gradient(
FastLogDensityAt(
fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
LogDensityAt(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
),
fldf._adprep,
fldf.adtype,
ldf._adprep,
ldf.adtype,
params,
)
end

function LogDensityProblems.capabilities(
::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}}
) where {M}
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}}
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
) where {M}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.dimension(fldf::FastLDF)
return fldf._dim
function LogDensityProblems.dimension(ldf::LogDensityFunction)
return ldf._dim
end

"""
tweak_adtype(
adtype::ADTypes.AbstractADType,
model::Model,
varinfo::AbstractVarInfo,
)

Return an 'optimised' form of the adtype. This is useful for doing
backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating
the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`).
The model is passed as a parameter in case the optimisation depends on the
model.

By default, this just returns the input unchanged.
"""
tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype

######################################################
# Helper functions to extract ranges and link status #
######################################################
Expand Down
Loading
Loading