@@ -50,7 +50,6 @@ import LogDensityProblems
50
50
import LogDensityProblemsAD
51
51
import Random
52
52
import MCMCChains
53
- import StatsBase: predict
54
53
55
54
export InferenceAlgorithm,
56
55
Hamiltonian,
@@ -78,7 +77,6 @@ export InferenceAlgorithm,
78
77
dot_assume,
79
78
observe,
80
79
dot_observe,
81
- predict,
82
80
externalsampler
83
81
84
82
# ######################
@@ -396,7 +394,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
396
394
# this means that the code below will work both of linked and invlinked `vi`.
397
395
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
398
396
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
399
- vals = DynamicPPL. values_as_in_model (model, deepcopy (vi))
397
+ vals = DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
400
398
401
399
# Obtain an iterator over the flattened parameter names and values.
402
400
iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
@@ -612,112 +610,6 @@ end
612
610
DynamicPPL. getspace (spl:: Sampler ) = getspace (spl. alg)
613
611
DynamicPPL. inspace (vn:: VarName , spl:: Sampler ) = inspace (vn, getspace (spl. alg))
614
612
615
- """
616
-
617
- predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
618
-
619
- Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
620
-
621
- If `include_all` is `false`, the returned `Chains` will contain only those variables
622
- sampled/not present in `chain`.
623
-
624
- # Details
625
- Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
626
- and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
627
-
628
- # Example
629
- ```jldoctest
630
- julia> using Turing; Turing.setprogress!(false);
631
- [ Info: [Turing]: progress logging is disabled globally
632
-
633
- julia> @model function linear_reg(x, y, σ = 0.1)
634
- β ~ Normal(0, 1)
635
-
636
- for i ∈ eachindex(y)
637
- y[i] ~ Normal(β * x[i], σ)
638
- end
639
- end;
640
-
641
- julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
642
-
643
- julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
644
-
645
- julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
646
-
647
- julia> m_train = linear_reg(xs_train, ys_train, σ);
648
-
649
- julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
650
- ┌ Info: Found initial step size
651
- └ ϵ = 0.003125
652
-
653
- julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
654
-
655
- julia> predictions = predict(m_test, chain_lin_reg)
656
- Object of type Chains, with data of type 100×2×1 Array{Float64,3}
657
-
658
- Iterations = 1:100
659
- Thinning interval = 1
660
- Chains = 1
661
- Samples per chain = 100
662
- parameters = y[1], y[2]
663
-
664
- 2-element Array{ChainDataFrame,1}
665
-
666
- Summary Statistics
667
- parameters mean std naive_se mcse ess r_hat
668
- ────────── ─────── ────── ──────── ─────── ──────── ──────
669
- y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
670
- y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
671
-
672
- Quantiles
673
- parameters 2.5% 25.0% 50.0% 75.0% 97.5%
674
- ────────── ─────── ─────── ─────── ─────── ───────
675
- y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
676
- y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
677
-
678
-
679
- julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
680
-
681
- julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
682
- true
683
- ```
684
- """
685
- function predict (model:: Model , chain:: MCMCChains.Chains ; kwargs... )
686
- return predict (Random. default_rng (), model, chain; kwargs... )
687
- end
688
- function predict (
689
- rng:: AbstractRNG , model:: Model , chain:: MCMCChains.Chains ; include_all= false
690
- )
691
- # Don't need all the diagnostics
692
- chain_parameters = MCMCChains. get_sections (chain, :parameters )
693
-
694
- spl = DynamicPPL. SampleFromPrior ()
695
-
696
- # Sample transitions using `spl` conditioned on values in `chain`
697
- transitions = transitions_from_chain (rng, model, chain_parameters; sampler= spl)
698
-
699
- # Let the Turing internals handle everything else for you
700
- chain_result = reduce (
701
- MCMCChains. chainscat,
702
- [
703
- AbstractMCMC. bundle_samples (
704
- transitions[:, chain_idx], model, spl, nothing , MCMCChains. Chains
705
- ) for chain_idx in 1 : size (transitions, 2 )
706
- ],
707
- )
708
-
709
- parameter_names = if include_all
710
- names (chain_result, :parameters )
711
- else
712
- filter (
713
- k -> ∉ (k, names (chain_parameters, :parameters )),
714
- names (chain_result, :parameters ),
715
- )
716
- end
717
-
718
- return chain_result[parameter_names]
719
- end
720
-
721
613
"""
722
614
723
615
transitions_from_chain(
0 commit comments