|
| 1 | +""" |
| 2 | + AbstractInitStrategy |
| 3 | +
|
| 4 | +Abstract type representing the possible ways of initialising new values for |
| 5 | +the random variables in a model (e.g., when creating a new VarInfo). |
| 6 | +""" |
| 7 | +abstract type AbstractInitStrategy end |
| 8 | + |
| 9 | +""" |
| 10 | + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) |
| 11 | +
|
| 12 | +Generate a new value for a random variable with the given distribution. |
| 13 | +
|
| 14 | +!!! warning "Values must be unlinked" |
| 15 | + The values returned by `init` are always in the untransformed space, i.e., |
| 16 | + they must be within the support of the original distribution. That means that, |
| 17 | + for example, `init(rng, dist, u::UniformInit)` will in general return values that |
| 18 | + are outside the range [u.lower, u.upper]. |
| 19 | +""" |
| 20 | +function init end |
| 21 | + |
| 22 | +""" |
| 23 | + PriorInit() |
| 24 | +
|
| 25 | +Obtain new values by sampling from the prior distribution. |
| 26 | +""" |
| 27 | +struct PriorInit <: AbstractInitStrategy end |
| 28 | +init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist) |
| 29 | + |
| 30 | +""" |
| 31 | + UniformInit() |
| 32 | + UniformInit(lower, upper) |
| 33 | +
|
| 34 | +Obtain new values by first transforming the distribution of the random variable |
| 35 | +to unconstrained space, and then sampling a value uniformly between `lower` and |
| 36 | +`upper`. |
| 37 | +
|
| 38 | +If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's |
| 39 | +default initialisation strategy. |
| 40 | +
|
| 41 | +# References |
| 42 | +
|
| 43 | +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) |
| 44 | +""" |
| 45 | +struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy |
| 46 | + lower::T |
| 47 | + upper::T |
| 48 | + function UniformInit(lower::T, upper::T) where {T<:AbstractFloat} |
| 49 | + lower > upper && |
| 50 | + throw(ArgumentError("`lower` must be less than or equal to `upper`")) |
| 51 | + return new{T}(lower, upper) |
| 52 | + end |
| 53 | + UniformInit() = UniformInit(-2.0, 2.0) |
| 54 | +end |
| 55 | +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) |
| 56 | + b = Bijectors.bijector(dist) |
| 57 | + sz = Bijectors.output_size(b, size(dist)) |
| 58 | + y = rand(rng, Uniform(u.lower, u.upper), sz) |
| 59 | + b_inv = Bijectors.inverse(b) |
| 60 | + x = b_inv(y) |
| 61 | + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 |
| 62 | + if x isa Array{<:Any,0} |
| 63 | + x = x[] |
| 64 | + end |
| 65 | + return x |
| 66 | +end |
| 67 | + |
| 68 | +""" |
| 69 | + ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit()) |
| 70 | + ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) |
| 71 | +
|
| 72 | +Obtain new values by extracting them from the given dictionary or NamedTuple. |
| 73 | +The parameter `default` specifies how new values are to be obtained if they |
| 74 | +cannot be found in `params`, or they are specified as `missing`. The default |
| 75 | +for `default` is `PriorInit()`. |
| 76 | +
|
| 77 | +!!! note |
| 78 | + These values must be provided in the space of the untransformed distribution. |
| 79 | +""" |
| 80 | +struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy |
| 81 | + params::P |
| 82 | + default::S |
| 83 | + function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy) |
| 84 | + return new{typeof(params),typeof(default)}(params, default) |
| 85 | + end |
| 86 | + ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) |
| 87 | + function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) |
| 88 | + return ParamsInit(to_varname_dict(params), default) |
| 89 | + end |
| 90 | +end |
| 91 | +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) |
| 92 | + # TODO(penelopeysm): We should do a check to make sure that all of the |
| 93 | + # parameters in `p.params` were actually used, and either warn or error if |
| 94 | + # they aren't. This is non-trivial (we need to use something like |
| 95 | + # varname_leaves), so I'm going to defer it to a later PR. |
| 96 | + return if hasvalue(p.params, vn, dist) |
| 97 | + x = getvalue(p.params, vn, dist) |
| 98 | + if x === missing |
| 99 | + init(rng, vn, dist, p.default) |
| 100 | + else |
| 101 | + # TODO(penelopeysm): We could also check that the type of x matches |
| 102 | + # the dist? |
| 103 | + x |
| 104 | + end |
| 105 | + else |
| 106 | + init(rng, vn, dist, p.default) |
| 107 | + end |
| 108 | +end |
| 109 | + |
| 110 | +""" |
| 111 | + InitContext( |
| 112 | + [rng::Random.AbstractRNG=Random.default_rng()], |
| 113 | + [strategy::AbstractInitStrategy=PriorInit()], |
| 114 | + ) |
| 115 | +
|
| 116 | +A leaf context that indicates that new values for random variables are |
| 117 | +currently being obtained through sampling. Used e.g. when initialising a fresh |
| 118 | +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then |
| 119 | +`evaluate!!(model, varinfo)` will override all values in the VarInfo. |
| 120 | +""" |
| 121 | +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext |
| 122 | + rng::R |
| 123 | + strategy::S |
| 124 | + function InitContext( |
| 125 | + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit() |
| 126 | + ) |
| 127 | + return new{typeof(rng),typeof(strategy)}(rng, strategy) |
| 128 | + end |
| 129 | + function InitContext(strategy::AbstractInitStrategy=PriorInit()) |
| 130 | + return InitContext(Random.default_rng(), strategy) |
| 131 | + end |
| 132 | +end |
| 133 | +NodeTrait(::InitContext) = IsLeaf() |
| 134 | + |
| 135 | +function tilde_assume( |
| 136 | + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo |
| 137 | +) |
| 138 | + in_varinfo = haskey(vi, vn) |
| 139 | + # `init()` always returns values in original space, i.e. possibly |
| 140 | + # constrained |
| 141 | + x = init(ctx.rng, vn, dist, ctx.strategy) |
| 142 | + # Determine whether to insert a transformed value into the VarInfo. |
| 143 | + # If the VarInfo alrady had a value for this variable, we will |
| 144 | + # keep the same linked status as in the original VarInfo. If not, we |
| 145 | + # check the rest of the VarInfo to see if other variables are linked. |
| 146 | + # istrans(vi) returns true if vi is nonempty and all variables in vi |
| 147 | + # are linked. |
| 148 | + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) |
| 149 | + f = if insert_transformed_value |
| 150 | + to_linked_internal_transform(vi, vn, dist) |
| 151 | + else |
| 152 | + to_internal_transform(vi, vn, dist) |
| 153 | + end |
| 154 | + # TODO(penelopeysm): We would really like to do: |
| 155 | + # y, logjac = with_logabsdet_jacobian(f, x) |
| 156 | + # Unfortunately, `to_{linked_}internal_transform` returns a function that |
| 157 | + # always converts x to a vector, i.e., if dist is univariate, f(x) will be |
| 158 | + # a vector of length 1. It would be nice if we could unify these. |
| 159 | + y = f(x) |
| 160 | + logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) |
| 161 | + # Add the new value to the VarInfo. `push!!` errors if the value already |
| 162 | + # exists, hence the need for setindex!!. |
| 163 | + if in_varinfo |
| 164 | + vi = setindex!!(vi, y, vn) |
| 165 | + else |
| 166 | + vi = push!!(vi, vn, y, dist) |
| 167 | + end |
| 168 | + # Neither of these set the `trans` flag so we have to do it manually if |
| 169 | + # necessary. |
| 170 | + insert_transformed_value && settrans!!(vi, true, vn) |
| 171 | + # `accumulate_assume!!` wants untransformed values as the second argument. |
| 172 | + vi = accumulate_assume!!(vi, x, -logjac, vn, dist) |
| 173 | + # We always return the untransformed value here, as that will determine |
| 174 | + # what the lhs of the tilde-statement is set to. |
| 175 | + return x, vi |
| 176 | +end |
| 177 | + |
| 178 | +function tilde_observe!!(::InitContext, right, left, vn, vi) |
| 179 | + return tilde_observe!!(DefaultContext(), right, left, vn, vi) |
| 180 | +end |
0 commit comments