Skip to content

Commit f856389

Browse files
committed
Use ParamsInit for predict; remove setval_and_resample! and friends
1 parent 7bf8aba commit f856389

File tree

7 files changed

+49
-216
lines changed

7 files changed

+49
-216
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
function _check_varname_indexing(c::MCMCChains.Chains)
3030
return DynamicPPL.supports_varname_indexing(c) ||
31-
error("Chains do not support indexing using `VarName`s.")
31+
error("This `Chains` object does not support indexing using `VarName`s.")
3232
end
3333

3434
function DynamicPPL.getindex_varname(
@@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
46+
_check_varname_indexing(c)
47+
d = Dict{DynamicPPL.VarName,Any}()
48+
for vn in DynamicPPL.varnames(c)
49+
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
50+
end
51+
return d
52+
end
53+
4554
"""
4655
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
4756
@@ -114,9 +123,15 @@ function DynamicPPL.predict(
114123

115124
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116125
predictive_samples = map(iters) do (sample_idx, chain_idx)
117-
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))
119-
126+
# Extract values from the chain
127+
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
128+
# Resample any variables that are not present in `values_dict`
129+
_, varinfo = DynamicPPL.init!!(
130+
rng,
131+
model,
132+
varinfo,
133+
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
134+
)
120135
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121136
varname_vals = mapreduce(
122137
collect,
@@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
248263
varinfo = DynamicPPL.VarInfo(model)
249264
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
250265
return map(iters) do (sample_idx, chain_idx)
251-
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
252-
# Update the varinfo with the current sample and make variables not present in `chain`
253-
# to be sampled.
254-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
255-
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
256-
# `deepcopy` the `varinfo` before passing it to the `model`.
257-
model(deepcopy(varinfo))
266+
# Extract values from the chain
267+
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
268+
# Resample any variables that are not present in `values_dict`, and
269+
# return the model's retval.
270+
retval, _ = DynamicPPL.init!!(
271+
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
272+
)
273+
retval
258274
end
259275
end
260276

src/model.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,8 +1200,15 @@ function predict(
12001200
varinfo = DynamicPPL.VarInfo(model)
12011201
return map(chain) do params_varinfo
12021202
vi = deepcopy(varinfo)
1203-
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1204-
model(rng, vi)
1203+
# TODO(penelopeysm): Requires two model evaluations, one to extract the
1204+
# parameters and one to set them. The reason why we need values_as_in_model
1205+
# is because `params_varinfo` may well have some weird combination of
1206+
# linked/unlinked, whereas `varinfo` is always unlinked since it is
1207+
# freshly constructed.
1208+
# This is quite inefficient. It would of course be alright if
1209+
# ValuesAsInModelAccumulator was a default acc.
1210+
values_nt = values_as_in_model(model, false, params_varinfo)
1211+
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
12051212
return vi
12061213
end
12071214
end

src/varinfo.jl

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,42 +1506,6 @@ function islinked(vi::VarInfo)
15061506
return any(istrans(vi, vn) for vn in keys(vi))
15071507
end
15081508

1509-
function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
1510-
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1511-
end
1512-
function nested_setindex_maybe!(
1513-
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
1514-
) where {names,sym}
1515-
return if sym in names
1516-
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1517-
else
1518-
nothing
1519-
end
1520-
end
1521-
function _nested_setindex_maybe!(
1522-
vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName
1523-
)
1524-
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
1525-
vns = Base.keys(md)
1526-
if vn in vns
1527-
setindex!(vi, val, vn)
1528-
return vn
1529-
end
1530-
1531-
# Otherwise, we need to check if either of the `vns` subsumes `vn`.
1532-
i = findfirst(Base.Fix2(subsumes, vn), vns)
1533-
i === nothing && return nothing
1534-
1535-
vn_parent = vns[i]
1536-
val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here.
1537-
# Split the varname into its tail optic.
1538-
optic = remove_parent_optic(vn_parent, vn)
1539-
# Update the value for the parent.
1540-
val_parent_updated = set!!(val_parent, optic, val)
1541-
setindex!(vi, val_parent_updated, vn_parent)
1542-
return vn_parent
1543-
end
1544-
15451509
# The default getindex & setindex!() for get & set values
15461510
# NOTE: vi[vn] will always transform the variable to its original space and Julia type
15471511
function getindex(vi::VarInfo, vn::VarName)
@@ -2045,113 +2009,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke
20452009
return indices
20462010
end
20472011

2048-
"""
2049-
setval_and_resample!(vi::VarInfo, x)
2050-
setval_and_resample!(vi::VarInfo, values, keys)
2051-
setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx)
2052-
2053-
Set the values in `vi` to the provided values and those which are not present
2054-
in `x` or `chains` to *be* resampled.
2055-
2056-
Note that this does *not* resample the values not provided! It will call
2057-
`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means
2058-
that the next time we call `model(vi)` these variables will be resampled.
2059-
2060-
## Note
2061-
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
2062-
2063-
## Example
2064-
```jldoctest
2065-
julia> using DynamicPPL, Distributions, StableRNGs
2066-
2067-
julia> @model function demo(x)
2068-
m ~ Normal()
2069-
for i in eachindex(x)
2070-
x[i] ~ Normal(m, 1)
2071-
end
2072-
end;
2073-
2074-
julia> rng = StableRNG(42);
2075-
2076-
julia> m = demo([missing]);
2077-
2078-
julia> var_info = DynamicPPL.VarInfo(rng, m);
2079-
# Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set.
2080-
2081-
julia> var_info[@varname(m)]
2082-
-0.6702516921145671
2083-
2084-
julia> var_info[@varname(x[1])]
2085-
-0.22312984965118443
2086-
2087-
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
2088-
2089-
julia> var_info[@varname(m)] # [✓] changed
2090-
100.0
2091-
2092-
julia> var_info[@varname(x[1])] # [✓] unchanged
2093-
-0.22312984965118443
2094-
2095-
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
2096-
2097-
julia> var_info[@varname(m)] # [✓] unchanged
2098-
100.0
2099-
2100-
julia> var_info[@varname(x[1])] # [✓] changed
2101-
101.37363069798343
2102-
```
2103-
2104-
## See also
2105-
- [`setval!`](@ref)
2106-
"""
2107-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x)
2108-
return setval_and_resample!(vi, values(x), keys(x))
2109-
end
2110-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys)
2111-
return _apply!(_setval_and_resample_kernel!, vi, values, keys)
2112-
end
2113-
function setval_and_resample!(
2114-
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
2115-
)
2116-
if supports_varname_indexing(chains)
2117-
# First we need to set every variable to be resampled.
2118-
for vn in keys(vi)
2119-
set_flag!(vi, vn, "del")
2120-
end
2121-
# Then we set the variables in `varinfo` from `chain`.
2122-
for vn in varnames(chains)
2123-
vn_updated = nested_setindex_maybe!(
2124-
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
2125-
)
2126-
2127-
# Unset the `del` flag if we found something.
2128-
if vn_updated !== nothing
2129-
# NOTE: This will be triggered even if only a subset of a variable has been set!
2130-
unset_flag!(vi, vn_updated, "del")
2131-
end
2132-
end
2133-
else
2134-
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
2135-
end
2136-
end
2137-
2138-
function _setval_and_resample_kernel!(
2139-
vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys
2140-
)
2141-
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
2142-
if !isempty(indices)
2143-
val = reduce(vcat, values[indices])
2144-
setval!(vi, val, vn)
2145-
settrans!!(vi, false, vn)
2146-
else
2147-
# Ensures that we'll resample the variable corresponding to `vn` if we run
2148-
# the model on `vi` again.
2149-
set_flag!(vi, vn, "del")
2150-
end
2151-
2152-
return indices
2153-
end
2154-
21552012
values_as(vi::VarInfo) = vi.metadata
21562013
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
21572014
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
@model demo() = x ~ Normal()
33
model = demo()
44

5-
chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y]))
5+
chain = MCMCChains.Chains(
6+
randn(1000, 2, 1),
7+
[:x, :y],
8+
Dict(:internals => [:y]);
9+
info=(; varname_to_symbol=Dict(@varname(x) => :x)),
10+
)
611
chain_generated = @test_nowarn returned(model, chain)
712
@test size(chain_generated) == (1000, 1)
813
@test mean(chain_generated) 0 atol = 0.1

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
573573
xs_train = 1:0.1:10
574574
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
575575
m_lin_reg = linear_reg(xs_train, ys_train)
576-
chain = [VarInfo(m_lin_reg) _ in 1:10000]
576+
chain = [VarInfo(m_lin_reg) for _ in 1:10000]
577577

578578
# chain is generated from the prior
579579
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1

test/test_util.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
8181
varnames = collect(varnames)
8282
# Construct matrix of values
8383
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
84+
# Construct dict of varnames -> symbol
85+
vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames)))
8486
# Construct and return the Chains object
85-
return Chains(vals, varnames)
87+
return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict))
8688
end
8789
function make_chain_from_prior(model::Model, n_iters::Int)
8890
return make_chain_from_prior(Random.default_rng(), model, n_iters)

test/varinfo.jl

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ end
278278
@test typed_vi[vn_y] == 2.0
279279
end
280280

281-
@testset "setval! & setval_and_resample!" begin
281+
@testset "setval!" begin
282282
@model function testmodel(x)
283283
n = length(x)
284284
s ~ truncated(Normal(); lower=0)
@@ -329,8 +329,8 @@ end
329329
else
330330
DynamicPPL.setval!(vicopy, (m=zeros(5),))
331331
end
332-
# Setting `m` fails for univariate due to limitations of `setval!`
333-
# and `setval_and_resample!`. See docstring of `setval!` for more info.
332+
# Setting `m` fails for univariate due to limitations of `setval!`.
333+
# See docstring of `setval!` for more info.
334334
if model == model_uv && vi in [vi_untyped, vi_typed]
335335
@test_broken vicopy[m_vns] == zeros(5)
336336
else
@@ -355,57 +355,6 @@ end
355355
DynamicPPL.setval!(vicopy, (s=42,))
356356
@test vicopy[m_vns] == 1:5
357357
@test vicopy[s_vns] == 42
358-
359-
### `setval_and_resample!` ###
360-
if model == model_mv && vi == vi_untyped
361-
# Trying to re-run model with `MvNormal` on `vi_untyped` will call
362-
# `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError`
363-
# so we skip this particular case.
364-
continue
365-
end
366-
367-
if vi in [vi_vnv, vi_vnv_typed]
368-
# `setval_and_resample!` works differently for `VarNamedVector`: All
369-
# values will be resampled when model(vicopy) is called. Hence the below
370-
# tests are not applicable.
371-
continue
372-
end
373-
374-
vicopy = deepcopy(vi)
375-
DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),))
376-
model(vicopy)
377-
# Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)`
378-
if model == model_uv
379-
@test_broken vicopy[m_vns] == zeros(5)
380-
else
381-
@test vicopy[m_vns] == zeros(5)
382-
end
383-
@test vicopy[s_vns] != vi[s_vns]
384-
385-
# Ordering is NOT preserved.
386-
DynamicPPL.setval_and_resample!(
387-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)
388-
)
389-
model(vicopy)
390-
if model == model_uv
391-
@test vicopy[m_vns] == 1:5
392-
else
393-
@test vicopy[m_vns] == [1, 3, 5, 4, 2]
394-
end
395-
@test vicopy[s_vns] != vi[s_vns]
396-
397-
# Correct ordering.
398-
DynamicPPL.setval_and_resample!(
399-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...)
400-
)
401-
model(vicopy)
402-
@test vicopy[m_vns] == 1:5
403-
@test vicopy[s_vns] != vi[s_vns]
404-
405-
DynamicPPL.setval_and_resample!(vicopy, (s=42,))
406-
model(vicopy)
407-
@test vicopy[m_vns] != 1:5
408-
@test vicopy[s_vns] == 42
409358
end
410359
end
411360

@@ -419,9 +368,6 @@ end
419368
ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])]
420369
DynamicPPL.setval!(vi, vi.metadata.x.vals, ks)
421370
@test vals_prev == vi.metadata.x.vals
422-
423-
DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks)
424-
@test vals_prev == vi.metadata.x.vals
425371
end
426372

427373
@testset "setval! on chain" begin

0 commit comments

Comments
 (0)