From 1c1c90735fe9664914240e511c7938f494a178e1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 12 Nov 2024 14:00:32 +0000 Subject: [PATCH 01/14] move `predict` from Turing --- ext/DynamicPPLMCMCChainsExt.jl | 301 ++++++++++++++++++++++++++++ src/DynamicPPL.jl | 2 +- src/model.jl | 14 ++ test/Project.toml | 2 + test/ext/DynamicPPLMCMCChainsExt.jl | 158 +++++++++++++++ test/runtests.jl | 1 + 6 files changed, 477 insertions(+), 1 deletion(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 8a2679d09..457c85bfb 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -42,6 +42,307 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +# this is copied from Turing.jl, `stats` field is omitted as it is never used +struct Transition{T,F} + θ::T + lp::F +end + +function Transition(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) + return Transition(getparams(model, vi), DynamicPPL.getlogp(vi)) +end + +# a copy of Turing.Inference.getparams +getparams(model, t) = t.θ +function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) + # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used. + # Unfortunately, using `invlink` can cause issues in scenarios where the constraints + # of the parameters change depending on the realizations. Hence we have to use + # `values_as_in_model`, which re-runs the model and extracts the parameters + # as they are seen in the model, i.e. in the constrained space. Moreover, + # this means that the code below will work both of linked and invlinked `vi`. + # Ref: https://github.com/TuringLang/Turing.jl/issues/2195 + # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`. + vals = DynamicPPL.values_as_in_model(model, deepcopy(vi)) + + # Obtain an iterator over the flattened parameter names and values. + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + + # Materialize the iterators and concatenate. + return mapreduce(collect, vcat, iters) +end + +function _params_to_array(model::DynamicPPL.Model, ts::Vector) + names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() + # Extract the parameter names and values from each transition. + dicts = map(ts) do t + nms_and_vs = getparams(model, t) + nms = map(first, nms_and_vs) + vs = map(last, nms_and_vs) + for nm in nms + push!(names_set, nm) + end + # Convert the names and values to a single dictionary. + return DynamicPPL.OrderedCollections.OrderedDict(zip(nms, vs)) + end + names = collect(names_set) + vals = [ + get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names) + ] + + return names, vals +end + +""" + + predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) + +Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`. + +If `include_all` is `false`, the returned `Chains` will contain only those variables +sampled/not present in `chain`. + +# Details +Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples +and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`. + +# Example +```jldoctest +julia> using AbstractMCMC, AdvancedHMC, DynamicPPL, ForwardDiff; +[ Info: [Turing]: progress logging is disabled globally + +julia> @model function linear_reg(x, y, σ = 0.1) + β ~ Normal(0, 1) + + for i ∈ eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end + end; + +julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn(); + +julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train); + +julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test); + +julia> m_train = linear_reg(xs_train, ys_train, σ); + +julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train)); + +julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β]); +┌ Info: Found initial step size +└ ϵ = 0.003125 + +julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ); + +julia> predictions = predict(m_test, chain_lin_reg) +Object of type Chains, with data of type 100×2×1 Array{Float64,3} + +Iterations = 1:100 +Thinning interval = 1 +Chains = 1 +Samples per chain = 100 +parameters = y[1], y[2] + +2-element Array{ChainDataFrame,1} + +Summary Statistics + parameters mean std naive_se mcse ess r_hat + ────────── ─────── ────── ──────── ─────── ──────── ────── + y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922 + y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903 + +Quantiles + parameters 2.5% 25.0% 50.0% 75.0% 97.5% + ────────── ─────── ─────── ─────── ─────── ─────── + y[1] 20.0342 20.1188 20.2135 20.2588 20.4188 + y[2] 20.1870 20.3178 20.3839 20.4466 20.5895 + + +julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1)); + +julia> sum(abs2, ys_test - ys_pred) ≤ 0.1 +true +``` +""" +function DynamicPPL.predict( + rng::DynamicPPL.Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains; + include_all=false, +) + # Don't need all the diagnostics + chain_parameters = MCMCChains.get_sections(chain, :parameters) + + spl = DynamicPPL.SampleFromPrior() + + # Sample transitions using `spl` conditioned on values in `chain` + transitions = transitions_from_chain(rng, model, chain_parameters; sampler=spl) + + # Let the Turing internals handle everything else for you + chain_result = reduce( + MCMCChains.chainscat, + [ + _bundle_samples(transitions[:, chain_idx], model, spl) for + chain_idx in 1:size(transitions, 2) + ], + ) + + parameter_names = if include_all + MCMCChains.names(chain_result, :parameters) + else + filter( + k -> !(k in MCMCChains.names(chain_parameters, :parameters)), + names(chain_result, :parameters), + ) + end + + return chain_result[parameter_names] +end + +getlogp(t::Transition) = t.lp + +function get_transition_extras(ts::AbstractVector{<:Transition}) + valmat = reshape([getlogp(t) for t in ts], :, 1) + return [:lp], valmat +end + +function names_values(extra_data::AbstractVector{<:NamedTuple{names}}) where {names} + values = [getfield(data, name) for data in extra_data, name in names] + return collect(names), values +end + +function names_values(xs::AbstractVector{<:NamedTuple}) + # Obtain all parameter names. + names_set = Set{Symbol}() + for x in xs + for k in keys(x) + push!(names_set, k) + end + end + names_unique = collect(names_set) + + # Extract all values as matrix. + values = [haskey(x, name) ? x[name] : missing for x in xs, name in names_unique] + + return names_unique, values +end + +getlogevidence(transitions, sampler, state) = missing + +# this is copied from Turing.jl/src/mcmc/Inference.jl, types are more restrictive (removed types that are defined in Turing) +# the function is simplified, so that unused arguments are removed +function _bundle_samples( + ts::Vector{<:Transition}, model::DynamicPPL.Model, spl::DynamicPPL.SampleFromPrior +) + # Convert transitions to array format. + # Also retrieve the variable names. + varnames, vals = _params_to_array(model, ts) + varnames_symbol = map(Symbol, varnames) + + # Get the values of the extra parameters in each transition. + extra_params, extra_values = get_transition_extras(ts) + + # Extract names & construct param array. + nms = [varnames_symbol; extra_params] + parray = hcat(vals, extra_values) + + # Set up the info tuple. + info = NamedTuple() + + info = merge( + info, + ( + varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict( + zip(varnames, varnames_symbol) + ), + ), + ) + + # Conretize the array before giving it to MCMCChains. + parray = MCMCChains.concretize(parray) + + # Chain construction. + chain = MCMCChains.Chains(parray, nms, (internals=extra_params,)) + + return chain +end + +""" + transitions_from_chain( + [rng::AbstractRNG,] + model::Model, + chain::MCMCChains.Chains; + sampler = DynamicPPL.SampleFromPrior() + ) + +Execute `model` conditioned on each sample in `chain`, and return resulting transitions. + +The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`. + +# Details + +In a bit more detail, the process is as follows: +1. For every `sample` in `chain` + 1. For every `variable` in `sample` + 1. Set `variable` in `model` to its value in `sample` + 2. Execute `model` with variables fixed as above, sampling variables NOT present + in `chain` using `SampleFromPrior` + 3. Return sampled variables and log-joint + +# Example +```julia-repl +julia> using Turing + +julia> @model function demo() + m ~ Normal(0, 1) + x ~ Normal(m, 1) + end; + +julia> m = demo(); + +julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m` + +julia> transitions = Turing.Inference.transitions_from_chain(m, chain); + +julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints +2-element Array{Float64,1}: + -3.6294991938628374 + -2.5697948166987845 + +julia> [first(t.θ.x) for t in transitions] # extract samples for `x` +2-element Array{Array{Float64,1},1}: + [-2.0844148956440796] + [-1.704630494695469] +``` +""" +function transitions_from_chain( + model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs... +) + return transitions_from_chain(Random.default_rng(), model, chain; kwargs...) +end + +function transitions_from_chain( + rng::DynamicPPL.Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains; + sampler=DynamicPPL.SampleFromPrior(), +) + vi = DynamicPPL.VarInfo(model) + + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + transitions = map(iters) do (sample_idx, chain_idx) + # Set variables present in `chain` and mark those NOT present in chain to be resampled. + DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx) + model(rng, vi, sampler) + + # Convert `VarInfo` into `NamedTuple` and save. + Transition(model, vi) + end + + return transitions +end + """ generated_quantities(model::Model, chain::MCMCChains.Chains) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a5d178125..e8cee3b08 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -5,7 +5,7 @@ using AbstractPPL using Bijectors using Compat using Distributions -using OrderedCollections: OrderedDict +using OrderedCollections: OrderedCollections, OrderedDict using AbstractMCMC: AbstractMCMC using ADTypes: ADTypes diff --git a/src/model.jl b/src/model.jl index 2a1a6db88..82654e2ec 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1203,6 +1203,20 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end +""" + predict([rng::AbstractRNG,] model::Model, chain; include_all=false) + +Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample +in `chain`, and return the resulting `Chains`. At the moment, `chain` must be a `MCMCChains.Chains` object. + +If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by +the samples in `chain`. This is useful when you want to sample only new variables from the posterior +predictive distribution. +""" +function predict(model::Model, chain; include_all=false) + return predict(Random.default_rng(), model, chain; include_all) +end + """ generated_quantities(model::Model, parameters::NamedTuple) generated_quantities(model::Model, values, keys) diff --git a/test/Project.toml b/test/Project.toml index 36fcd1b69..03a87e5b8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" @@ -32,6 +33,7 @@ AbstractMCMC = "5" AbstractPPL = "0.8.4, 0.9" Accessors = "0.1" Bijectors = "0.13.9, 0.14" +AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" Combinatorics = "1" Compat = "4.3.0" Distributions = "0.25" diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index c19bf6f2d..c5a5c4f74 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -7,3 +7,161 @@ @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 end + +@testset "predict" begin + DynamicPPL.Random.seed!(100) + + @model function linear_reg(x, y, σ=0.1) + β ~ Normal(0, 1) + + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end + end + + @model function linear_reg_vec(x, y, σ=0.1) + β ~ Normal(0, 1) + return y ~ MvNormal(β .* x, σ^2 * I) + end + + f(x) = 2 * x + 0.1 * randn() + + Δ = 0.1 + xs_train = 0:Δ:10 + ys_train = f.(xs_train) + xs_test = [10 + Δ, 10 + 2 * Δ] + ys_test = f.(xs_test) + + # Infer + m_lin_reg = linear_reg(xs_train, ys_train) + chain_lin_reg = sample( + DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), + AdvancedHMC.NUTS(0.65), + 200; + chain_type=MCMCChains.Chains, + param_names=[:β], + ) + + # Predict on two last indices + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) + predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) + + ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) + + @test sum(abs2, ys_test - ys_pred) ≤ 0.1 + + # Ensure that `rng` is respected + predictions1 = let rng = MersenneTwister(42) + DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) + end + predictions2 = let rng = MersenneTwister(42) + DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) + end + @test all(Array(predictions1) .== Array(predictions2)) + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing) + predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) + ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) + + @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 + + # Multiple chains + chain_lin_reg = sample( + DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), + AdvancedHMC.NUTS(0.65), + MCMCThreads(), + 200, + 2; + chain_type=MCMCChains.Chains, + param_names=[:β], + ) + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) + predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) + + @test size(chain_lin_reg, 3) == size(predictions, 3) + + for chain_idx in MCMCChains.chains(chain_lin_reg) + ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) + @test sum(abs2, ys_test - ys_pred) ≤ 0.1 + end + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing) + predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) + + for chain_idx in MCMCChains.chains(chain_lin_reg) + ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) + @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 + end + + # https://github.com/TuringLang/Turing.jl/issues/1352 + @model function simple_linear1(x, y) + intercept ~ Normal(0, 1) + coef ~ MvNormal(zeros(2), I) + coef = reshape(coef, 1, size(x, 1)) + + mu = vec(intercept .+ coef * x) + error ~ truncated(Normal(0, 1), 0, Inf) + return y ~ MvNormal(mu, error^2 * I) + end + + @model function simple_linear2(x, y) + intercept ~ Normal(0, 1) + coef ~ filldist(Normal(0, 1), 2) + coef = reshape(coef, 1, size(x, 1)) + + mu = vec(intercept .+ coef * x) + error ~ truncated(Normal(0, 1), 0, Inf) + return y ~ MvNormal(mu, error^2 * I) + end + + @model function simple_linear3(x, y) + intercept ~ Normal(0, 1) + coef = Vector(undef, 2) + for i in axes(coef, 1) + coef[i] ~ Normal(0, 1) + end + coef = reshape(coef, 1, size(x, 1)) + + mu = vec(intercept .+ coef * x) + error ~ truncated(Normal(0, 1), 0, Inf) + return y ~ MvNormal(mu, error^2 * I) + end + + @model function simple_linear4(x, y) + intercept ~ Normal(0, 1) + coef1 ~ Normal(0, 1) + coef2 ~ Normal(0, 1) + coef = [coef1, coef2] + coef = reshape(coef, 1, size(x, 1)) + + mu = vec(intercept .+ coef * x) + error ~ truncated(Normal(0, 1), 0, Inf) + return y ~ MvNormal(mu, error^2 * I) + end + + # Some data + x = randn(2, 100) + y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] + + param_names = Dict( + simple_linear1 => [:intercept, :coef], + simple_linear2 => [:intercept, :coef], + simple_linear3 => [:intercept, Symbol.(["coef[$i]" for i in 1:2])...], + simple_linear4 => [:intercept, :coef1, :coef2], + ) + for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4] + m = model(x, y) + chain = sample( + DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)), + AdvancedHMC.NUTS(0.65), + 100; + chain_type=MCMCChains.Chains, + param_names=param_names[model], + ) + chain_predict = DynamicPPL.predict(model(x, missing), chain) + mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] + @test mean(abs2, mean_prediction - y) ≤ 1e-3 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a832a0f08..9e4b3a446 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Accessors using ADTypes +using AdvancedHMC: AdvancedHMC using DynamicPPL using AbstractMCMC using AbstractPPL From bdf90b4e6371a475098aa9ebfd45192a2678a73e Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 13 Nov 2024 09:21:44 +0000 Subject: [PATCH 02/14] minor fixes --- ext/DynamicPPLMCMCChainsExt.jl | 6 ------ test/ext/DynamicPPLMCMCChainsExt.jl | 11 +++++------ 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 457c85bfb..832787ad0 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -316,12 +316,6 @@ julia> [first(t.θ.x) for t in transitions] # extract samples for `x` [-1.704630494695469] ``` """ -function transitions_from_chain( - model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs... -) - return transitions_from_chain(Random.default_rng(), model, chain; kwargs...) -end - function transitions_from_chain( rng::DynamicPPL.Random.AbstractRNG, model::DynamicPPL.Model, diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index c5a5c4f74..6923cad9c 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -141,17 +141,16 @@ end return y ~ MvNormal(mu, error^2 * I) end - # Some data x = randn(2, 100) y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] param_names = Dict( - simple_linear1 => [:intercept, :coef], - simple_linear2 => [:intercept, :coef], - simple_linear3 => [:intercept, Symbol.(["coef[$i]" for i in 1:2])...], - simple_linear4 => [:intercept, :coef1, :coef2], + simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], + simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], + simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], + simple_linear4 => [:intercept, :coef1, :coef2, :error], ) - for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4] + @testset "$model" for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4] m = model(x, y) chain = sample( DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)), From c7d08b0332d667619292edbfaed07858572d94c7 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:28:38 +0000 Subject: [PATCH 03/14] Update test/ext/DynamicPPLMCMCChainsExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/ext/DynamicPPLMCMCChainsExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 6923cad9c..111ee7fbf 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -150,7 +150,8 @@ end simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], simple_linear4 => [:intercept, :coef1, :coef2, :error], ) - @testset "$model" for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4] + @testset "$model" for model in + [simple_linear1, simple_linear2, simple_linear3, simple_linear4] m = model(x, y) chain = sample( DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)), From a425c41e0274efd408a028851302cf1267f2cea1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 18 Nov 2024 11:06:43 +0000 Subject: [PATCH 04/14] fix test error by discard burn-in's --- test/ext/DynamicPPLMCMCChainsExt.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 111ee7fbf..b7888acf6 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -35,11 +35,13 @@ end # Infer m_lin_reg = linear_reg(xs_train, ys_train) chain_lin_reg = sample( - DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), + DynamicPPL.LogDensityFunction(m_lin_reg), AdvancedHMC.NUTS(0.65), - 200; + 1000; chain_type=MCMCChains.Chains, param_names=[:β], + discard_initial=100, + n_adapt=100, ) # Predict on two last indices @@ -156,9 +158,11 @@ end chain = sample( DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)), AdvancedHMC.NUTS(0.65), - 100; + 1000; chain_type=MCMCChains.Chains, param_names=param_names[model], + discard_initial=100, + n_adapt=100, ) chain_predict = DynamicPPL.predict(model(x, missing), chain) mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] From 41471f688d4edc470270147c1a1573dca8042529 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 18 Nov 2024 11:08:38 +0000 Subject: [PATCH 05/14] add some comments --- test/ext/DynamicPPLMCMCChainsExt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index b7888acf6..4284edcfb 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -50,6 +50,8 @@ end ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) + # test like this depends on the variance of the posterior + # this only makes sense if the posterior variance is about 0.002 @test sum(abs2, ys_test - ys_pred) ≤ 0.1 # Ensure that `rng` is respected From 90d99ca37b646bf19377f570f96dbb04e3688274 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 18 Nov 2024 12:18:17 +0000 Subject: [PATCH 06/14] fix test error --- test/ext/DynamicPPLMCMCChainsExt.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 4284edcfb..8cdcbfd92 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -75,10 +75,12 @@ end DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), AdvancedHMC.NUTS(0.65), MCMCThreads(), - 200, + 1000, 2; chain_type=MCMCChains.Chains, param_names=[:β], + discard_initial=100, + n_adapt=100, ) m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) @@ -158,9 +160,10 @@ end [simple_linear1, simple_linear2, simple_linear3, simple_linear4] m = model(x, y) chain = sample( - DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)), + DynamicPPL.LogDensityFunction(m), AdvancedHMC.NUTS(0.65), - 1000; + 400; + initial_params = rand(4), chain_type=MCMCChains.Chains, param_names=param_names[model], discard_initial=100, From ea23b7c1b9c353f5e5b8bcbb9cc610d5f756cc7e Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Mon, 18 Nov 2024 12:22:07 +0000 Subject: [PATCH 07/14] Update test/ext/DynamicPPLMCMCChainsExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/ext/DynamicPPLMCMCChainsExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 8cdcbfd92..25e0d55ba 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -163,7 +163,7 @@ end DynamicPPL.LogDensityFunction(m), AdvancedHMC.NUTS(0.65), 400; - initial_params = rand(4), + initial_params=rand(4), chain_type=MCMCChains.Chains, param_names=param_names[model], discard_initial=100, From 76ef40f40abbb0675c0a0ba18793cfc5ba7e4f18 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 21 Nov 2024 12:42:01 +0000 Subject: [PATCH 08/14] refactor the code; add `predict` in Turing that takes array of varinfos --- ext/DynamicPPLMCMCChainsExt.jl | 240 ++++++--------------------------- src/model.jl | 19 ++- 2 files changed, 57 insertions(+), 202 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 832787ad0..44651c77e 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -42,78 +42,22 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end -# this is copied from Turing.jl, `stats` field is omitted as it is never used -struct Transition{T,F} - θ::T - lp::F -end - -function Transition(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) - return Transition(getparams(model, vi), DynamicPPL.getlogp(vi)) -end - -# a copy of Turing.Inference.getparams -getparams(model, t) = t.θ -function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) - # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used. - # Unfortunately, using `invlink` can cause issues in scenarios where the constraints - # of the parameters change depending on the realizations. Hence we have to use - # `values_as_in_model`, which re-runs the model and extracts the parameters - # as they are seen in the model, i.e. in the constrained space. Moreover, - # this means that the code below will work both of linked and invlinked `vi`. - # Ref: https://github.com/TuringLang/Turing.jl/issues/2195 - # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`. - vals = DynamicPPL.values_as_in_model(model, deepcopy(vi)) - - # Obtain an iterator over the flattened parameter names and values. - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - - # Materialize the iterators and concatenate. - return mapreduce(collect, vcat, iters) -end - -function _params_to_array(model::DynamicPPL.Model, ts::Vector) - names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() - # Extract the parameter names and values from each transition. - dicts = map(ts) do t - nms_and_vs = getparams(model, t) - nms = map(first, nms_and_vs) - vs = map(last, nms_and_vs) - for nm in nms - push!(names_set, nm) - end - # Convert the names and values to a single dictionary. - return DynamicPPL.OrderedCollections.OrderedDict(zip(nms, vs)) - end - names = collect(names_set) - vals = [ - get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names) - ] - - return names, vals -end - """ - predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) -Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`. - -If `include_all` is `false`, the returned `Chains` will contain only those variables -sampled/not present in `chain`. +Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample +in `chain`, and return the resulting `Chains`. -# Details -Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples -and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`. +If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by +the samples in `chain`. This is useful when you want to sample only new variables from the posterior +predictive distribution. -# Example +# Examples ```jldoctest -julia> using AbstractMCMC, AdvancedHMC, DynamicPPL, ForwardDiff; -[ Info: [Turing]: progress logging is disabled globally +julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff; julia> @model function linear_reg(x, y, σ = 0.1) β ~ Normal(0, 1) - for i ∈ eachindex(y) y[i] ~ Normal(β * x[i], σ) end @@ -129,7 +73,7 @@ julia> m_train = linear_reg(xs_train, ys_train, σ); julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train)); -julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β]); +julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β], discard_initial=100) ┌ Info: Found initial step size └ ϵ = 0.003125 @@ -158,7 +102,6 @@ Quantiles y[1] 20.0342 20.1188 20.2135 20.2588 20.4188 y[2] 20.1870 20.3178 20.3839 20.4466 20.5895 - julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1)); julia> sum(abs2, ys_test - ys_pred) ≤ 0.1 @@ -171,170 +114,65 @@ function DynamicPPL.predict( chain::MCMCChains.Chains; include_all=false, ) - # Don't need all the diagnostics - chain_parameters = MCMCChains.get_sections(chain, :parameters) - - spl = DynamicPPL.SampleFromPrior() + parameter_only_chain = MCMCChains.get_sections(chain, :parameters) + vi = DynamicPPL.VarInfo(model) + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + varinfos = map(iters) do (sample_idx, chain_idx) + DynamicPPL.setval_and_resample!( + deepcopy(vi), parameter_only_chain, sample_idx, chain_idx + ) + end - # Sample transitions using `spl` conditioned on values in `chain` - transitions = transitions_from_chain(rng, model, chain_parameters; sampler=spl) + predictive_samples = DynamicPPL.predict(rng, model, varinfos; include_all) - # Let the Turing internals handle everything else for you chain_result = reduce( MCMCChains.chainscat, [ - _bundle_samples(transitions[:, chain_idx], model, spl) for - chain_idx in 1:size(transitions, 2) + _bundle_samples(predictive_samples[:, chain_idx]) for + chain_idx in 1:size(predictive_samples, 2) ], ) - parameter_names = if include_all MCMCChains.names(chain_result, :parameters) else filter( - k -> !(k in MCMCChains.names(chain_parameters, :parameters)), + k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)), names(chain_result, :parameters), ) end - return chain_result[parameter_names] end -getlogp(t::Transition) = t.lp - -function get_transition_extras(ts::AbstractVector{<:Transition}) - valmat = reshape([getlogp(t) for t in ts], :, 1) - return [:lp], valmat -end - -function names_values(extra_data::AbstractVector{<:NamedTuple{names}}) where {names} - values = [getfield(data, name) for data in extra_data, name in names] - return collect(names), values -end +function _params_to_array(ts::Vector) + names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() -function names_values(xs::AbstractVector{<:NamedTuple}) - # Obtain all parameter names. - names_set = Set{Symbol}() - for x in xs - for k in keys(x) - push!(names_set, k) + dicts = map(ts) do t + nms_and_vs = t.values + nms = map(first, nms_and_vs) + vs = map(last, nms_and_vs) + for nm in nms + push!(names_set, nm) end + return DynamicPPL.OrderedCollections.OrderedDict(zip(nms, vs)) end - names_unique = collect(names_set) - # Extract all values as matrix. - values = [haskey(x, name) ? x[name] : missing for x in xs, name in names_unique] + names = collect(names_set) + vals = [ + get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names) + ] - return names_unique, values + return names, vals end -getlogevidence(transitions, sampler, state) = missing - -# this is copied from Turing.jl/src/mcmc/Inference.jl, types are more restrictive (removed types that are defined in Turing) -# the function is simplified, so that unused arguments are removed -function _bundle_samples( - ts::Vector{<:Transition}, model::DynamicPPL.Model, spl::DynamicPPL.SampleFromPrior -) - # Convert transitions to array format. - # Also retrieve the variable names. - varnames, vals = _params_to_array(model, ts) +function _bundle_samples(ts::Vector{<:DynamicPPL.PredictiveSample}) + varnames, vals = _params_to_array(ts) varnames_symbol = map(Symbol, varnames) - - # Get the values of the extra parameters in each transition. - extra_params, extra_values = get_transition_extras(ts) - - # Extract names & construct param array. + extra_params = [:lp] + extra_values = reshape([t.logp for t in ts], :, 1) nms = [varnames_symbol; extra_params] parray = hcat(vals, extra_values) - - # Set up the info tuple. - info = NamedTuple() - - info = merge( - info, - ( - varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict( - zip(varnames, varnames_symbol) - ), - ), - ) - - # Conretize the array before giving it to MCMCChains. parray = MCMCChains.concretize(parray) - - # Chain construction. - chain = MCMCChains.Chains(parray, nms, (internals=extra_params,)) - - return chain -end - -""" - transitions_from_chain( - [rng::AbstractRNG,] - model::Model, - chain::MCMCChains.Chains; - sampler = DynamicPPL.SampleFromPrior() - ) - -Execute `model` conditioned on each sample in `chain`, and return resulting transitions. - -The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`. - -# Details - -In a bit more detail, the process is as follows: -1. For every `sample` in `chain` - 1. For every `variable` in `sample` - 1. Set `variable` in `model` to its value in `sample` - 2. Execute `model` with variables fixed as above, sampling variables NOT present - in `chain` using `SampleFromPrior` - 3. Return sampled variables and log-joint - -# Example -```julia-repl -julia> using Turing - -julia> @model function demo() - m ~ Normal(0, 1) - x ~ Normal(m, 1) - end; - -julia> m = demo(); - -julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m` - -julia> transitions = Turing.Inference.transitions_from_chain(m, chain); - -julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints -2-element Array{Float64,1}: - -3.6294991938628374 - -2.5697948166987845 - -julia> [first(t.θ.x) for t in transitions] # extract samples for `x` -2-element Array{Array{Float64,1},1}: - [-2.0844148956440796] - [-1.704630494695469] -``` -""" -function transitions_from_chain( - rng::DynamicPPL.Random.AbstractRNG, - model::DynamicPPL.Model, - chain::MCMCChains.Chains; - sampler=DynamicPPL.SampleFromPrior(), -) - vi = DynamicPPL.VarInfo(model) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - transitions = map(iters) do (sample_idx, chain_idx) - # Set variables present in `chain` and mark those NOT present in chain to be resampled. - DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx) - model(rng, vi, sampler) - - # Convert `VarInfo` into `NamedTuple` and save. - Transition(model, vi) - end - - return transitions + return MCMCChains.Chains(parray, nms, (internals=extra_params,)) end """ diff --git a/src/model.jl b/src/model.jl index 82654e2ec..9fe1030f2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1203,11 +1203,16 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end +struct PredictiveSample{T,F} + values::T + logp::F +end + """ predict([rng::AbstractRNG,] model::Model, chain; include_all=false) Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample -in `chain`, and return the resulting `Chains`. At the moment, `chain` must be a `MCMCChains.Chains` object. +in `chain`. If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by the samples in `chain`. This is useful when you want to sample only new variables from the posterior @@ -1217,6 +1222,18 @@ function predict(model::Model, chain; include_all=false) return predict(Random.default_rng(), model, chain; include_all) end +function predict(rng::Random.AbstractRNG, model::Model, varinfos::AbstractArray{<:AbstractVarInfo}; include_all=false) + predictive_samples = Array{PredictiveSample}(undef, size(varinfos)) + for i in eachindex(varinfos) + model(rng, varinfos[i], SampleFromPrior()) + vals = values_as_in_model(model, varinfos[i]) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + params = mapreduce(collect, vcat, iters) + predictive_samples[i] = PredictiveSample(params, getlogp(varinfos[i])) + end + return predictive_samples +end + """ generated_quantities(model::Model, parameters::NamedTuple) generated_quantities(model::Model, values, keys) From 304b63e60c874e542fa9cc629e073a05666bf13f Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 21 Nov 2024 14:23:08 +0000 Subject: [PATCH 09/14] Update model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/model.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 9fe1030f2..4899a8443 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1222,7 +1222,12 @@ function predict(model::Model, chain; include_all=false) return predict(Random.default_rng(), model, chain; include_all) end -function predict(rng::Random.AbstractRNG, model::Model, varinfos::AbstractArray{<:AbstractVarInfo}; include_all=false) +function predict( + rng::Random.AbstractRNG, + model::Model, + varinfos::AbstractArray{<:AbstractVarInfo}; + include_all=false, +) predictive_samples = Array{PredictiveSample}(undef, size(varinfos)) for i in eachindex(varinfos) model(rng, varinfos[i], SampleFromPrior()) From fcd7c3dab3e22e9333878005f861b853ef77a570 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 29 Nov 2024 11:42:08 +0000 Subject: [PATCH 10/14] stop using `PredictiveSample` type --- ext/DynamicPPLMCMCChainsExt.jl | 18 +++++++++++------- src/model.jl | 11 ++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 44651c77e..fcf9eca8a 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -128,7 +128,7 @@ function DynamicPPL.predict( chain_result = reduce( MCMCChains.chainscat, [ - _bundle_samples(predictive_samples[:, chain_idx]) for + _bundle_predictive_samples(predictive_samples[:, chain_idx]) for chain_idx in 1:size(predictive_samples, 2) ], ) @@ -143,11 +143,11 @@ function DynamicPPL.predict( return chain_result[parameter_names] end -function _params_to_array(ts::Vector) +function _params_to_array(predictive_samples) names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() - dicts = map(ts) do t - nms_and_vs = t.values + dicts = map(predictive_samples) do t + nms_and_vs = t[:values] nms = map(first, nms_and_vs) vs = map(last, nms_and_vs) for nm in nms @@ -164,11 +164,15 @@ function _params_to_array(ts::Vector) return names, vals end -function _bundle_samples(ts::Vector{<:DynamicPPL.PredictiveSample}) - varnames, vals = _params_to_array(ts) +function _bundle_predictive_samples( + predictive_samples::AbstractArray{ + <:DynamicPPL.OrderedCollections.OrderedDict{Symbol,Any} + }, +) + varnames, vals = _params_to_array(predictive_samples) varnames_symbol = map(Symbol, varnames) extra_params = [:lp] - extra_values = reshape([t.logp for t in ts], :, 1) + extra_values = reshape([t[:logp] for t in predictive_samples], :, 1) nms = [varnames_symbol; extra_params] parray = hcat(vals, extra_values) parray = MCMCChains.concretize(parray) diff --git a/src/model.jl b/src/model.jl index 4899a8443..e57379d3e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1203,11 +1203,6 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end -struct PredictiveSample{T,F} - values::T - logp::F -end - """ predict([rng::AbstractRNG,] model::Model, chain; include_all=false) @@ -1228,13 +1223,15 @@ function predict( varinfos::AbstractArray{<:AbstractVarInfo}; include_all=false, ) - predictive_samples = Array{PredictiveSample}(undef, size(varinfos)) + predictive_samples = similar(varinfos, OrderedDict{Symbol,Any}) for i in eachindex(varinfos) model(rng, varinfos[i], SampleFromPrior()) vals = values_as_in_model(model, varinfos[i]) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) params = mapreduce(collect, vcat, iters) - predictive_samples[i] = PredictiveSample(params, getlogp(varinfos[i])) + predictive_samples[i] = OrderedDict( + :values => params, :logp => getlogp(varinfos[i]) + ) end return predictive_samples end From 30208eced09ef56fa063e02828a11117ebf0de0f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 29 Nov 2024 14:22:10 +0000 Subject: [PATCH 11/14] use NamedTuple --- ext/DynamicPPLMCMCChainsExt.jl | 14 +++++--------- src/model.jl | 8 +++----- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index fcf9eca8a..8650724af 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -147,9 +147,9 @@ function _params_to_array(predictive_samples) names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() dicts = map(predictive_samples) do t - nms_and_vs = t[:values] - nms = map(first, nms_and_vs) - vs = map(last, nms_and_vs) + varname_and_values = t.varname_and_values + nms = map(first, varname_and_values) + vs = map(last, varname_and_values) for nm in nms push!(names_set, nm) end @@ -164,15 +164,11 @@ function _params_to_array(predictive_samples) return names, vals end -function _bundle_predictive_samples( - predictive_samples::AbstractArray{ - <:DynamicPPL.OrderedCollections.OrderedDict{Symbol,Any} - }, -) +function _bundle_predictive_samples(predictive_samples) varnames, vals = _params_to_array(predictive_samples) varnames_symbol = map(Symbol, varnames) extra_params = [:lp] - extra_values = reshape([t[:logp] for t in predictive_samples], :, 1) + extra_values = reshape([t.logp for t in predictive_samples], :, 1) nms = [varnames_symbol; extra_params] parray = hcat(vals, extra_values) parray = MCMCChains.concretize(parray) diff --git a/src/model.jl b/src/model.jl index e57379d3e..4aea7ecd9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1223,15 +1223,13 @@ function predict( varinfos::AbstractArray{<:AbstractVarInfo}; include_all=false, ) - predictive_samples = similar(varinfos, OrderedDict{Symbol,Any}) + predictive_samples = similar(varinfos, NamedTuple{(:varname_and_values, :logp)}) for i in eachindex(varinfos) model(rng, varinfos[i], SampleFromPrior()) vals = values_as_in_model(model, varinfos[i]) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - params = mapreduce(collect, vcat, iters) - predictive_samples[i] = OrderedDict( - :values => params, :logp => getlogp(varinfos[i]) - ) + params = mapreduce(collect, vcat, iters) # returns a vector of tuples (varname, value) + predictive_samples[i] = (varname_and_values=params, logp=getlogp(varinfos[i])) end return predictive_samples end From bf3862728b6a023b8d6c16b542fe3069054685be Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 1 Dec 2024 12:10:45 +0000 Subject: [PATCH 12/14] remove predict with varinfos function --- ext/DynamicPPLMCMCChainsExt.jl | 73 +++++++++++++++++++++------------- src/model.jl | 19 +-------- 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 8650724af..82f489765 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -115,20 +115,30 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - vi = DynamicPPL.VarInfo(model) + prototypical_varinfo = DynamicPPL.VarInfo(model) + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - varinfos = map(iters) do (sample_idx, chain_idx) + predictive_samples = map(iters) do (sample_idx, chain_idx) + varinfo = deepcopy(prototypical_varinfo) DynamicPPL.setval_and_resample!( - deepcopy(vi), parameter_only_chain, sample_idx, chain_idx + varinfo, parameter_only_chain, sample_idx, chain_idx ) - end + model(rng, varinfo, DynamicPPL.SampleFromPrior()) - predictive_samples = DynamicPPL.predict(rng, model, varinfos; include_all) + vals = DynamicPPL.values_as_in_model(model, varinfo) + varname_vals = mapreduce( + collect, + vcat, + map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), + ) + + return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) + end chain_result = reduce( MCMCChains.chainscat, [ - _bundle_predictive_samples(predictive_samples[:, chain_idx]) for + _predictive_samples_to_chains(predictive_samples[:, chain_idx]) for chain_idx in 1:size(predictive_samples, 2) ], ) @@ -143,36 +153,43 @@ function DynamicPPL.predict( return chain_result[parameter_names] end -function _params_to_array(predictive_samples) - names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() +function _predictive_samples_to_arrays(predictive_samples) + variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() - dicts = map(predictive_samples) do t - varname_and_values = t.varname_and_values - nms = map(first, varname_and_values) - vs = map(last, varname_and_values) - for nm in nms - push!(names_set, nm) + sample_dicts = map(predictive_samples) do sample + varname_value_pairs = sample.varname_and_values + varnames = map(first, varname_value_pairs) + values = map(last, varname_value_pairs) + for varname in varnames + push!(variable_names_set, varname) end - return DynamicPPL.OrderedCollections.OrderedDict(zip(nms, vs)) + + return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values)) end - names = collect(names_set) - vals = [ - get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names) + variable_names = collect(variable_names_set) + variable_values = [ + get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts), + key in variable_names ] - return names, vals + return variable_names, variable_values end -function _bundle_predictive_samples(predictive_samples) - varnames, vals = _params_to_array(predictive_samples) - varnames_symbol = map(Symbol, varnames) - extra_params = [:lp] - extra_values = reshape([t.logp for t in predictive_samples], :, 1) - nms = [varnames_symbol; extra_params] - parray = hcat(vals, extra_values) - parray = MCMCChains.concretize(parray) - return MCMCChains.Chains(parray, nms, (internals=extra_params,)) +function _predictive_samples_to_chains(predictive_samples) + variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples) + variable_names_symbols = map(Symbol, variable_names) + + internal_parameters = [:lp] + log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1) + + parameter_names = [variable_names_symbols; internal_parameters] + parameter_values = hcat(variable_values, log_probabilities) + parameter_values = MCMCChains.concretize(parameter_values) + + return MCMCChains.Chains( + parameter_values, parameter_names, (internals=internal_parameters,) + ) end """ diff --git a/src/model.jl b/src/model.jl index 4aea7ecd9..dfae5fb1d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1214,26 +1214,11 @@ the samples in `chain`. This is useful when you want to sample only new variable predictive distribution. """ function predict(model::Model, chain; include_all=false) + # this is only defined in `ext/DynamicPPLMCMCChainsExt.jl` + # TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict` return predict(Random.default_rng(), model, chain; include_all) end -function predict( - rng::Random.AbstractRNG, - model::Model, - varinfos::AbstractArray{<:AbstractVarInfo}; - include_all=false, -) - predictive_samples = similar(varinfos, NamedTuple{(:varname_and_values, :logp)}) - for i in eachindex(varinfos) - model(rng, varinfos[i], SampleFromPrior()) - vals = values_as_in_model(model, varinfos[i]) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - params = mapreduce(collect, vcat, iters) # returns a vector of tuples (varname, value) - predictive_samples[i] = (varname_and_values=params, logp=getlogp(varinfos[i])) - end - return predictive_samples -end - """ generated_quantities(model::Model, parameters::NamedTuple) generated_quantities(model::Model, values, keys) From a3fc8b1f96320ffac8e1df04163ca705083e7a0e Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 20 Dec 2024 09:56:44 +0000 Subject: [PATCH 13/14] update implementation and tests; no longer using AdvancedHMC --- Project.toml | 2 +- ext/DynamicPPLMCMCChainsExt.jl | 86 +++++++------- src/DynamicPPL.jl | 2 + src/model.jl | 26 +++-- test/Project.toml | 4 +- test/ext/DynamicPPLMCMCChainsExt.jl | 167 +--------------------------- test/model.jl | 105 +++++++++++++++++ test/runtests.jl | 1 - 8 files changed, 164 insertions(+), 229 deletions(-) diff --git a/Project.toml b/Project.toml index 95342249c..97969944d 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.8.4, 0.9" +AbstractPPL = "0.10.1" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 41efcb15c..06cde3bac 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -48,64 +48,59 @@ end Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample in `chain`, and return the resulting `Chains`. +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to +simulate what new, unobserved data might look like, given your posterior beliefs. + +For each parameter configuration in `chain`: +1. All random variables present in `chain` are fixed to their sampled values. +2. Any variables not included in `chain` are sampled from their prior distributions. + If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by the samples in `chain`. This is useful when you want to sample only new variables from the posterior predictive distribution. # Examples ```jldoctest -julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff; - -julia> @model function linear_reg(x, y, σ = 0.1) - β ~ Normal(0, 1) - for i ∈ eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end; - -julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn(); - -julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train); - -julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test); - -julia> m_train = linear_reg(xs_train, ys_train, σ); +using AbstractMCMC, Distributions, DynamicPPL, Random -julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train)); +@model function linear_reg(x, y, σ = 0.1) + β ~ Normal(0, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end +end -julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β], discard_initial=100) -┌ Info: Found initial step size -└ ϵ = 0.003125 +# Generate synthetic chain using known ground truth parameter +ground_truth_β = 2.0 -julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ); +# Create chain of samples from a normal distribution centered on ground truth +β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), [:β,] +) -julia> predictions = predict(m_test, chain_lin_reg) -Object of type Chains, with data of type 100×2×1 Array{Float64,3} +# Generate predictions for two test points +xs_test = [10.1, 10.2] -Iterations = 1:100 -Thinning interval = 1 -Chains = 1 -Samples per chain = 100 -parameters = y[1], y[2] +m_train = linear_reg(xs_test, fill(missing, length(xs_test))) -2-element Array{ChainDataFrame,1} +predictions = DynamicPPL.AbstractPPL.predict( + Random.default_rng(), m_train, β_chain +) -Summary Statistics - parameters mean std naive_se mcse ess r_hat - ────────── ─────── ────── ──────── ─────── ──────── ────── - y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922 - y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903 +ys_pred = vec(mean(Array(predictions); dims=1)) -Quantiles - parameters 2.5% 25.0% 50.0% 75.0% 97.5% - ────────── ─────── ─────── ─────── ─────── ─────── - y[1] 20.0342 20.1188 20.2135 20.2588 20.4188 - y[2] 20.1870 20.3178 20.3839 20.4466 20.5895 +# Check if predictions match expected values within tolerance +( + isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01), + isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01) +) -julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1)); +# output -julia> sum(abs2, ys_test - ys_pred) ≤ 0.1 -true +(true, true) ``` """ function DynamicPPL.predict( @@ -115,14 +110,11 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - prototypical_varinfo = DynamicPPL.VarInfo(model) + varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - varinfo = deepcopy(prototypical_varinfo) - DynamicPPL.setval_and_resample!( - varinfo, parameter_only_chain, sample_idx, chain_idx - ) + DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) model(rng, varinfo, DynamicPPL.SampleFromPrior()) vals = DynamicPPL.values_as_in_model(model, varinfo) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c314e6c6d..c1cdbd94e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -40,6 +40,8 @@ import Base: keys, haskey +import AbstractPPL: predict + # VarInfo export AbstractVarInfo, VarInfo, diff --git a/src/model.jl b/src/model.jl index 037ed8379..2bad6f1fe 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1145,19 +1145,23 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end """ - predict([rng::AbstractRNG,] model::Model, chain; include_all=false) + predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) -Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample -in `chain`. - -If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by -the samples in `chain`. This is useful when you want to sample only new variables from the posterior -predictive distribution. +Generate samples from the posterior predictive distribution by evaluating `model` at each set +of parameter values provided in `chain`. The number of posterior predictive samples matches +the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values +and the predicted values. """ -function predict(model::Model, chain; include_all=false) - # this is only defined in `ext/DynamicPPLMCMCChainsExt.jl` - # TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict` - return predict(Random.default_rng(), model, chain; include_all) +function predict( + rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} +) + varinfo = DynamicPPL.VarInfo(model) + return map(chain) do params_varinfo + vi = deepcopy(varinfo) + DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) + model(rng, vi, SampleFromPrior()) + return vi + end end """ diff --git a/test/Project.toml b/test/Project.toml index 11ebeaad8..c7583c672 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,7 +2,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" -AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" @@ -34,9 +33,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.8.4, 0.9" +AbstractPPL = "0.10.1" Accessors = "0.1" -AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" Bijectors = "0.15.1" Combinatorics = "1" Compat = "4.3.0" diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 8693c3b02..3ba5edfe1 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -8,169 +8,4 @@ @test mean(chain_generated) ≈ 0 atol = 0.1 end -@testset "predict" begin - DynamicPPL.Random.seed!(100) - - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(0, 1) - - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end - - @model function linear_reg_vec(x, y, σ=0.1) - β ~ Normal(0, 1) - return y ~ MvNormal(β .* x, σ^2 * I) - end - - f(x) = 2 * x + 0.1 * randn() - - Δ = 0.1 - xs_train = 0:Δ:10 - ys_train = f.(xs_train) - xs_test = [10 + Δ, 10 + 2 * Δ] - ys_test = f.(xs_test) - - # Infer - m_lin_reg = linear_reg(xs_train, ys_train) - chain_lin_reg = sample( - DynamicPPL.LogDensityFunction(m_lin_reg), - AdvancedHMC.NUTS(0.65), - 1000; - chain_type=MCMCChains.Chains, - param_names=[:β], - discard_initial=100, - n_adapt=100, - ) - - # Predict on two last indices - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) - predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) - - ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) - - # test like this depends on the variance of the posterior - # this only makes sense if the posterior variance is about 0.002 - @test sum(abs2, ys_test - ys_pred) ≤ 0.1 - - # Ensure that `rng` is respected - predictions1 = let rng = MersenneTwister(42) - DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) - end - predictions2 = let rng = MersenneTwister(42) - DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) - end - @test all(Array(predictions1) .== Array(predictions2)) - - # Predict on two last indices for vectorized - m_lin_reg_test = linear_reg_vec(xs_test, missing) - predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) - ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) - - @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 - - # Multiple chains - chain_lin_reg = sample( - DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), - AdvancedHMC.NUTS(0.65), - MCMCThreads(), - 1000, - 2; - chain_type=MCMCChains.Chains, - param_names=[:β], - discard_initial=100, - n_adapt=100, - ) - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) - predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) - - @test size(chain_lin_reg, 3) == size(predictions, 3) - - for chain_idx in MCMCChains.chains(chain_lin_reg) - ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) - @test sum(abs2, ys_test - ys_pred) ≤ 0.1 - end - - # Predict on two last indices for vectorized - m_lin_reg_test = linear_reg_vec(xs_test, missing) - predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) - - for chain_idx in MCMCChains.chains(chain_lin_reg) - ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) - @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 - end - - # https://github.com/TuringLang/Turing.jl/issues/1352 - @model function simple_linear1(x, y) - intercept ~ Normal(0, 1) - coef ~ MvNormal(zeros(2), I) - coef = reshape(coef, 1, size(x, 1)) - - mu = vec(intercept .+ coef * x) - error ~ truncated(Normal(0, 1), 0, Inf) - return y ~ MvNormal(mu, error^2 * I) - end - - @model function simple_linear2(x, y) - intercept ~ Normal(0, 1) - coef ~ filldist(Normal(0, 1), 2) - coef = reshape(coef, 1, size(x, 1)) - - mu = vec(intercept .+ coef * x) - error ~ truncated(Normal(0, 1), 0, Inf) - return y ~ MvNormal(mu, error^2 * I) - end - - @model function simple_linear3(x, y) - intercept ~ Normal(0, 1) - coef = Vector(undef, 2) - for i in axes(coef, 1) - coef[i] ~ Normal(0, 1) - end - coef = reshape(coef, 1, size(x, 1)) - - mu = vec(intercept .+ coef * x) - error ~ truncated(Normal(0, 1), 0, Inf) - return y ~ MvNormal(mu, error^2 * I) - end - - @model function simple_linear4(x, y) - intercept ~ Normal(0, 1) - coef1 ~ Normal(0, 1) - coef2 ~ Normal(0, 1) - coef = [coef1, coef2] - coef = reshape(coef, 1, size(x, 1)) - - mu = vec(intercept .+ coef * x) - error ~ truncated(Normal(0, 1), 0, Inf) - return y ~ MvNormal(mu, error^2 * I) - end - - x = randn(2, 100) - y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] - - param_names = Dict( - simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], - simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], - simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], - simple_linear4 => [:intercept, :coef1, :coef2, :error], - ) - @testset "$model" for model in - [simple_linear1, simple_linear2, simple_linear3, simple_linear4] - m = model(x, y) - chain = sample( - DynamicPPL.LogDensityFunction(m), - AdvancedHMC.NUTS(0.65), - 400; - initial_params=rand(4), - chain_type=MCMCChains.Chains, - param_names=param_names[model], - discard_initial=100, - n_adapt=100, - ) - chain_predict = DynamicPPL.predict(model(x, missing), chain) - mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] - @test mean(abs2, mean_prediction - y) ≤ 1e-3 - end -end +# test for `predict` is in `test/model.jl` diff --git a/test/model.jl b/test/model.jl index a19cb29d2..cb1dbc735 100644 --- a/test/model.jl +++ b/test/model.jl @@ -429,4 +429,109 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) end end + + @testset "predict" begin + @testset "with MCMCChains.Chains" begin + DynamicPPL.Random.seed!(100) + + @model function linear_reg(x, y, σ=0.1) + β ~ Normal(0, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end + end + + @model function linear_reg_vec(x, y, σ=0.1) + β ~ Normal(0, 1) + return y ~ MvNormal(β .* x, σ^2 * I) + end + + ground_truth_β = 2 + β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + + xs_test = [10 + 0.1, 10 + 2 * 0.1] + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predictions = DynamicPPL.predict(m_lin_reg_test, β_chain) + + ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) + @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + + # Ensure that `rng` is respected + rng = MersenneTwister(42) + predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2]) + predictions2 = DynamicPPL.predict( + MersenneTwister(42), m_lin_reg_test, β_chain[1:2] + ) + @test all(Array(predictions1) .== Array(predictions2)) + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing) + predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain) + ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) + + @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + + # Multiple chains + multiple_β_chain = MCMCChains.Chains( + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + ) + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) + @test size(multiple_β_chain, 3) == size(predictions, 3) + + for chain_idx in MCMCChains.chains(multiple_β_chain) + ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) + @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + end + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing) + predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) + + for chain_idx in MCMCChains.chains(multiple_β_chain) + ys_pred_vec = vec( + mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1) + ) + @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + end + end + + @testset "with AbstractVector{<:AbstractVarInfo}" begin + @model function linear_reg(x, y, σ=0.1) + β ~ Normal(1, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end + end + + ground_truth_β = 2.0 + # the data will be ignored, as we are generating samples from the prior + xs_train = 1:0.1:10 + ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) + m_lin_reg = linear_reg(xs_train, ys_train) + chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000] + + # chain is generated from the prior + @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 + + xs_test = [10 + 0.1, 10 + 2 * 0.1] + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) + + @test size(predicted_vis) == size(chain) + @test Set(keys(predicted_vis[1])) == + Set([@varname(β), @varname(y[1]), @varname(y[2])]) + # because β samples are from the prior, the std will be larger + @test mean([ + predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) + ]) ≈ 1.0 * xs_test[1] rtol = 0.1 + @test mean([ + predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) + ]) ≈ 1.0 * xs_test[2] rtol = 0.1 + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 6fd925cae..9f2d21990 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ using Accessors using ADTypes -using AdvancedHMC: AdvancedHMC using DynamicPPL using AbstractMCMC using AbstractPPL From da7fa1cbaedee13a0e995b72e5b7a6aed8704b51 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 20 Dec 2024 17:14:31 +0000 Subject: [PATCH 14/14] try fixing naming conflict --- test/contexts.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 0f6628440..7a7826466 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -202,7 +202,7 @@ end s, m = retval.s, retval.m # Keword approach. - model_fixed = fix(model; s=s) + model_fixed = DynamicPPL.fix(model; s=s) @test model_fixed().s == s @test model_fixed().m != m # A fixed variable should not contribute at all to the logjoint. @@ -210,19 +210,19 @@ end @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) # Positional approach. - model_fixed = fix(model, (; s)) + model_fixed = DynamicPPL.fix(model, (; s)) @test model_fixed().s == s @test model_fixed().m != m @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) # Pairs approach. - model_fixed = fix(model, @varname(s) => s) + model_fixed = DynamicPPL.fix(model, @varname(s) => s) @test model_fixed().s == s @test model_fixed().m != m @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) # Dictionary approach. - model_fixed = fix(model, Dict(@varname(s) => s)) + model_fixed = DynamicPPL.fix(model, Dict(@varname(s) => s)) @test model_fixed().s == s @test model_fixed().m != m @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))