Skip to content

Commit 3875d41

Browse files
committed
Improve type stability when all parameters are linked or unlinked
1 parent 6f5df1a commit 3875d41

File tree

2 files changed

+74
-16
lines changed

2 files changed

+74
-16
lines changed

src/contexts/init.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ struct RangeAndLinked
214214
end
215215

216216
"""
217-
VectorWithRanges(
217+
VectorWithRanges{Tlink}(
218218
iden_varname_ranges::NamedTuple,
219219
varname_ranges::Dict{VarName,RangeAndLinked},
220220
vect::AbstractVector{<:Real},
@@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
231231
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
232232
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
233233
"""
234-
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
234+
struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
235235
# This NamedTuple stores the ranges for identity VarNames
236236
iden_varname_ranges::N
237237
# This Dict stores the ranges for all other VarNames
238238
varname_ranges::Dict{VarName,RangeAndLinked}
239239
# The full parameter vector which we index into to get variable values
240240
vect::T
241+
242+
function VectorWithRanges{Tlink}(
243+
iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T
244+
) where {Tlink,N,T}
245+
return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect)
246+
end
241247
end
242248

243249
function _get_range_and_linked(
@@ -252,7 +258,29 @@ function init(
252258
::Random.AbstractRNG,
253259
vn::VarName,
254260
dist::Distribution,
255-
p::InitFromParams{<:VectorWithRanges},
261+
p::InitFromParams{<:VectorWithRanges{true}},
262+
)
263+
vr = p.params
264+
range_and_linked = _get_range_and_linked(vr, vn)
265+
transform = from_linked_vec_transform(dist)
266+
return (@view vr.vect[range_and_linked.range]), transform
267+
end
268+
function init(
269+
::Random.AbstractRNG,
270+
vn::VarName,
271+
dist::Distribution,
272+
p::InitFromParams{<:VectorWithRanges{false}},
273+
)
274+
vr = p.params
275+
range_and_linked = _get_range_and_linked(vr, vn)
276+
transform = from_vec_transform(dist)
277+
return (@view vr.vect[range_and_linked.range]), transform
278+
end
279+
function init(
280+
::Random.AbstractRNG,
281+
vn::VarName,
282+
dist::Distribution,
283+
p::InitFromParams{<:VectorWithRanges{nothing}},
256284
)
257285
vr = p.params
258286
range_and_linked = _get_range_and_linked(vr, vn)

src/fasteval.jl

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
194194
`unflatten` + `evaluate!!` approach also fails with such models.
195195
"""
196196
struct LogDensityFunction{
197+
# true if all variables are linked; false if all variables are unlinked; nothing if
198+
# mixed
199+
Tlink,
197200
M<:Model,
198201
AD<:Union{ADTypes.AbstractADType,Nothing},
199202
F<:Function,
@@ -217,6 +220,21 @@ struct LogDensityFunction{
217220
# Figure out which variable corresponds to which index, and
218221
# which variables are linked.
219222
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
223+
# Figure out if all variables are linked, unlinked, or mixed
224+
link_statuses = Bool[]
225+
for ral in all_iden_ranges
226+
push!(link_statuses, ral.is_linked)
227+
end
228+
for (_, ral) in all_ranges
229+
push!(link_statuses, ral.is_linked)
230+
end
231+
Tlink = if all(link_statuses)
232+
true
233+
elseif all(!s for s in link_statuses)
234+
false
235+
else
236+
nothing
237+
end
220238
x = [val for val in varinfo[:]]
221239
dim = length(x)
222240
# Do AD prep if needed
@@ -226,12 +244,13 @@ struct LogDensityFunction{
226244
# Make backend-specific tweaks to the adtype
227245
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
228246
DI.prepare_gradient(
229-
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
247+
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
230248
adtype,
231249
x,
232250
)
233251
end
234252
return new{
253+
Tlink,
235254
typeof(model),
236255
typeof(adtype),
237256
typeof(getlogdensity),
@@ -263,36 +282,45 @@ end
263282
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
264283
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
265284

266-
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
285+
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
267286
model::M
268287
getlogdensity::F
269288
iden_varname_ranges::N
270289
varname_ranges::Dict{VarName,RangeAndLinked}
290+
291+
function LogDensityAt{Tlink}(
292+
model::M,
293+
getlogdensity::F,
294+
iden_varname_ranges::N,
295+
varname_ranges::Dict{VarName,RangeAndLinked},
296+
) where {Tlink,M,F,N}
297+
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
298+
end
271299
end
272-
function (f::LogDensityAt)(params::AbstractVector{<:Real})
300+
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
273301
strategy = InitFromParams(
274-
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
302+
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
275303
)
276304
accs = fast_ldf_accs(f.getlogdensity)
277305
_, vi = fast_evaluate!!(f.model, strategy, accs)
278306
return f.getlogdensity(vi)
279307
end
280308

281309
function LogDensityProblems.logdensity(
282-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
283-
)
284-
return LogDensityAt(
310+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
311+
) where {Tlink}
312+
return LogDensityAt{Tlink}(
285313
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
286314
)(
287315
params
288316
)
289317
end
290318

291319
function LogDensityProblems.logdensity_and_gradient(
292-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
293-
)
320+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
321+
) where {Tlink}
294322
return DI.value_and_gradient(
295-
LogDensityAt(
323+
LogDensityAt{Tlink}(
296324
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
297325
),
298326
ldf._adprep,
@@ -301,12 +329,14 @@ function LogDensityProblems.logdensity_and_gradient(
301329
)
302330
end
303331

304-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
332+
function LogDensityProblems.capabilities(
333+
::Type{<:LogDensityFunction{T,M,Nothing}}
334+
) where {T,M}
305335
return LogDensityProblems.LogDensityOrder{0}()
306336
end
307337
function LogDensityProblems.capabilities(
308-
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
309-
) where {M}
338+
::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}}
339+
) where {T,M}
310340
return LogDensityProblems.LogDensityOrder{1}()
311341
end
312342
function LogDensityProblems.dimension(ldf::LogDensityFunction)

0 commit comments

Comments
 (0)