68
68
69
69
Return a default varinfo object for the given `model` and `sampler`.
70
70
71
+ The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
72
+
71
73
# Arguments
72
74
- `rng::Random.AbstractRNG`: Random number generator.
73
75
- `model::Model`: Model for which we want to create a varinfo object.
@@ -76,9 +78,10 @@ Return a default varinfo object for the given `model` and `sampler`.
76
78
# Returns
77
79
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
78
80
"""
79
- function default_varinfo (rng:: Random.AbstractRNG , model:: Model , sampler:: AbstractSampler )
80
- init_sampler = initialsampler (sampler)
81
- return typed_varinfo (rng, model, init_sampler)
81
+ function default_varinfo (:: Random.AbstractRNG , :: Model , :: AbstractSampler )
82
+ # Note that variable values are unconditionally initialized later, so no
83
+ # point putting them in now.
84
+ return typed_varinfo (VarInfo ())
82
85
end
83
86
84
87
function AbstractMCMC. sample (
@@ -96,24 +99,32 @@ function AbstractMCMC.sample(
96
99
)
97
100
end
98
101
99
- # initial step: general interface for resuming and
102
+ """
103
+ init_strategy(sampler)
104
+
105
+ Define the initialisation strategy used for generating initial values when
106
+ sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
107
+ """
108
+ init_strategy (:: Sampler ) = PriorInit ()
109
+
100
110
function AbstractMCMC. step (
101
- rng:: Random.AbstractRNG , model:: Model , spl:: Sampler ; initial_params= nothing , kwargs...
111
+ rng:: Random.AbstractRNG ,
112
+ model:: Model ,
113
+ spl:: Sampler ;
114
+ initial_params:: AbstractInitStrategy = init_strategy (spl),
115
+ kwargs... ,
102
116
)
103
- # Sample initial values.
117
+ # Generate the default varinfo (usually this just makes an empty VarInfo
118
+ # with NamedTuple of Metadata).
104
119
vi = default_varinfo (rng, model, spl)
105
120
106
- # Update the parameters if provided.
107
- if initial_params != = nothing
108
- vi = initialize_parameters!! (vi, initial_params, model)
109
-
110
- # Update joint log probability.
111
- # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
112
- # and https://github.com/TuringLang/Turing.jl/issues/1563
113
- # to avoid that existing variables are resampled
114
- vi = last (evaluate!! (model, vi))
115
- end
121
+ # Fill it with initial parameters. Note that, if `ParamsInit` is used, the
122
+ # parameters provided must be in unlinked space (when inserted into the
123
+ # varinfo, they will be adjusted to match the linking status of the
124
+ # varinfo).
125
+ _, vi = init!! (rng, model, vi, initial_params)
116
126
127
+ # Call the actual function that does the first step.
117
128
return initialstep (rng, model, spl, vi; initial_params, kwargs... )
118
129
end
119
130
@@ -131,110 +142,7 @@ loadstate(data) = data
131
142
132
143
Default type of the chain of posterior samples from `sampler`.
133
144
"""
134
- default_chain_type (sampler:: Sampler ) = Any
135
-
136
- """
137
- initialsampler(sampler::Sampler)
138
-
139
- Return the sampler that is used for generating the initial parameters when sampling with
140
- `sampler`.
141
-
142
- By default, it returns an instance of [`SampleFromPrior`](@ref).
143
- """
144
- initialsampler (spl:: Sampler ) = SampleFromPrior ()
145
-
146
- """
147
- set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
148
- set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
149
-
150
- Take the values inside `initial_params`, replace the corresponding values in
151
- the given VarInfo object, and return a new VarInfo object with the updated values.
152
-
153
- This differs from `DynamicPPL.unflatten` in two ways:
154
-
155
- 1. It works with `NamedTuple` arguments.
156
- 2. For the `AbstractVector` method, if any of the elements are missing, it will not
157
- overwrite the original value in the VarInfo (it will just use the original
158
- value instead).
159
- """
160
- function set_initial_values (varinfo:: AbstractVarInfo , initial_params:: AbstractVector )
161
- throw (
162
- ArgumentError (
163
- " `initial_params` must be a vector of type `Union{Real,Missing}`. " *
164
- " If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first." ,
165
- ),
166
- )
167
- end
168
-
169
- function set_initial_values (
170
- varinfo:: AbstractVarInfo , initial_params:: AbstractVector{<:Union{Real,Missing}}
171
- )
172
- flattened_param_vals = varinfo[:]
173
- length (flattened_param_vals) == length (initial_params) || throw (
174
- DimensionMismatch (
175
- " Provided initial value size ($(length (initial_params)) ) doesn't match " *
176
- " the model size ($(length (flattened_param_vals)) )." ,
177
- ),
178
- )
179
-
180
- # Update values that are provided.
181
- for i in eachindex (initial_params)
182
- x = initial_params[i]
183
- if x != = missing
184
- flattened_param_vals[i] = x
185
- end
186
- end
187
-
188
- # Update in `varinfo`.
189
- new_varinfo = unflatten (varinfo, flattened_param_vals)
190
- return new_varinfo
191
- end
192
-
193
- function set_initial_values (varinfo:: AbstractVarInfo , initial_params:: NamedTuple )
194
- varinfo = deepcopy (varinfo)
195
- vars_in_varinfo = keys (varinfo)
196
- for v in keys (initial_params)
197
- vn = VarName {v} ()
198
- if ! (vn in vars_in_varinfo)
199
- for vv in vars_in_varinfo
200
- if subsumes (vn, vv)
201
- throw (
202
- ArgumentError (
203
- " The current model contains sub-variables of $v , such as ($vv ). " *
204
- " Using NamedTuple for initial_params is not supported in such a case. " *
205
- " Please use AbstractVector for initial_params instead of NamedTuple." ,
206
- ),
207
- )
208
- end
209
- end
210
- throw (ArgumentError (" Variable $v not found in the model." ))
211
- end
212
- end
213
- initial_params = NamedTuple (k => v for (k, v) in pairs (initial_params) if v != = missing )
214
- return update_values!! (
215
- varinfo, initial_params, map (k -> VarName {k} (), keys (initial_params))
216
- )
217
- end
218
-
219
- function initialize_parameters!! (vi:: AbstractVarInfo , initial_params, model:: Model )
220
- @debug " Using passed-in initial variable values" initial_params
221
-
222
- # `link` the varinfo if needed.
223
- linked = islinked (vi)
224
- if linked
225
- vi = invlink!! (vi, model)
226
- end
227
-
228
- # Set the values in `vi`.
229
- vi = set_initial_values (vi, initial_params)
230
-
231
- # `invlink` if needed.
232
- if linked
233
- vi = link!! (vi, model)
234
- end
235
-
236
- return vi
237
- end
145
+ default_chain_type (:: Sampler ) = Any
238
146
239
147
"""
240
148
initialstep(rng, model, sampler, varinfo; kwargs...)
0 commit comments