@@ -65,16 +65,21 @@ in the function name.
6565 # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
6666 # it _should_ do, but this is wrong regardless.
6767 # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
68- vi = if Threads. nthreads () > 1
69- param_eltype = DynamicPPL. get_param_eltype (strategy)
68+ return if Threads. nthreads () > 1
69+ # WARNING: Do NOT move get_param_eltype(strategy) into an intermediate variable, it
70+ # will cause type instabilities! See also unflatten in src/varinfo.jl.
7071 accs = map (accs) do acc
71- DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
72+ DynamicPPL. convert_eltype (
73+ float_type_with_fallback (DynamicPPL. get_param_eltype (strategy)), acc
74+ )
7275 end
73- ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
76+ tsvi = ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
77+ retval, tsvi_new = DynamicPPL. _evaluate!! (model, tsvi)
78+ retval, setaccs!! (tsvi_new. varinfo, getaccs (tsvi_new))
7479 else
75- OnlyAccsVarInfo (accs)
80+ vi = OnlyAccsVarInfo (accs)
81+ DynamicPPL. _evaluate!! (model, vi)
7682 end
77- return DynamicPPL. _evaluate!! (model, vi)
7883end
7984@inline function fast_evaluate!! (
8085 model:: Model , strategy:: AbstractInitStrategy , accs:: AccumulatorTuple
@@ -194,6 +199,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
194199`unflatten` + `evaluate!!` approach also fails with such models.
195200"""
196201struct LogDensityFunction{
202+ # true if all variables are linked; false if all variables are unlinked; nothing if
203+ # mixed
204+ Tlink,
197205 M<: Model ,
198206 AD<: Union{ADTypes.AbstractADType,Nothing} ,
199207 F<: Function ,
@@ -217,6 +225,21 @@ struct LogDensityFunction{
217225 # Figure out which variable corresponds to which index, and
218226 # which variables are linked.
219227 all_iden_ranges, all_ranges = get_ranges_and_linked (varinfo)
228+ # Figure out if all variables are linked, unlinked, or mixed
229+ link_statuses = Bool[]
230+ for ral in all_iden_ranges
231+ push! (link_statuses, ral. is_linked)
232+ end
233+ for (_, ral) in all_ranges
234+ push! (link_statuses, ral. is_linked)
235+ end
236+ Tlink = if all (link_statuses)
237+ true
238+ elseif all (! s for s in link_statuses)
239+ false
240+ else
241+ nothing
242+ end
220243 x = [val for val in varinfo[:]]
221244 dim = length (x)
222245 # Do AD prep if needed
@@ -226,12 +249,13 @@ struct LogDensityFunction{
226249 # Make backend-specific tweaks to the adtype
227250 adtype = DynamicPPL. tweak_adtype (adtype, model, varinfo)
228251 DI. prepare_gradient (
229- LogDensityAt (model, getlogdensity, all_iden_ranges, all_ranges),
252+ LogDensityAt {Tlink} (model, getlogdensity, all_iden_ranges, all_ranges),
230253 adtype,
231254 x,
232255 )
233256 end
234257 return new{
258+ Tlink,
235259 typeof (model),
236260 typeof (adtype),
237261 typeof (getlogdensity),
@@ -263,36 +287,45 @@ end
263287fast_ldf_accs (:: typeof (getlogprior)) = AccumulatorTuple ((LogPriorAccumulator (),))
264288fast_ldf_accs (:: typeof (getloglikelihood)) = AccumulatorTuple ((LogLikelihoodAccumulator (),))
265289
266- struct LogDensityAt{M<: Model ,F<: Function ,N<: NamedTuple }
290+ struct LogDensityAt{Tlink, M<: Model ,F<: Function ,N<: NamedTuple }
267291 model:: M
268292 getlogdensity:: F
269293 iden_varname_ranges:: N
270294 varname_ranges:: Dict{VarName,RangeAndLinked}
295+
296+ function LogDensityAt {Tlink} (
297+ model:: M ,
298+ getlogdensity:: F ,
299+ iden_varname_ranges:: N ,
300+ varname_ranges:: Dict{VarName,RangeAndLinked} ,
301+ ) where {Tlink,M,F,N}
302+ return new {Tlink,M,F,N} (model, getlogdensity, iden_varname_ranges, varname_ranges)
303+ end
271304end
272- function (f:: LogDensityAt )(params:: AbstractVector{<:Real} )
305+ function (f:: LogDensityAt{Tlink} )(params:: AbstractVector{<:Real} ) where {Tlink}
273306 strategy = InitFromParams (
274- VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
307+ VectorWithRanges {Tlink} (f. iden_varname_ranges, f. varname_ranges, params), nothing
275308 )
276309 accs = fast_ldf_accs (f. getlogdensity)
277310 _, vi = fast_evaluate!! (f. model, strategy, accs)
278311 return f. getlogdensity (vi)
279312end
280313
281314function LogDensityProblems. logdensity (
282- ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
283- )
284- return LogDensityAt (
315+ ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
316+ ) where {Tlink}
317+ return LogDensityAt {Tlink} (
285318 ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
286319 )(
287320 params
288321 )
289322end
290323
291324function LogDensityProblems. logdensity_and_gradient (
292- ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
293- )
325+ ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
326+ ) where {Tlink}
294327 return DI. value_and_gradient (
295- LogDensityAt (
328+ LogDensityAt {Tlink} (
296329 ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
297330 ),
298331 ldf. _adprep,
@@ -301,12 +334,14 @@ function LogDensityProblems.logdensity_and_gradient(
301334 )
302335end
303336
304- function LogDensityProblems. capabilities (:: Type{<:LogDensityFunction{M,Nothing}} ) where {M}
337+ function LogDensityProblems. capabilities (
338+ :: Type{<:LogDensityFunction{T,M,Nothing}}
339+ ) where {T,M}
305340 return LogDensityProblems. LogDensityOrder {0} ()
306341end
307342function LogDensityProblems. capabilities (
308- :: Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
309- ) where {M}
343+ :: Type{<:LogDensityFunction{T, M,<:ADTypes.AbstractADType}}
344+ ) where {T, M}
310345 return LogDensityProblems. LogDensityOrder {1} ()
311346end
312347function LogDensityProblems. dimension (ldf:: LogDensityFunction )
0 commit comments