1
- # Uniform random numbers with range 4 for robust initializations
1
+ # UniformInit random numbers with range 4 for robust initializations
2
2
# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
3
3
randrealuni (rng:: Random.AbstractRNG ) = 4 * rand (rng) - 2
4
4
randrealuni (rng:: Random.AbstractRNG , args... ) = 4 .* rand (rng, args... ) .- 2
5
5
6
- istransformable (dist) = link_transform (dist) != = identity
7
-
8
- # ################################
9
- # Single-sample initialisations #
10
- # ################################
11
- inittrans (rng, dist:: UnivariateDistribution ) = Bijectors. invlink (dist, randrealuni (rng))
12
- function inittrans (rng, dist:: MultivariateDistribution )
13
- # Get the length of the unconstrained vector
14
- b = link_transform (dist)
15
- d = Bijectors. output_length (b, length (dist))
16
- return Bijectors. invlink (dist, randrealuni (rng, d))
17
- end
18
- function inittrans (rng, dist:: MatrixDistribution )
19
- # Get the size of the unconstrained vector
20
- b = link_transform (dist)
21
- sz = Bijectors. output_size (b, size (dist))
22
- return Bijectors. invlink (dist, randrealuni (rng, sz... ))
23
- end
24
- function inittrans (rng, dist:: Distribution{CholeskyVariate} )
25
- # Get the size of the unconstrained vector
26
- b = link_transform (dist)
27
- sz = Bijectors. output_size (b, size (dist))
28
- return Bijectors. invlink (dist, randrealuni (rng, sz... ))
29
- end
30
- # ###############################
31
- # Multi-sample initialisations #
32
- # ###############################
33
- function inittrans (rng, dist:: UnivariateDistribution , n:: Int )
34
- return Bijectors. invlink (dist, randrealuni (rng, n))
35
- end
36
- function inittrans (rng, dist:: MultivariateDistribution , n:: Int )
37
- return Bijectors. invlink (dist, randrealuni (rng, size (dist)[1 ], n))
38
- end
39
- function inittrans (rng, dist:: MatrixDistribution , n:: Int )
40
- return Bijectors. invlink (dist, [randrealuni (rng, size (dist)... ) for _ in 1 : n])
41
- end
42
-
43
6
"""
44
7
AbstractInitStrategy
45
8
@@ -49,15 +12,29 @@ the random variables in a model (e.g., when creating a new VarInfo).
49
12
abstract type AbstractInitStrategy end
50
13
51
14
"""
52
- Prior()
15
+ init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy)
16
+
17
+ Generate a new value for a random variable with the given distribution.
18
+
19
+ !!! warning "Values must be unlinked"
20
+ The values returned by `init` are always in the untransformed space, i.e.,
21
+ they must be within the support of the original distribution. That means that,
22
+ for example, `init(rng, dist, u::UniformInit)` will in general return values that
23
+ are outside the range [u.lower, u.upper].
24
+ """
25
+ function init end
53
26
54
- Obtain new values by sampling from the prior.
55
27
"""
56
- struct Prior <: AbstractInitStrategy end
28
+ PriorInit()
57
29
30
+ Obtain new values by sampling from the prior distribution.
58
31
"""
59
- Uniform()
60
- Uniform(lower, upper)
32
+ struct PriorInit <: AbstractInitStrategy end
33
+ init (rng:: Random.AbstractRNG , :: VarName , dist:: Distribution , :: PriorInit ) = rand (rng, dist)
34
+
35
+ """
36
+ UniformInit()
37
+ UniformInit(lower, upper)
61
38
62
39
Obtain new values by first transforming the distribution of the random variable
63
40
to unconstrained space, and then sampling a value uniformly between `lower` and
@@ -70,41 +47,65 @@ default initialisation strategy.
70
47
71
48
[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization)
72
49
"""
73
- struct Uniform {T<: AbstractFloat } <: AbstractInitStrategy
50
+ struct UniformInit {T<: AbstractFloat } <: AbstractInitStrategy
74
51
lower:: T
75
52
upper:: T
53
+ function UniformInit (lower:: T , upper:: T ) where {T<: AbstractFloat }
54
+ lower > upper &&
55
+ throw (ArgumentError (" `lower` must be less than or equal to `upper`" ))
56
+ return new {T} (lower, upper)
57
+ end
58
+ UniformInit () = UniformInit (- 2.0 , 2.0 )
59
+ end
60
+ function init (rng:: Random.AbstractRNG , :: VarName , dist:: Distribution , u:: UniformInit )
61
+ b = Bijectors. bijector (dist)
62
+ sz = Bijectors. output_size (b, size (dist))
63
+ y = rand (rng, Uniform (u. lower, u. upper), sz)
64
+ b_inv = Bijectors. inverse (b)
65
+ return b_inv (y)
76
66
end
77
- Uniform () = Uniform (- 2 , 2 )
78
67
79
68
"""
80
- Params (params::AbstractDict{VarName, Any }, default::AbstractInitStrategy)
81
- Params (params::NamedTuple, default::AbstractInitStrategy)
69
+ ParamsInit (params::AbstractDict{<: VarName}, default::AbstractInitStrategy=PriorInit() )
70
+ ParamsInit (params::NamedTuple, default::AbstractInitStrategy=PriorInit() )
82
71
83
72
Obtain new values by extracting them from the given dictionary or NamedTuple.
84
- These values are assumed to be provided in the space of the untransformed
85
- distribution.
86
-
87
73
The parameter `default` specifies how new values are to be obtained if they
88
- cannot be found in `params`. The default for `default` is `Prior()`.
74
+ cannot be found in `params`, or they are specified as `missing`. The default
75
+ for `default` is `PriorInit()`.
76
+
77
+ !!! note
78
+ These values must be provided in the space of the untransformed distribution.
89
79
"""
90
- struct Params {P,S<: AbstractInitStrategy } <: AbstractInitStrategy
80
+ struct ParamsInit {P,S<: AbstractInitStrategy } <: AbstractInitStrategy
91
81
params:: P
92
82
default:: S
93
-
94
- function Params (
95
- params:: AbstractDict{VarName,Any} , default:: AbstractInitStrategy = Prior ()
96
- )
83
+ function ParamsInit (params:: AbstractDict{<:VarName} , default:: AbstractInitStrategy )
97
84
return new {typeof(params),typeof(default)} (params, default)
98
85
end
99
- function Params (params:: NamedTuple , default:: AbstractInitStrategy = Prior ())
100
- return Params (to_varname_dict (params), default)
86
+ ParamsInit (params:: AbstractDict{<:VarName} ) = ParamsInit (params, PriorInit ())
87
+ function ParamsInit (params:: NamedTuple , default:: AbstractInitStrategy = PriorInit ())
88
+ return ParamsInit (to_varname_dict (params), default)
89
+ end
90
+ end
91
+ function init (rng:: Random.AbstractRNG , vn:: VarName , dist:: Distribution , p:: ParamsInit )
92
+ return if hasvalue (p. params, vn)
93
+ x = getvalue (p. params, vn)
94
+ if x === missing
95
+ init (rng, vn, dist, p. default)
96
+ else
97
+ # TODO : Check that the type of x matches the dist?
98
+ x
99
+ end
100
+ else
101
+ init (rng, vn, dist, p. default)
101
102
end
102
103
end
103
104
104
105
"""
105
106
InitContext(
106
107
[rng::Random.AbstractRNG=Random.default_rng()],
107
- [strategy::AbstractInitStrategy=Prior ()],
108
+ [strategy::AbstractInitStrategy=PriorInit ()],
108
109
)
109
110
110
111
A leaf context that indicates that new values for random variables are
@@ -115,95 +116,144 @@ VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then
115
116
struct InitContext{R<: Random.AbstractRNG ,S<: AbstractInitStrategy } <: AbstractContext
116
117
rng:: R
117
118
strategy:: S
118
- function InitContext (rng:: Random.AbstractRNG , strategy:: AbstractInitStrategy = Prior ())
119
+ function InitContext (
120
+ rng:: Random.AbstractRNG , strategy:: AbstractInitStrategy = PriorInit ()
121
+ )
119
122
return new {typeof(rng),typeof(strategy)} (rng, strategy)
120
123
end
121
- function InitContext (strategy:: AbstractInitStrategy = Prior ())
124
+ function InitContext (strategy:: AbstractInitStrategy = PriorInit ())
122
125
return InitContext (Random. default_rng (), strategy)
123
126
end
124
127
end
125
128
NodeTrait (:: InitContext ) = IsLeaf ()
126
129
127
130
function tilde_assume (
128
- ctx:: InitContext{<:Random.AbstractRNG,Prior} ,
129
- dist:: Distribution ,
130
- vn:: VarName ,
131
- vi:: AbstractVarInfo ,
131
+ ctx:: InitContext , dist:: Distribution , vn:: VarName , vi:: AbstractVarInfo
132
132
)
133
- r = rand (ctx. rng, dist)
134
- vi[vn] = r
135
- # TODO : FIX
136
- logjac = 0
137
- vi = accumulate_assume!! (vi, r, - logjac, vn, dist)
138
- println (" sampled $r from $dist for $vn " )
139
- return r, vi
133
+ in_varinfo = haskey (vi, vn)
134
+ # `init()` always returns values in original space, i.e. possibly
135
+ # constrained
136
+ x = init (ctx. rng, vn, dist, ctx. strategy)
137
+ # There is a function `to_maybe_linked_internal_transform` that does this,
138
+ # but unfortunately it uses `istrans(vi, vn)` which fails if vn is not in
139
+ # vi, so we have to manually check. By default we will insert an unlinked
140
+ # value into the varinfo.
141
+ is_transformed = in_varinfo ? istrans (vi, vn) : false
142
+ f = if is_transformed
143
+ to_linked_internal_transform (vi, vn, dist)
144
+ else
145
+ to_internal_transform (vi, vn, dist)
146
+ end
147
+ # TODO (penelopeysm): We would really like to do:
148
+ # y, logjac = with_logabsdet_jacobian(f, x)
149
+ # Unfortunately, `to_{linked_}internal_transform` returns a function that
150
+ # always converts x to a vector, i.e., if dist is univariate, f(x) will be
151
+ # a vector of length 1. It would be nice if we could unify these.
152
+ y = f (x)
153
+ logjac = logabsdetjac (is_transformed ? Bijectors. bijector (dist) : identity, x)
154
+ # Add the new value to the VarInfo. `push!!` errors if the value already
155
+ # exists, hence the need for setindex!!
156
+ if in_varinfo
157
+ vi = setindex!! (vi, y, vn)
158
+ else
159
+ vi = push!! (vi, vn, y, dist)
160
+ end
161
+ # `accumulate_assume!!` wants untransformed values as the second argument.
162
+ vi = accumulate_assume!! (vi, x, - logjac, vn, dist)
163
+ # We always return the untransformed value here, as that will determine
164
+ # what the lhs of the tilde-statement is set to.
165
+ return x, vi
140
166
end
141
167
142
- # TODO : Remove this thing.
143
- # function assume(
144
- # rng::Random.AbstractRNG,
145
- # init_strategy::AbstractInitStrategy,
146
- # dist::Distribution,
147
- # vn::VarName,
148
- # vi::AbstractVarInfo,
168
+ # """
169
+ # set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
170
+ # set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
171
+ #
172
+ # Take the values inside `initial_params`, replace the corresponding values in
173
+ # the given VarInfo object, and return a new VarInfo object with the updated values.
174
+ #
175
+ # This differs from `DynamicPPL.unflatten` in two ways:
176
+ #
177
+ # 1. It works with `NamedTuple` arguments.
178
+ # 2. For the `AbstractVector` method, if any of the elements are missing, it will not
179
+ # overwrite the original value in the VarInfo (it will just use the original
180
+ # value instead).
181
+ # """
182
+ # function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
183
+ # throw(
184
+ # ArgumentError(
185
+ # "`initial_params` must be a vector of type `Union{Real,Missing}`. " *
186
+ # "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.",
187
+ # ),
188
+ # )
189
+ # end
190
+ #
191
+ # function set_initial_values(
192
+ # varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
149
193
# )
150
- # if haskey(vi, vn)
151
- # # Always overwrite the parameters with new ones for `SampleFromUniform`.
152
- # if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
153
- # # TODO (mhauru) Is it important to unset the flag here? The `true` allows us
154
- # # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
155
- # # if that's okay.
156
- # unset_flag!(vi, vn, "del", true)
157
- # r = init(rng, dist, sampler)
158
- # f = to_maybe_linked_internal_transform(vi, vn, dist)
159
- # # TODO (mhauru) This should probably be call a function called setindex_internal!
160
- # vi = BangBang.setindex!!(vi, f(r), vn)
161
- # setorder!(vi, vn, get_num_produce(vi))
162
- # else
163
- # # Otherwise we just extract it.
164
- # r = vi[vn, dist]
165
- # end
166
- # else
167
- # r = init(rng, dist, sampler)
168
- # if istrans(vi)
169
- # f = to_linked_internal_transform(vi, vn, dist)
170
- # vi = push!!(vi, vn, f(r), dist)
171
- # # By default `push!!` sets the transformed flag to `false`.
172
- # vi = settrans!!(vi, true, vn)
173
- # else
174
- # vi = push!!(vi, vn, r, dist)
194
+ # flattened_param_vals = varinfo[:]
195
+ # length(flattened_param_vals) == length(initial_params) || throw(
196
+ # DimensionMismatch(
197
+ # "Provided initial value size ($(length(initial_params))) doesn't match " *
198
+ # "the model size ($(length(flattened_param_vals))).",
199
+ # ),
200
+ # )
201
+ #
202
+ # # Update values that are provided.
203
+ # for i in eachindex(initial_params)
204
+ # x = initial_params[i]
205
+ # if x !== missing
206
+ # flattened_param_vals[i] = x
175
207
# end
176
208
# end
177
209
#
178
- # # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
179
- # logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
180
- # vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
181
- # return r, vi
210
+ # # Update in `varinfo`.
211
+ # new_varinfo = unflatten(varinfo, flattened_param_vals)
212
+ # return new_varinfo
182
213
# end
183
-
184
- # function assume(
185
- # rng::Random.AbstractRNG,
186
- # sampler::Union{SampleFromPrior,SampleFromUniform},
187
- # dist::Distribution,
188
- # vn::VarName,
189
- # vi::SimpleOrThreadSafeSimple,
190
- # )
191
- # value = init(rng, dist, sampler)
192
- # # Transform if we're working in unconstrained space.
193
- # f = to_maybe_linked_internal_transform(vi, vn, dist)
194
- # value_raw, logjac = with_logabsdet_jacobian(f, value)
195
- # vi = BangBang.push!!(vi, vn, value_raw, dist)
196
- # vi = accumulate_assume!!(vi, value, -logjac, vn, dist)
197
- # return value, vi
198
- # end
199
-
200
- # Initializations.
201
- # init(rng, dist, ::SampleFromPrior) = rand(rng, dist)
202
- # function init(rng, dist, ::SampleFromUniform)
203
- # return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist)
214
+ #
215
+ # function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
216
+ # varinfo = deepcopy(varinfo)
217
+ # vars_in_varinfo = keys(varinfo)
218
+ # for v in keys(initial_params)
219
+ # vn = VarName{v}()
220
+ # if !(vn in vars_in_varinfo)
221
+ # for vv in vars_in_varinfo
222
+ # if subsumes(vn, vv)
223
+ # throw(
224
+ # ArgumentError(
225
+ # "The current model contains sub-variables of $v, such as ($vv). " *
226
+ # "Using NamedTuple for initial_params is not supported in such a case. " *
227
+ # "Please use AbstractVector for initial_params instead of NamedTuple.",
228
+ # ),
229
+ # )
230
+ # end
231
+ # end
232
+ # throw(ArgumentError("Variable $v not found in the model."))
233
+ # end
234
+ # end
235
+ # initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
236
+ # return update_values!!(
237
+ # varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))
238
+ # )
204
239
# end
205
240
#
206
- # init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n)
207
- # function init(rng, dist, ::SampleFromUniform, n::Int)
208
- # return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
241
+ # function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model)
242
+ # @debug "Using passed-in initial variable values" initial_params
243
+ #
244
+ # # `link` the varinfo if needed.
245
+ # linked = islinked(vi)
246
+ # if linked
247
+ # vi = invlink!!(vi, model)
248
+ # end
249
+ #
250
+ # # Set the values in `vi`.
251
+ # vi = set_initial_values(vi, initial_params)
252
+ #
253
+ # # `invlink` if needed.
254
+ # if linked
255
+ # vi = link!!(vi, model)
256
+ # end
257
+ #
258
+ # return vi
209
259
# end
0 commit comments