Skip to content

Commit 072da15

Browse files
committed
Improve type stability when all parameters are linked or unlinked
1 parent 9310ec0 commit 072da15

File tree

4 files changed

+106
-25
lines changed

4 files changed

+106
-25
lines changed

src/chains.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
156156
"""
157157
function ParamsWithStats(
158158
param_vector::AbstractVector,
159-
ldf::DynamicPPL.LogDensityFunction,
159+
ldf::DynamicPPL.LogDensityFunction{Tlink},
160160
stats::NamedTuple=NamedTuple();
161161
include_colon_eq::Bool=true,
162162
include_log_probs::Bool=true,
163-
)
163+
) where {Tlink}
164164
strategy = InitFromParams(
165-
VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector),
165+
VectorWithRanges{Tlink}(
166+
ldf._iden_varname_ranges, ldf._varname_ranges, param_vector
167+
),
166168
nothing,
167169
)
168170
accs = if include_log_probs

src/contexts/init.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ struct RangeAndLinked
214214
end
215215

216216
"""
217-
VectorWithRanges(
217+
VectorWithRanges{Tlink}(
218218
iden_varname_ranges::NamedTuple,
219219
varname_ranges::Dict{VarName,RangeAndLinked},
220220
vect::AbstractVector{<:Real},
@@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
231231
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
232232
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
233233
"""
234-
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
234+
struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
235235
# This NamedTuple stores the ranges for identity VarNames
236236
iden_varname_ranges::N
237237
# This Dict stores the ranges for all other VarNames
238238
varname_ranges::Dict{VarName,RangeAndLinked}
239239
# The full parameter vector which we index into to get variable values
240240
vect::T
241+
242+
function VectorWithRanges{Tlink}(
243+
iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T
244+
) where {Tlink,N,T}
245+
return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect)
246+
end
241247
end
242248

243249
function _get_range_and_linked(
@@ -252,7 +258,29 @@ function init(
252258
::Random.AbstractRNG,
253259
vn::VarName,
254260
dist::Distribution,
255-
p::InitFromParams{<:VectorWithRanges},
261+
p::InitFromParams{<:VectorWithRanges{true}},
262+
)
263+
vr = p.params
264+
range_and_linked = _get_range_and_linked(vr, vn)
265+
transform = from_linked_vec_transform(dist)
266+
return (@view vr.vect[range_and_linked.range]), transform
267+
end
268+
function init(
269+
::Random.AbstractRNG,
270+
vn::VarName,
271+
dist::Distribution,
272+
p::InitFromParams{<:VectorWithRanges{false}},
273+
)
274+
vr = p.params
275+
range_and_linked = _get_range_and_linked(vr, vn)
276+
transform = from_vec_transform(dist)
277+
return (@view vr.vect[range_and_linked.range]), transform
278+
end
279+
function init(
280+
::Random.AbstractRNG,
281+
vn::VarName,
282+
dist::Distribution,
283+
p::InitFromParams{<:VectorWithRanges{nothing}},
256284
)
257285
vr = p.params
258286
range_and_linked = _get_range_and_linked(vr, vn)

src/fasteval.jl

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,21 @@ in the function name.
6565
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
6666
# it _should_ do, but this is wrong regardless.
6767
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
68-
vi = if Threads.nthreads() > 1
69-
param_eltype = DynamicPPL.get_param_eltype(strategy)
68+
return if Threads.nthreads() > 1
69+
# WARNING: Do NOT move get_param_eltype(strategy) into an intermediate variable, it
70+
# will cause type instabilities! See also unflatten in src/varinfo.jl.
7071
accs = map(accs) do acc
71-
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
72+
DynamicPPL.convert_eltype(
73+
float_type_with_fallback(DynamicPPL.get_param_eltype(strategy)), acc
74+
)
7275
end
73-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
76+
tsvi = ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
77+
retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi)
78+
retval, setaccs!!(tsvi_new.varinfo, getaccs(tsvi_new))
7479
else
75-
OnlyAccsVarInfo(accs)
80+
vi = OnlyAccsVarInfo(accs)
81+
DynamicPPL._evaluate!!(model, vi)
7682
end
77-
return DynamicPPL._evaluate!!(model, vi)
7883
end
7984
@inline function fast_evaluate!!(
8085
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
@@ -194,6 +199,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
194199
`unflatten` + `evaluate!!` approach also fails with such models.
195200
"""
196201
struct LogDensityFunction{
202+
# true if all variables are linked; false if all variables are unlinked; nothing if
203+
# mixed
204+
Tlink,
197205
M<:Model,
198206
AD<:Union{ADTypes.AbstractADType,Nothing},
199207
F<:Function,
@@ -217,6 +225,21 @@ struct LogDensityFunction{
217225
# Figure out which variable corresponds to which index, and
218226
# which variables are linked.
219227
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
228+
# Figure out if all variables are linked, unlinked, or mixed
229+
link_statuses = Bool[]
230+
for ral in all_iden_ranges
231+
push!(link_statuses, ral.is_linked)
232+
end
233+
for (_, ral) in all_ranges
234+
push!(link_statuses, ral.is_linked)
235+
end
236+
Tlink = if all(link_statuses)
237+
true
238+
elseif all(!s for s in link_statuses)
239+
false
240+
else
241+
nothing
242+
end
220243
x = [val for val in varinfo[:]]
221244
dim = length(x)
222245
# Do AD prep if needed
@@ -226,12 +249,13 @@ struct LogDensityFunction{
226249
# Make backend-specific tweaks to the adtype
227250
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
228251
DI.prepare_gradient(
229-
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
252+
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
230253
adtype,
231254
x,
232255
)
233256
end
234257
return new{
258+
Tlink,
235259
typeof(model),
236260
typeof(adtype),
237261
typeof(getlogdensity),
@@ -263,36 +287,45 @@ end
263287
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
264288
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
265289

266-
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
290+
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
267291
model::M
268292
getlogdensity::F
269293
iden_varname_ranges::N
270294
varname_ranges::Dict{VarName,RangeAndLinked}
295+
296+
function LogDensityAt{Tlink}(
297+
model::M,
298+
getlogdensity::F,
299+
iden_varname_ranges::N,
300+
varname_ranges::Dict{VarName,RangeAndLinked},
301+
) where {Tlink,M,F,N}
302+
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
303+
end
271304
end
272-
function (f::LogDensityAt)(params::AbstractVector{<:Real})
305+
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
273306
strategy = InitFromParams(
274-
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
307+
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
275308
)
276309
accs = fast_ldf_accs(f.getlogdensity)
277310
_, vi = fast_evaluate!!(f.model, strategy, accs)
278311
return f.getlogdensity(vi)
279312
end
280313

281314
function LogDensityProblems.logdensity(
282-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
283-
)
284-
return LogDensityAt(
315+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
316+
) where {Tlink}
317+
return LogDensityAt{Tlink}(
285318
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
286319
)(
287320
params
288321
)
289322
end
290323

291324
function LogDensityProblems.logdensity_and_gradient(
292-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
293-
)
325+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
326+
) where {Tlink}
294327
return DI.value_and_gradient(
295-
LogDensityAt(
328+
LogDensityAt{Tlink}(
296329
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
297330
),
298331
ldf._adprep,
@@ -301,12 +334,14 @@ function LogDensityProblems.logdensity_and_gradient(
301334
)
302335
end
303336

304-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
337+
function LogDensityProblems.capabilities(
338+
::Type{<:LogDensityFunction{T,M,Nothing}}
339+
) where {T,M}
305340
return LogDensityProblems.LogDensityOrder{0}()
306341
end
307342
function LogDensityProblems.capabilities(
308-
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
309-
) where {M}
343+
::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}}
344+
) where {T,M}
310345
return LogDensityProblems.LogDensityOrder{1}()
311346
end
312347
function LogDensityProblems.dimension(ldf::LogDensityFunction)

test/fasteval.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ using Mooncake: Mooncake
6969
end
7070
end
7171

72+
@testset "LogDensityFunction: Type stability" begin
73+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
74+
unlinked_vi = DynamicPPL.VarInfo(m)
75+
@testset "$islinked" for islinked in (false, true)
76+
vi = if islinked
77+
DynamicPPL.link!!(unlinked_vi, m)
78+
else
79+
unlinked_vi
80+
end
81+
ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi)
82+
x = vi[:]
83+
@inferred LogDensityProblems.logdensity(ldf, x)
84+
end
85+
end
86+
end
87+
7288
@testset "Fast evaluation: performance" begin
7389
if Threads.nthreads() == 1
7490
# Evaluating these three models with OnlyAccsVarInfo should not lead to any

0 commit comments

Comments
 (0)