Skip to content

Commit cf33cff

Browse files
committed
Use init!! instead of fast_evaluate!!
1 parent e165249 commit cf33cff

File tree

5 files changed

+49
-78
lines changed

5 files changed

+49
-78
lines changed

docs/src/api.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,10 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
6666
LogDensityFunction
6767
```
6868

69-
Internally, this is accomplished using:
69+
Internally, this is accomplished using [`init!!`](@ref) on:
7070

7171
```@docs
7272
OnlyAccsVarInfo
73-
fast_evaluate!!
7473
```
7574

7675
## Condition and decondition
@@ -517,7 +516,7 @@ The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
517516
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.
518517

519518
```@docs
520-
DynamicPPL.init!!
519+
init!!
521520
```
522521

523522
To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.

src/DynamicPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ export AbstractVarInfo,
9292
getargnames,
9393
extract_priors,
9494
values_as_in_model,
95+
# evaluation
96+
evaluate!!,
97+
init!!,
9598
# LogDensityFunction and fasteval
9699
LogDensityFunction,
97-
fast_evaluate!!,
98100
OnlyAccsVarInfo,
99101
# Leaf contexts
100102
AbstractContext,

src/chains.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
"""
138138
ParamsWithStats(
139139
param_vector::AbstractVector,
140-
ldf::DynamicPPL.Experimental.FastLDF,
140+
ldf::DynamicPPL.LogDensityFunction,
141141
stats::NamedTuple=NamedTuple();
142142
include_colon_eq::Bool=true,
143143
include_log_probs::Bool=true,
@@ -156,7 +156,7 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
156156
"""
157157
function ParamsWithStats(
158158
param_vector::AbstractVector,
159-
ldf::DynamicPPL.Experimental.FastLDF,
159+
ldf::DynamicPPL.LogDensityFunction,
160160
stats::NamedTuple=NamedTuple();
161161
include_colon_eq::Bool=true,
162162
include_log_probs::Bool=true,
@@ -174,9 +174,7 @@ function ParamsWithStats(
174174
else
175175
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
176176
end
177-
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
178-
ldf.model, strategy, AccumulatorTuple(accs)
179-
)
177+
_, vi = DynamicPPL.init!!(ldf.model, OnlyAccsVarInfo(AccumulatorTuple(accs)), strategy)
180178
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
181179
if include_log_probs
182180
stats = merge(

src/fasteval.jl

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -29,60 +29,6 @@ using LogDensityProblems: LogDensityProblems
2929
import DifferentiationInterface as DI
3030
using Random: Random
3131

32-
"""
33-
DynamicPPL.fast_evaluate!!(
34-
[rng::Random.AbstractRNG,]
35-
model::Model,
36-
strategy::AbstractInitStrategy,
37-
accs::AccumulatorTuple,
38-
)
39-
40-
Evaluate a model using parameters obtained via `strategy`, and only computing the results in
41-
the provided accumulators.
42-
43-
It is assumed that the accumulators passed in have been initialised to appropriate values,
44-
as this function will not reset them. The default constructors for each accumulator will do
45-
this for you correctly.
46-
47-
Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
48-
argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
49-
in the function name.
50-
"""
51-
@inline function fast_evaluate!!(
52-
# Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
53-
# to extra allocations (even for trivial models) and much slower runtime.
54-
rng::Random.AbstractRNG,
55-
model::Model,
56-
strategy::AbstractInitStrategy,
57-
accs::AccumulatorTuple,
58-
)
59-
ctx = InitContext(rng, strategy)
60-
model = DynamicPPL.setleafcontext(model, ctx)
61-
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
62-
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
63-
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
64-
# here.
65-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
66-
# it _should_ do, but this is wrong regardless.
67-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
68-
vi = if Threads.nthreads() > 1
69-
param_eltype = DynamicPPL.get_param_eltype(strategy)
70-
accs = map(accs) do acc
71-
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
72-
end
73-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
74-
else
75-
OnlyAccsVarInfo(accs)
76-
end
77-
return DynamicPPL._evaluate!!(model, vi)
78-
end
79-
@inline function fast_evaluate!!(
80-
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
81-
)
82-
# This `@inline` is also mandatory for performance
83-
return fast_evaluate!!(Random.default_rng(), model, strategy, accs)
84-
end
85-
8632
"""
8733
DynamicPPL.LogDensityFunction(
8834
model::Model,
@@ -274,7 +220,7 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real})
274220
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
275221
)
276222
accs = fast_ldf_accs(f.getlogdensity)
277-
_, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs)
223+
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
278224
return f.getlogdensity(vi)
279225
end
280226

src/model.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -881,30 +881,56 @@ end
881881
[init_strategy::AbstractInitStrategy=InitFromPrior()]
882882
)
883883
884-
Evaluate the `model` and replace the values of the model's random variables
885-
in the given `varinfo` with new values, using a specified initialisation strategy.
886-
If the values in `varinfo` are not set, they will be added
887-
using a specified initialisation strategy.
884+
Evaluate the `model` and replace the values of the model's random variables in the given
885+
`varinfo` with new values, using a specified initialisation strategy. If the values in
886+
`varinfo` are not set, they will be added using a specified initialisation strategy.
888887
889888
If `init_strategy` is not provided, defaults to `InitFromPrior()`.
890889
891890
Returns a tuple of the model's return value, plus the updated `varinfo` object.
892891
"""
893-
function init!!(
892+
@inline function init!!(
893+
# Note that this `@inline` is mandatory for performance, especially for
894+
# LogDensityFunction. If it's not inlined, it leads to extra allocations (even for
895+
# trivial models) and much slower runtime.
894896
rng::Random.AbstractRNG,
895897
model::Model,
896-
varinfo::AbstractVarInfo,
897-
init_strategy::AbstractInitStrategy=InitFromPrior(),
898+
vi::AbstractVarInfo,
899+
strategy::AbstractInitStrategy=InitFromPrior(),
898900
)
899-
new_model = setleafcontext(model, InitContext(rng, init_strategy))
900-
return evaluate!!(new_model, varinfo)
901+
ctx = InitContext(rng, strategy)
902+
model = DynamicPPL.setleafcontext(model, ctx)
903+
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
904+
# it _should_ do, but this is wrong regardless.
905+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
906+
return if Threads.nthreads() > 1
907+
# TODO(penelopeysm): The logic for setting eltype of accs is very similar to that
908+
# used in `unflatten`. The reason why we need it here is because the VarInfo `vi`
909+
# won't have been filled with parameters prior to `init!!` being called.
910+
#
911+
# Note that this eltype promotion is only needed for threadsafe evaluation. In an
912+
# ideal world, this code should be handled inside `evaluate_threadsafe!!` or a
913+
# similar method. In other words, it should not be here, and it should not be inside
914+
# `unflatten` either. The problem is performance. Shifting this code around can have
915+
# massive, inexplicable, impacts on performance. This should be investigated
916+
# properly.
917+
param_eltype = DynamicPPL.get_param_eltype(strategy)
918+
accs = map(vi.accs) do acc
919+
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
920+
end
921+
vi = DynamicPPL.setaccs!!(vi, accs)
922+
tsvi = ThreadSafeVarInfo(resetaccs!!(vi))
923+
retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi)
924+
return retval, setaccs!!(vi, DynamicPPL.getaccs(tsvi_new))
925+
else
926+
return DynamicPPL._evaluate!!(model, resetaccs!!(vi))
927+
end
901928
end
902-
function init!!(
903-
model::Model,
904-
varinfo::AbstractVarInfo,
905-
init_strategy::AbstractInitStrategy=InitFromPrior(),
929+
@inline function init!!(
930+
model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior()
906931
)
907-
return init!!(Random.default_rng(), model, varinfo, init_strategy)
932+
# This `@inline` is also mandatory for performance
933+
return init!!(Random.default_rng(), model, vi, strategy)
908934
end
909935

910936
"""

0 commit comments

Comments
 (0)