@@ -194,6 +194,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
194194`unflatten` + `evaluate!!` approach also fails with such models.
195195"""
196196struct LogDensityFunction{
197+ # true if all variables are linked; false if all variables are unlinked; nothing if
198+ # mixed
199+ Tlink,
197200 M<: Model ,
198201 AD<: Union{ADTypes.AbstractADType,Nothing} ,
199202 F<: Function ,
@@ -217,6 +220,21 @@ struct LogDensityFunction{
217220 # Figure out which variable corresponds to which index, and
218221 # which variables are linked.
219222 all_iden_ranges, all_ranges = get_ranges_and_linked (varinfo)
223+ # Figure out if all variables are linked, unlinked, or mixed
224+ link_statuses = Bool[]
225+ for ral in all_iden_ranges
226+ push! (link_statuses, ral. is_linked)
227+ end
228+ for (_, ral) in all_ranges
229+ push! (link_statuses, ral. is_linked)
230+ end
231+ Tlink = if all (link_statuses)
232+ true
233+ elseif all (! s for s in link_statuses)
234+ false
235+ else
236+ nothing
237+ end
220238 x = [val for val in varinfo[:]]
221239 dim = length (x)
222240 # Do AD prep if needed
@@ -226,12 +244,13 @@ struct LogDensityFunction{
226244 # Make backend-specific tweaks to the adtype
227245 adtype = DynamicPPL. tweak_adtype (adtype, model, varinfo)
228246 DI. prepare_gradient (
229- LogDensityAt (model, getlogdensity, all_iden_ranges, all_ranges),
247+ LogDensityAt {Tlink} (model, getlogdensity, all_iden_ranges, all_ranges),
230248 adtype,
231249 x,
232250 )
233251 end
234252 return new{
253+ Tlink,
235254 typeof (model),
236255 typeof (adtype),
237256 typeof (getlogdensity),
@@ -263,36 +282,45 @@ end
263282fast_ldf_accs (:: typeof (getlogprior)) = AccumulatorTuple ((LogPriorAccumulator (),))
264283fast_ldf_accs (:: typeof (getloglikelihood)) = AccumulatorTuple ((LogLikelihoodAccumulator (),))
265284
266- struct LogDensityAt{M<: Model ,F<: Function ,N<: NamedTuple }
285+ struct LogDensityAt{Tlink, M<: Model ,F<: Function ,N<: NamedTuple }
267286 model:: M
268287 getlogdensity:: F
269288 iden_varname_ranges:: N
270289 varname_ranges:: Dict{VarName,RangeAndLinked}
290+
291+ function LogDensityAt {Tlink} (
292+ model:: M ,
293+ getlogdensity:: F ,
294+ iden_varname_ranges:: N ,
295+ varname_ranges:: Dict{VarName,RangeAndLinked} ,
296+ ) where {Tlink,M,F,N}
297+ return new {Tlink,M,F,N} (model, getlogdensity, iden_varname_ranges, varname_ranges)
298+ end
271299end
272- function (f:: LogDensityAt )(params:: AbstractVector{<:Real} )
300+ function (f:: LogDensityAt{Tlink} )(params:: AbstractVector{<:Real} ) where {Tlink}
273301 strategy = InitFromParams (
274- VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
302+ VectorWithRanges {Tlink} (f. iden_varname_ranges, f. varname_ranges, params), nothing
275303 )
276304 accs = fast_ldf_accs (f. getlogdensity)
277305 _, vi = fast_evaluate!! (f. model, strategy, accs)
278306 return f. getlogdensity (vi)
279307end
280308
281309function LogDensityProblems. logdensity (
282- ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
283- )
284- return LogDensityAt (
310+ ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
311+ ) where {Tlink}
312+ return LogDensityAt {Tlink} (
285313 ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
286314 )(
287315 params
288316 )
289317end
290318
291319function LogDensityProblems. logdensity_and_gradient (
292- ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
293- )
320+ ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
321+ ) where {Tlink}
294322 return DI. value_and_gradient (
295- LogDensityAt (
323+ LogDensityAt {Tlink} (
296324 ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
297325 ),
298326 ldf. _adprep,
@@ -301,12 +329,14 @@ function LogDensityProblems.logdensity_and_gradient(
301329 )
302330end
303331
304- function LogDensityProblems. capabilities (:: Type{<:LogDensityFunction{M,Nothing}} ) where {M}
332+ function LogDensityProblems. capabilities (
333+ :: Type{<:LogDensityFunction{T,M,Nothing}}
334+ ) where {T,M}
305335 return LogDensityProblems. LogDensityOrder {0} ()
306336end
307337function LogDensityProblems. capabilities (
308- :: Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
309- ) where {M}
338+ :: Type{<:LogDensityFunction{T, M,<:ADTypes.AbstractADType}}
339+ ) where {T, M}
310340 return LogDensityProblems. LogDensityOrder {1} ()
311341end
312342function LogDensityProblems. dimension (ldf:: LogDensityFunction )
0 commit comments