Skip to content

Commit 7265885

Browse files
committed
Use ParamsInit for predict; remove setval_and_resample! and friends
1 parent 247e53b commit 7265885

File tree

4 files changed

+41
-176
lines changed

4 files changed

+41
-176
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
46+
_check_varname_indexing(c)
47+
d = Dict{DynamicPPL.VarName,Any}()
48+
for vn in DynamicPPL.varnames(c)
49+
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
50+
end
51+
return d
52+
end
53+
4554
"""
4655
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
4756
@@ -114,9 +123,17 @@ function DynamicPPL.predict(
114123

115124
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116125
predictive_samples = map(iters) do (sample_idx, chain_idx)
117-
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))
119-
126+
# Extract values from the chain
127+
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
128+
# Resample any variables that are not present in `values_dict`
129+
_, varinfo = last(
130+
DynamicPPL.init!!(
131+
rng,
132+
model,
133+
varinfo,
134+
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
135+
),
136+
)
120137
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121138
varname_vals = mapreduce(
122139
collect,
@@ -248,13 +265,15 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
248265
varinfo = DynamicPPL.VarInfo(model)
249266
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
250267
return map(iters) do (sample_idx, chain_idx)
251-
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
252-
# Update the varinfo with the current sample and make variables not present in `chain`
253-
# to be sampled.
254-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
255-
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
256-
# `deepcopy` the `varinfo` before passing it to the `model`.
257-
model(deepcopy(varinfo))
268+
# Extract values from the chain
269+
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
270+
# Resample any variables that are not present in `values_dict`, and
271+
# return the model's retval (`first`).
272+
first(
273+
DynamicPPL.init!!(
274+
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
275+
),
276+
)
258277
end
259278
end
260279

src/model.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,8 +1233,15 @@ function predict(
12331233
varinfo = DynamicPPL.VarInfo(model)
12341234
return map(chain) do params_varinfo
12351235
vi = deepcopy(varinfo)
1236-
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1237-
model(rng, vi)
1236+
# TODO(penelopeysm): Requires two model evaluations, one to extract the
1237+
# parameters and one to set them. The reason why we need values_as_in_model
1238+
# is because `params_varinfo` may well have some weird combination of
1239+
# linked/unlinked, whereas `varinfo` is always unlinked since it is
1240+
# freshly constructed.
1241+
# This is quite inefficient. It would of course be alright if
1242+
# ValuesAsInModelAccumulator was a default acc.
1243+
values_nt = values_as_in_model(model, false, params_varinfo)
1244+
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
12381245
return vi
12391246
end
12401247
end

src/varinfo.jl

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,113 +2045,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke
20452045
return indices
20462046
end
20472047

2048-
"""
2049-
setval_and_resample!(vi::VarInfo, x)
2050-
setval_and_resample!(vi::VarInfo, values, keys)
2051-
setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx)
2052-
2053-
Set the values in `vi` to the provided values and those which are not present
2054-
in `x` or `chains` to *be* resampled.
2055-
2056-
Note that this does *not* resample the values not provided! It will call
2057-
`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means
2058-
that the next time we call `model(vi)` these variables will be resampled.
2059-
2060-
## Note
2061-
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
2062-
2063-
## Example
2064-
```jldoctest
2065-
julia> using DynamicPPL, Distributions, StableRNGs
2066-
2067-
julia> @model function demo(x)
2068-
m ~ Normal()
2069-
for i in eachindex(x)
2070-
x[i] ~ Normal(m, 1)
2071-
end
2072-
end;
2073-
2074-
julia> rng = StableRNG(42);
2075-
2076-
julia> m = demo([missing]);
2077-
2078-
julia> var_info = DynamicPPL.VarInfo(rng, m);
2079-
# Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set.
2080-
2081-
julia> var_info[@varname(m)]
2082-
-0.6702516921145671
2083-
2084-
julia> var_info[@varname(x[1])]
2085-
-0.22312984965118443
2086-
2087-
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
2088-
2089-
julia> var_info[@varname(m)] # [✓] changed
2090-
100.0
2091-
2092-
julia> var_info[@varname(x[1])] # [✓] unchanged
2093-
-0.22312984965118443
2094-
2095-
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
2096-
2097-
julia> var_info[@varname(m)] # [✓] unchanged
2098-
100.0
2099-
2100-
julia> var_info[@varname(x[1])] # [✓] changed
2101-
101.37363069798343
2102-
```
2103-
2104-
## See also
2105-
- [`setval!`](@ref)
2106-
"""
2107-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x)
2108-
return setval_and_resample!(vi, values(x), keys(x))
2109-
end
2110-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys)
2111-
return _apply!(_setval_and_resample_kernel!, vi, values, keys)
2112-
end
2113-
function setval_and_resample!(
2114-
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
2115-
)
2116-
if supports_varname_indexing(chains)
2117-
# First we need to set every variable to be resampled.
2118-
for vn in keys(vi)
2119-
set_flag!(vi, vn, "del")
2120-
end
2121-
# Then we set the variables in `varinfo` from `chain`.
2122-
for vn in varnames(chains)
2123-
vn_updated = nested_setindex_maybe!(
2124-
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
2125-
)
2126-
2127-
# Unset the `del` flag if we found something.
2128-
if vn_updated !== nothing
2129-
# NOTE: This will be triggered even if only a subset of a variable has been set!
2130-
unset_flag!(vi, vn_updated, "del")
2131-
end
2132-
end
2133-
else
2134-
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
2135-
end
2136-
end
2137-
2138-
function _setval_and_resample_kernel!(
2139-
vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys
2140-
)
2141-
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
2142-
if !isempty(indices)
2143-
val = reduce(vcat, values[indices])
2144-
setval!(vi, val, vn)
2145-
settrans!!(vi, false, vn)
2146-
else
2147-
# Ensures that we'll resample the variable corresponding to `vn` if we run
2148-
# the model on `vi` again.
2149-
set_flag!(vi, vn, "del")
2150-
end
2151-
2152-
return indices
2153-
end
2154-
21552048
values_as(vi::VarInfo) = vi.metadata
21562049
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
21572050
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})

test/varinfo.jl

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ end
278278
@test typed_vi[vn_y] == 2.0
279279
end
280280

281-
@testset "setval! & setval_and_resample!" begin
281+
@testset "setval!" begin
282282
@model function testmodel(x)
283283
n = length(x)
284284
s ~ truncated(Normal(); lower=0)
@@ -329,8 +329,8 @@ end
329329
else
330330
DynamicPPL.setval!(vicopy, (m=zeros(5),))
331331
end
332-
# Setting `m` fails for univariate due to limitations of `setval!`
333-
# and `setval_and_resample!`. See docstring of `setval!` for more info.
332+
# Setting `m` fails for univariate due to limitations of `setval!`.
333+
# See docstring of `setval!` for more info.
334334
if model == model_uv && vi in [vi_untyped, vi_typed]
335335
@test_broken vicopy[m_vns] == zeros(5)
336336
else
@@ -355,57 +355,6 @@ end
355355
DynamicPPL.setval!(vicopy, (s=42,))
356356
@test vicopy[m_vns] == 1:5
357357
@test vicopy[s_vns] == 42
358-
359-
### `setval_and_resample!` ###
360-
if model == model_mv && vi == vi_untyped
361-
# Trying to re-run model with `MvNormal` on `vi_untyped` will call
362-
# `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError`
363-
# so we skip this particular case.
364-
continue
365-
end
366-
367-
if vi in [vi_vnv, vi_vnv_typed]
368-
# `setval_and_resample!` works differently for `VarNamedVector`: All
369-
# values will be resampled when model(vicopy) is called. Hence the below
370-
# tests are not applicable.
371-
continue
372-
end
373-
374-
vicopy = deepcopy(vi)
375-
DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),))
376-
model(vicopy)
377-
# Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)`
378-
if model == model_uv
379-
@test_broken vicopy[m_vns] == zeros(5)
380-
else
381-
@test vicopy[m_vns] == zeros(5)
382-
end
383-
@test vicopy[s_vns] != vi[s_vns]
384-
385-
# Ordering is NOT preserved.
386-
DynamicPPL.setval_and_resample!(
387-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)
388-
)
389-
model(vicopy)
390-
if model == model_uv
391-
@test vicopy[m_vns] == 1:5
392-
else
393-
@test vicopy[m_vns] == [1, 3, 5, 4, 2]
394-
end
395-
@test vicopy[s_vns] != vi[s_vns]
396-
397-
# Correct ordering.
398-
DynamicPPL.setval_and_resample!(
399-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...)
400-
)
401-
model(vicopy)
402-
@test vicopy[m_vns] == 1:5
403-
@test vicopy[s_vns] != vi[s_vns]
404-
405-
DynamicPPL.setval_and_resample!(vicopy, (s=42,))
406-
model(vicopy)
407-
@test vicopy[m_vns] != 1:5
408-
@test vicopy[s_vns] == 42
409358
end
410359
end
411360

@@ -419,9 +368,6 @@ end
419368
ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])]
420369
DynamicPPL.setval!(vi, vi.metadata.x.vals, ks)
421370
@test vals_prev == vi.metadata.x.vals
422-
423-
DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks)
424-
@test vals_prev == vi.metadata.x.vals
425371
end
426372

427373
@testset "setval! on chain" begin

0 commit comments

Comments
 (0)