Skip to content

Commit f76bb3d

Browse files
penelopeysmgdalle
andcommitted
Don't store with_closure inside LogDensityFunction
Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 74fbad2 commit f76bb3d

File tree

2 files changed

+62
-11
lines changed

2 files changed

+62
-11
lines changed

ext/DynamicPPLForwardDiffExt.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
module DynamicPPLForwardDiffExt
2+
3+
if isdefined(Base, :get_extension)
4+
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
5+
using ForwardDiff
6+
else
7+
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
8+
using ..ForwardDiff
9+
end
10+
11+
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk
12+
13+
standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
14+
standardtag(::ADTypes.AutoForwardDiff) = false
15+
16+
function LogDensityProblemsAD.ADgradient(
17+
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
18+
)
19+
θ = DynamicPPL.getparams(ℓ)
20+
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)
21+
22+
# Define configuration for ForwardDiff.
23+
tag = if standardtag(ad)
24+
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
25+
else
26+
ForwardDiff.Tag(f, eltype(θ))
27+
end
28+
chunk_size = getchunksize(ad)
29+
chunk = if chunk_size == 0 || chunk_size === nothing
30+
ForwardDiff.Chunk(θ)
31+
else
32+
ForwardDiff.Chunk(length(θ), chunk_size)
33+
end
34+
35+
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
36+
end
37+
38+
# Allow Turing tag in gradient etc. calls of the log density function
39+
function ForwardDiff.checktag(
40+
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
41+
::DynamicPPL.LogDensityFunction,
42+
::AbstractArray{W},
43+
) where {V,W}
44+
return true
45+
end
46+
function ForwardDiff.checktag(
47+
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
48+
::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction},
49+
::AbstractArray{W},
50+
) where {V,W}
51+
return true
52+
end
53+
54+
end # module

src/logdensityfunction.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ struct LogDensityFunction{
106106
adtype::AD
107107
"(internal use only) gradient preparation object for the model"
108108
prep::Union{Nothing,DI.GradientPrep}
109-
"(internal use only) whether a closure was used for the gradient preparation"
110-
with_closure::Bool
111109

112110
function LogDensityFunction(
113111
model::Model,
@@ -117,15 +115,13 @@ struct LogDensityFunction{
117115
)
118116
if adtype === nothing
119117
prep = nothing
120-
with_closure = false
121118
else
122119
# Check support
123120
is_supported(adtype) ||
124121
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
125122
# Get a set of dummy params to use for prep
126123
x = map(identity, varinfo[:])
127-
with_closure = use_closure(adtype)
128-
if with_closure
124+
if use_closure(adtype)
129125
prep = DI.prepare_gradient(
130126
x -> logdensity_at(x, model, varinfo, context), adtype, x
131127
)
@@ -139,20 +135,19 @@ struct LogDensityFunction{
139135
DI.Constant(context),
140136
)
141137
end
142-
with_closure = with_closure
143138
end
144139
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
145-
model, varinfo, context, adtype, prep, with_closure
140+
model, varinfo, context, adtype, prep
146141
)
147142
end
148143
end
149144

150145
"""
151146
setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
152147
153-
Set the AD type used for evaluation of log density gradient in the given LogDensityFunction.
154-
This function also performs preparation of the gradient, and sets the `prep`
155-
and `with_closure` fields of the LogDensityFunction.
148+
Set the AD type used for evaluation of log density gradient in the given
149+
LogDensityFunction. This function also performs preparation of the gradient,
150+
and sets the `prep` field of the LogDensityFunction.
156151
157152
If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well.
158153
@@ -208,7 +203,9 @@ function LogDensityProblems.logdensity_and_gradient(
208203
f.prep === nothing &&
209204
error("Gradient preparation not available; this should not happen")
210205
x = map(identity, x) # Concretise type
211-
return if f.with_closure
206+
# Make branching statically inferrable, i.e. type-stable (even if the two
207+
# branches happen to return different types)
208+
return if use_closure(f.adtype)
212209
DI.value_and_gradient(
213210
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
214211
)

0 commit comments

Comments
 (0)