@@ -106,8 +106,6 @@ struct LogDensityFunction{
106
106
adtype:: AD
107
107
" (internal use only) gradient preparation object for the model"
108
108
prep:: Union{Nothing,DI.GradientPrep}
109
- " (internal use only) whether a closure was used for the gradient preparation"
110
- with_closure:: Bool
111
109
112
110
function LogDensityFunction (
113
111
model:: Model ,
@@ -117,15 +115,13 @@ struct LogDensityFunction{
117
115
)
118
116
if adtype === nothing
119
117
prep = nothing
120
- with_closure = false
121
118
else
122
119
# Check support
123
120
is_supported (adtype) ||
124
121
@warn " The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
125
122
# Get a set of dummy params to use for prep
126
123
x = map (identity, varinfo[:])
127
- with_closure = use_closure (adtype)
128
- if with_closure
124
+ if use_closure (adtype)
129
125
prep = DI. prepare_gradient (
130
126
x -> logdensity_at (x, model, varinfo, context), adtype, x
131
127
)
@@ -139,20 +135,19 @@ struct LogDensityFunction{
139
135
DI. Constant (context),
140
136
)
141
137
end
142
- with_closure = with_closure
143
138
end
144
139
return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype)} (
145
- model, varinfo, context, adtype, prep, with_closure
140
+ model, varinfo, context, adtype, prep
146
141
)
147
142
end
148
143
end
149
144
150
145
"""
151
146
setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
152
147
153
- Set the AD type used for evaluation of log density gradient in the given LogDensityFunction.
154
- This function also performs preparation of the gradient, and sets the `prep`
155
- and `with_closure` fields of the LogDensityFunction.
148
+ Set the AD type used for evaluation of log density gradient in the given
149
+ LogDensityFunction. This function also performs preparation of the gradient,
150
+ and sets the `prep` field of the LogDensityFunction.
156
151
157
152
If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well.
158
153
@@ -208,7 +203,9 @@ function LogDensityProblems.logdensity_and_gradient(
208
203
f. prep === nothing &&
209
204
error (" Gradient preparation not available; this should not happen" )
210
205
x = map (identity, x) # Concretise type
211
- return if f. with_closure
206
+ # Make branching statically inferrable, i.e. type-stable (even if the two
207
+ # branches happen to return different types)
208
+ return if use_closure (f. adtype)
212
209
DI. value_and_gradient (
213
210
x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
214
211
)
0 commit comments