@@ -39,7 +39,7 @@ julia> rng = StableRNG(42);
39
39
julia> # In the `NamedTuple` version we need to provide the place-holder values for
40
40
# the variables which are using "containers", e.g. `Array`.
41
41
# In this case, this means that we need to specify `x` but not `m`.
42
- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo((x = ones(2), )));
42
+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo((x = ones(2), )));
43
43
44
44
julia> # (✓) Vroom, vroom! FAST!!!
45
45
vi[@varname(x[1])]
@@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])]
57
57
1.3736306979834252
58
58
59
59
julia> # (×) If we don't provide the container...
60
- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo()); vi
60
+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo()); vi
61
61
ERROR: type NamedTuple has no field x
62
62
[...]
63
63
64
64
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
65
- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
65
+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
66
66
67
67
julia> # (✓) Sort of fast, but only possible at runtime.
68
68
vi[@varname(x[1])]
@@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods)
91
91
92
92
julia> m = demo_constrained();
93
93
94
- julia> _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo());
94
+ julia> _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo());
95
95
96
96
julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
97
97
1.8632965762164932
98
98
99
- julia> _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
99
+ julia> _, vi = DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
100
100
101
101
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
102
102
-0.21080155351918753
103
103
104
- julia> xs = [last(DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
104
+ julia> xs = [last(DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
105
105
106
106
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107
107
true
108
108
109
109
julia> # And with `OrderedDict` of course!
110
- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
110
+ _, vi = DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111
111
112
112
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
113
113
0.6225185067787314
114
114
115
- julia> xs = [last(DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
115
+ julia> xs = [last(DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
116
116
117
117
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
118
118
true
@@ -226,24 +226,25 @@ end
226
226
227
227
# Constructor from `Model`.
228
228
function SimpleVarInfo {T} (
229
- rng:: Random.AbstractRNG , model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
229
+ rng:: Random.AbstractRNG , model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
230
230
) where {T<: Real }
231
- new_model = contextualize (model, SamplingContext (rng, sampler, model. context))
231
+ new_context = setleafcontext (model. context, InitContext (rng, init_strategy))
232
+ new_model = contextualize (model, new_context)
232
233
return last (evaluate!! (new_model, SimpleVarInfo {T} ()))
233
234
end
234
235
function SimpleVarInfo {T} (
235
- model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
236
+ model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
236
237
) where {T<: Real }
237
- return SimpleVarInfo {T} (Random. default_rng (), model, sampler )
238
+ return SimpleVarInfo {T} (Random. default_rng (), model, init_strategy )
238
239
end
239
240
# Constructors without type param
240
241
function SimpleVarInfo (
241
- rng:: Random.AbstractRNG , model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
242
+ rng:: Random.AbstractRNG , model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
242
243
)
243
- return SimpleVarInfo {LogProbType} (rng, model, sampler )
244
+ return SimpleVarInfo {LogProbType} (rng, model, init_strategy )
244
245
end
245
- function SimpleVarInfo (model:: Model , sampler :: AbstractSampler = SampleFromPrior ())
246
- return SimpleVarInfo {LogProbType} (Random. default_rng (), model, sampler )
246
+ function SimpleVarInfo (model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ())
247
+ return SimpleVarInfo {LogProbType} (Random. default_rng (), model, init_strategy )
247
248
end
248
249
249
250
# Constructor from `VarInfo`.
@@ -259,12 +260,12 @@ end
259
260
260
261
function untyped_simple_varinfo (model:: Model )
261
262
varinfo = SimpleVarInfo (OrderedDict {VarName,Any} ())
262
- return last (evaluate_and_sample !! (model, varinfo))
263
+ return last (init !! (model, varinfo))
263
264
end
264
265
265
266
function typed_simple_varinfo (model:: Model )
266
267
varinfo = SimpleVarInfo {Float64} ()
267
- return last (evaluate_and_sample !! (model, varinfo))
268
+ return last (init !! (model, varinfo))
268
269
end
269
270
270
271
function unflatten (svi:: SimpleVarInfo , x:: AbstractVector )
@@ -474,7 +475,6 @@ function assume(
474
475
return value, vi
475
476
end
476
477
477
- # NOTE: We don't implement `settrans!!(vi, trans, vn)`.
478
478
function settrans!! (vi:: SimpleVarInfo , trans)
479
479
return settrans!! (vi, trans ? DynamicTransformation () : NoTransformation ())
480
480
end
484
484
function settrans!! (vi:: ThreadSafeVarInfo{<:SimpleVarInfo} , trans)
485
485
return Accessors. @set vi. varinfo = settrans!! (vi. varinfo, trans)
486
486
end
487
+ function settrans!! (vi:: SimpleOrThreadSafeSimple , trans:: Bool , :: VarName )
488
+ # We keep this method around just to obey the AbstractVarInfo interface; however,
489
+ # this is only a valid operation if it would be a no-op.
490
+ if trans != istrans (vi)
491
+ error (
492
+ " Individual variables in SimpleVarInfo cannot have different `settrans` statuses." ,
493
+ )
494
+ end
495
+ end
487
496
488
497
istrans (vi:: SimpleVarInfo ) = ! (vi. transformation isa NoTransformation)
489
498
istrans (vi:: SimpleVarInfo , :: VarName ) = istrans (vi)
0 commit comments