Skip to content

Commit d8f79a7

Browse files
committed
Update for DynamicPPL 0.33
1 parent 24d5556 commit d8f79a7

File tree

3 files changed

+4
-112
lines changed

3 files changed

+4
-112
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.36.0"
3+
version = "0.36.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25"
6363
DistributionsAD = "0.6"
6464
DocStringExtensions = "0.8, 0.9"
6565
DynamicHMC = "3.4"
66-
DynamicPPL = "0.32"
66+
DynamicPPL = "0.33"
6767
EllipticalSliceSampling = "0.5, 1, 2"
6868
ForwardDiff = "0.10.3"
6969
Libtask = "0.8.8"

src/mcmc/Inference.jl

Lines changed: 1 addition & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ import LogDensityProblems
5050
import LogDensityProblemsAD
5151
import Random
5252
import MCMCChains
53-
import StatsBase: predict
5453

5554
export InferenceAlgorithm,
5655
Hamiltonian,
@@ -78,7 +77,6 @@ export InferenceAlgorithm,
7877
dot_assume,
7978
observe,
8079
dot_observe,
81-
predict,
8280
externalsampler
8381

8482
#######################
@@ -396,7 +394,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
396394
# this means that the code below will work both of linked and invlinked `vi`.
397395
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
398396
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
399-
vals = DynamicPPL.values_as_in_model(model, deepcopy(vi))
397+
vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
400398

401399
# Obtain an iterator over the flattened parameter names and values.
402400
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
@@ -612,112 +610,6 @@ end
612610
DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg)
613611
DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg))
614612

615-
"""
616-
617-
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
618-
619-
Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
620-
621-
If `include_all` is `false`, the returned `Chains` will contain only those variables
622-
sampled/not present in `chain`.
623-
624-
# Details
625-
Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
626-
and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
627-
628-
# Example
629-
```jldoctest
630-
julia> using Turing; Turing.setprogress!(false);
631-
[ Info: [Turing]: progress logging is disabled globally
632-
633-
julia> @model function linear_reg(x, y, σ = 0.1)
634-
β ~ Normal(0, 1)
635-
636-
for i ∈ eachindex(y)
637-
y[i] ~ Normal(β * x[i], σ)
638-
end
639-
end;
640-
641-
julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
642-
643-
julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
644-
645-
julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
646-
647-
julia> m_train = linear_reg(xs_train, ys_train, σ);
648-
649-
julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
650-
┌ Info: Found initial step size
651-
└ ϵ = 0.003125
652-
653-
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
654-
655-
julia> predictions = predict(m_test, chain_lin_reg)
656-
Object of type Chains, with data of type 100×2×1 Array{Float64,3}
657-
658-
Iterations = 1:100
659-
Thinning interval = 1
660-
Chains = 1
661-
Samples per chain = 100
662-
parameters = y[1], y[2]
663-
664-
2-element Array{ChainDataFrame,1}
665-
666-
Summary Statistics
667-
parameters mean std naive_se mcse ess r_hat
668-
────────── ─────── ────── ──────── ─────── ──────── ──────
669-
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
670-
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
671-
672-
Quantiles
673-
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
674-
────────── ─────── ─────── ─────── ─────── ───────
675-
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
676-
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
677-
678-
679-
julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
680-
681-
julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
682-
true
683-
```
684-
"""
685-
function predict(model::Model, chain::MCMCChains.Chains; kwargs...)
686-
return predict(Random.default_rng(), model, chain; kwargs...)
687-
end
688-
function predict(
689-
rng::AbstractRNG, model::Model, chain::MCMCChains.Chains; include_all=false
690-
)
691-
# Don't need all the diagnostics
692-
chain_parameters = MCMCChains.get_sections(chain, :parameters)
693-
694-
spl = DynamicPPL.SampleFromPrior()
695-
696-
# Sample transitions using `spl` conditioned on values in `chain`
697-
transitions = transitions_from_chain(rng, model, chain_parameters; sampler=spl)
698-
699-
# Let the Turing internals handle everything else for you
700-
chain_result = reduce(
701-
MCMCChains.chainscat,
702-
[
703-
AbstractMCMC.bundle_samples(
704-
transitions[:, chain_idx], model, spl, nothing, MCMCChains.Chains
705-
) for chain_idx in 1:size(transitions, 2)
706-
],
707-
)
708-
709-
parameter_names = if include_all
710-
names(chain_result, :parameters)
711-
else
712-
filter(
713-
k -> (k, names(chain_parameters, :parameters)),
714-
names(chain_result, :parameters),
715-
)
716-
end
717-
718-
return chain_result[parameter_names]
719-
end
720-
721613
"""
722614
723615
transitions_from_chain(

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Combinatorics = "1"
5151
Distributions = "0.25"
5252
DistributionsAD = "0.6.3"
5353
DynamicHMC = "2.1.6, 3.0"
54-
DynamicPPL = "0.32.2"
54+
DynamicPPL = "0.33"
5555
FiniteDifferences = "0.10.8, 0.11, 0.12"
5656
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
5757
HypothesisTests = "0.11"

0 commit comments

Comments
 (0)