Skip to content

Commit 74c9f12

Browse files
committed
Combine LogDensityFunction{,WithGrad} into one (#811)
1 parent bb832ab commit 74c9f12

File tree

3 files changed

+175
-125
lines changed

3 files changed

+175
-125
lines changed

docs/src/api.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,10 @@ logjoint
5454

5555
### LogDensityProblems.jl interface
5656

57-
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction` or `DynamicPPL.LogDensityFunctionWithGrad`.
57+
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`.
5858

5959
```@docs
6060
DynamicPPL.LogDensityFunction
61-
DynamicPPL.LogDensityFunctionWithGrad
6261
```
6362

6463
## Condition and decondition

src/logdensityfunction.jl

Lines changed: 160 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
11
import DifferentiationInterface as DI
22

33
"""
4-
LogDensityFunction
5-
6-
A callable representing a log density function of a `model`.
7-
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface,
8-
but only to 0th-order, i.e. it is only possible to calculate the log density,
9-
and not its gradient. If you need to calculate the gradient as well, you have
10-
to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object.
4+
LogDensityFunction(
5+
model::Model,
6+
varinfo::AbstractVarInfo=VarInfo(model),
7+
context::AbstractContext=DefaultContext();
8+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
9+
)
10+
11+
A struct which contains a model, along with all the information necessary to:
12+
13+
- calculate its log density at a given point;
14+
- and if `adtype` is provided, calculate the gradient of the log density at
15+
that point.
16+
17+
At its most basic level, a LogDensityFunction wraps the model together with its
18+
the type of varinfo to be used, as well as the evaluation context. These must
19+
be known in order to calculate the log density (using
20+
[`DynamicPPL.evaluate!!`](@ref)).
21+
22+
If the `adtype` keyword argument is provided, then this struct will also store
23+
the adtype along with other information for efficient calculation of the
24+
gradient of the log density. Note that preparing a `LogDensityFunction` with an
25+
AD type `AutoBackend()` requires the AD backend itself to have been loaded
26+
(e.g. with `import Backend`).
27+
28+
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
29+
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
30+
concrete AD backend type, then `logdensity_and_gradient` is also implemented.
1131
1232
# Fields
1333
$(FIELDS)
1434
1535
# Examples
36+
1637
```jldoctest
1738
julia> using Distributions
1839
@@ -48,66 +69,150 @@ julia> # This also respects the context in `model`.
4869
4970
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
5071
true
72+
73+
julia> # If we also need to calculate the gradient, we can specify an AD backend.
74+
import ForwardDiff, ADTypes
75+
76+
julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
77+
78+
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
79+
(-2.3378770664093453, [1.0])
5180
```
5281
"""
53-
struct LogDensityFunction{V,M,C}
54-
"varinfo used for evaluation"
55-
varinfo::V
82+
struct LogDensityFunction{
83+
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
84+
}
5685
"model used for evaluation"
5786
model::M
87+
"varinfo used for evaluation"
88+
varinfo::V
5889
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
5990
context::C
60-
end
61-
62-
function LogDensityFunction(
63-
model::Model,
64-
varinfo::AbstractVarInfo=VarInfo(model),
65-
context::Union{Nothing,AbstractContext}=nothing,
66-
)
67-
return LogDensityFunction(varinfo, model, context)
68-
end
91+
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
92+
adtype::AD
93+
"(internal use only) gradient preparation object for the model"
94+
prep::Union{Nothing,DI.GradientPrep}
95+
"(internal use only) whether a closure was used for the gradient preparation"
96+
with_closure::Bool
6997

70-
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
71-
function getcontext(f::LogDensityFunction)
72-
return f.context === nothing ? leafcontext(f.model.context) : f.context
98+
function LogDensityFunction(
99+
model::Model,
100+
varinfo::AbstractVarInfo=VarInfo(model),
101+
context::AbstractContext=leafcontext(model.context);
102+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
103+
)
104+
if adtype === nothing
105+
prep = nothing
106+
with_closure = false
107+
else
108+
# Get a set of dummy params to use for prep
109+
x = map(identity, varinfo[:])
110+
with_closure = use_closure(adtype)
111+
if with_closure
112+
prep = DI.prepare_gradient(
113+
x -> logdensity_at(x, model, varinfo, context), adtype, x
114+
)
115+
else
116+
prep = DI.prepare_gradient(
117+
logdensity_at,
118+
adtype,
119+
x,
120+
DI.Constant(model),
121+
DI.Constant(varinfo),
122+
DI.Constant(context),
123+
)
124+
end
125+
with_closure = with_closure
126+
end
127+
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
128+
model, varinfo, context, adtype, prep, with_closure
129+
)
130+
end
73131
end
74132

75133
"""
76-
getmodel(f)
134+
setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
77135
78-
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
79-
"""
80-
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
136+
Set the AD type used for evaluation of log density gradient in the given LogDensityFunction.
137+
This function also performs preparation of the gradient, and sets the `prep`
138+
and `with_closure` fields of the LogDensityFunction.
81139
82-
"""
83-
setmodel(f, model[, adtype])
140+
If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well.
84141
85-
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
142+
This function returns a new LogDensityFunction with the updated AD type, i.e. it does
143+
not mutate the input LogDensityFunction.
86144
"""
87-
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
88-
return Accessors.@set f.model = model
145+
function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
146+
return if adtype === f.adtype
147+
f # Avoid recomputing prep if not needed
148+
else
149+
LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype)
150+
end
89151
end
90152

91153
"""
92-
getparams(f::LogDensityFunction)
93-
94-
Return the parameters of the wrapped varinfo as a vector.
154+
logdensity_at(
155+
x::AbstractVector,
156+
model::Model,
157+
varinfo::AbstractVarInfo,
158+
context::AbstractContext
159+
)
160+
161+
Evaluate the log density of the given `model` at the given parameter values `x`,
162+
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
163+
only for its structure, in the sense that the parameters from the vector `x` are inserted into
164+
it, and its own parameters are discarded.
95165
"""
96-
getparams(f::LogDensityFunction) = f.varinfo[:]
166+
function logdensity_at(
167+
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
168+
)
169+
varinfo_new = unflatten(varinfo, x)
170+
return getlogp(last(evaluate!!(model, varinfo_new, context)))
171+
end
172+
173+
### LogDensityProblems interface
97174

98-
# LogDensityProblems interface: logp (0th order)
175+
function LogDensityProblems.capabilities(
176+
::Type{<:LogDensityFunction{M,V,C,Nothing}}
177+
) where {M,V,C}
178+
return LogDensityProblems.LogDensityOrder{0}()
179+
end
180+
function LogDensityProblems.capabilities(
181+
::Type{<:LogDensityFunction{M,V,C,AD}}
182+
) where {M,V,C,AD<:ADTypes.AbstractADType}
183+
return LogDensityProblems.LogDensityOrder{1}()
184+
end
99185
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
100-
context = getcontext(f)
101-
vi_new = unflatten(f.varinfo, x)
102-
return getlogp(last(evaluate!!(f.model, vi_new, context)))
186+
return logdensity_at(x, f.model, f.varinfo, f.context)
103187
end
104-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
105-
return LogDensityProblems.LogDensityOrder{0}()
188+
function LogDensityProblems.logdensity_and_gradient(
189+
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
190+
) where {M,V,C,AD<:ADTypes.AbstractADType}
191+
f.prep === nothing &&
192+
error("Gradient preparation not available; this should not happen")
193+
x = map(identity, x) # Concretise type
194+
return if f.with_closure
195+
DI.value_and_gradient(
196+
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
197+
)
198+
else
199+
DI.value_and_gradient(
200+
logdensity_at,
201+
f.prep,
202+
f.adtype,
203+
x,
204+
DI.Constant(f.model),
205+
DI.Constant(f.varinfo),
206+
DI.Constant(f.context),
207+
)
208+
end
106209
end
210+
107211
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
108212
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
109213

110-
# LogDensityProblems interface: gradient (1st order)
214+
### Utils
215+
111216
"""
112217
use_closure(adtype::ADTypes.AbstractADType)
113218
@@ -139,75 +244,24 @@ use_closure(::ADTypes.AutoMooncake) = false
139244
use_closure(::ADTypes.AutoReverseDiff) = true
140245

141246
"""
142-
_flipped_logdensity(f::LogDensityFunction, x::AbstractVector)
247+
getmodel(f)
143248
144-
This function is the same as `LogDensityProblems.logdensity(f, x)` but with the
145-
arguments flipped. It is used in the 'constant' approach to DifferentiationInterface
146-
(see `use_closure` for more information).
249+
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
147250
"""
148-
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
149-
return LogDensityProblems.logdensity(f, x)
150-
end
251+
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
151252

152253
"""
153-
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)
154-
155-
A callable representing a log density function of a `model`.
156-
`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl
157-
interface to 1st-order, meaning that you can both calculate the log density
158-
using
159-
160-
LogDensityProblems.logdensity(f, x)
161-
162-
and its gradient using
163-
164-
LogDensityProblems.logdensity_and_gradient(f, x)
254+
setmodel(f, model[, adtype])
165255
166-
where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters.
256+
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
257+
"""
258+
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
259+
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
260+
end
167261

168-
# Fields
169-
$(FIELDS)
170262
"""
171-
struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
172-
ldf::LogDensityFunction{V,M,C}
173-
adtype::TAD
174-
prep::DI.GradientPrep
175-
with_closure::Bool
263+
getparams(f::LogDensityFunction)
176264
177-
function LogDensityFunctionWithGrad(
178-
ldf::LogDensityFunction{V,M,C}, adtype::TAD
179-
) where {V,M,C,TAD}
180-
# Get a set of dummy params to use for prep
181-
x = map(identity, getparams(ldf))
182-
with_closure = use_closure(adtype)
183-
if with_closure
184-
prep = DI.prepare_gradient(
185-
Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x
186-
)
187-
else
188-
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))
189-
end
190-
# Store the prep with the struct. We also store whether a closure was used because
191-
# we need to know this when calling `DI.value_and_gradient`. In practice we could
192-
# recalculate it, but this runs the risk of introducing inconsistencies.
193-
return new{V,M,C,TAD}(ldf, adtype, prep, with_closure)
194-
end
195-
end
196-
function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad)
197-
return LogDensityProblems.logdensity(f.ldf)
198-
end
199-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad})
200-
return LogDensityProblems.LogDensityOrder{1}()
201-
end
202-
function LogDensityProblems.logdensity_and_gradient(
203-
f::LogDensityFunctionWithGrad, x::AbstractVector
204-
)
205-
x = map(identity, x) # Concretise type
206-
return if f.with_closure
207-
DI.value_and_gradient(
208-
Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x
209-
)
210-
else
211-
DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf))
212-
end
213-
end
265+
Return the parameters of the wrapped varinfo as a vector.
266+
"""
267+
getparams(f::LogDensityFunction) = f.varinfo[:]

test/ad.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
1+
using DynamicPPL: LogDensityFunction
22

33
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
44
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@@ -10,11 +10,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
1010
f = LogDensityFunction(m, varinfo)
1111
x = DynamicPPL.getparams(f)
1212
# Calculate reference logp + gradient of logp using ForwardDiff
13-
default_adtype = ADTypes.AutoForwardDiff()
14-
ldf_with_grad = LogDensityFunctionWithGrad(f, default_adtype)
15-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
16-
ldf_with_grad, x
17-
)
13+
ref_adtype = ADTypes.AutoForwardDiff()
14+
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
15+
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
1816

1917
@testset "$adtype" for adtype in [
2018
AutoReverseDiff(; compile=false),
@@ -33,20 +31,18 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
3331
# Mooncake doesn't work with several combinations of SimpleVarInfo.
3432
if is_mooncake && is_1_11 && is_svi_vnv
3533
# https://github.com/compintell/Mooncake.jl/issues/470
36-
@test_throws ArgumentError LogDensityFunctionWithGrad(f, adtype)
34+
@test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype)
3735
elseif is_mooncake && is_1_10 && is_svi_vnv
3836
# TODO: report upstream
39-
@test_throws UndefRefError LogDensityFunctionWithGrad(f, adtype)
37+
@test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype)
4038
elseif is_mooncake && is_1_10 && is_svi_od
4139
# TODO: report upstream
42-
@test_throws Mooncake.MooncakeRuleCompilationError LogDensityFunctionWithGrad(
43-
f, adtype
40+
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype(
41+
ref_ldf, adtype
4442
)
4543
else
46-
ldf_with_grad = LogDensityFunctionWithGrad(f, adtype)
47-
logp, grad = LogDensityProblems.logdensity_and_gradient(
48-
ldf_with_grad, x
49-
)
44+
ldf = DynamicPPL.setadtype(ref_ldf, adtype)
45+
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
5046
@test grad ref_grad
5147
@test logp ref_logp
5248
end
@@ -89,8 +85,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
8985
# Compiling the ReverseDiff tape used to fail here
9086
spl = Sampler(MyEmptyAlg())
9187
vi = VarInfo(model)
92-
ldf = LogDensityFunction(vi, model, SamplingContext(spl))
93-
ldf_grad = LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true))
94-
@test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any
88+
ldf = LogDensityFunction(
89+
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
90+
)
91+
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
9592
end
9693
end

0 commit comments

Comments
 (0)