Skip to content

Commit 38652f7

Browse files
committed
init_strategy
1 parent 20d0909 commit 38652f7

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

src/debug_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ function has_static_constraints(
520520
traces = map(last, results)
521521
dists_per_trace = map(distributions_in_trace, traces)
522522
transforms = map(dists_per_trace) do dists
523-
map(Bijectors.bijector, dists)
523+
map(DynamicPPL.link_transform, dists)
524524
end
525525

526526
# Check if the distributions are the same across all runs.

src/sampler.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ Default type of the chain of posterior samples from `sampler`.
8888
"""
8989
default_chain_type(sampler::Sampler) = Any
9090

91+
"""
92+
init_strategy(sampler)
93+
94+
Define the initialisation strategy used for generating initial values when
95+
sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
96+
"""
97+
init_strategy(::Sampler) = PriorInit()
98+
9199
"""
92100
initialstep(rng, model, sampler, varinfo; kwargs...)
93101

src/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ from_internal_transform(vi::SimpleVarInfo, ::VarName, dist) = identity
620620
# TODO: Should the following methods specialize on the case where we have a `StaticTransformation{<:Bijectors.NamedTransform}`?
621621
from_linked_internal_transform(vi::SimpleVarInfo, ::VarName) = identity
622622
function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist)
623-
return inverse(Bijectors.bijector(dist))
623+
return invlink_transform(dist)
624624
end
625625

626626
has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector

src/transforming.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function tilde_assume(
2626

2727
# Only transform if `!isinverse` since `vi[vn, right]`
2828
# already performs the inverse transformation if it's transformed.
29-
r_transformed = isinverse ? r : Bijectors.bijector(right)(r)
29+
r_transformed = isinverse ? r : link_transform(right)(r)
3030
if hasacc(vi, Val(:LogPrior))
3131
vi = acclogprior!!(vi, lp)
3232
end

src/utils.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,28 @@ function to_namedtuple_expr(syms, vals)
197197
return :(NamedTuple{$names_expr}($vals_expr))
198198
end
199199

200+
"""
201+
link_transform(dist)
202+
Return the constrained-to-unconstrained bijector for distribution `dist`.
203+
By default, this is just `Bijectors.bijector(dist)`.
204+
!!! warning
205+
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
206+
hence that needs to be overloaded separately if the intention is
207+
to change behavior of an existing distribution.
208+
"""
209+
link_transform(dist) = bijector(dist)
210+
211+
"""
212+
invlink_transform(dist)
213+
Return the unconstrained-to-constrained bijector for distribution `dist`.
214+
By default, this is just `inverse(link_transform(dist))`.
215+
!!! warning
216+
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
217+
hence that needs to be overloaded separately if the intention is
218+
to change behavior of an existing distribution.
219+
"""
220+
invlink_transform(dist) = inverse(link_transform(dist))
221+
200222
#####################################################
201223
# Helper functions for vectorize/reconstruct values #
202224
#####################################################
@@ -355,14 +377,11 @@ from_vec_transform(f, sz) = from_vec_transform_for_size(Bijectors.output_size(f,
355377
Return the transformation from the unconstrained vector to the constrained
356378
realization of distribution `dist`.
357379
358-
By default, this is just `inverse(bijector(dist)) ∘ from_vec_transform(dist)`.
359-
360-
See also: [`DynamicPPL.from_vec_transform`](@ref).
380+
See also: [`DynamicPPL.invlink_transform`](@ref), [`DynamicPPL.from_vec_transform`](@ref).
361381
"""
362382
function from_linked_vec_transform(dist::Distribution)
363-
f_link = Bijectors.bijector(dist)
364-
f_invlink = inverse(f_link)
365-
f_vec = from_vec_transform(f_link, size(dist))
383+
f_invlink = invlink_transform(dist)
384+
f_vec = from_vec_transform(inverse(f_invlink), size(dist))
366385
return f_invlink f_vec
367386
end
368387

@@ -372,9 +391,8 @@ end
372391
# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and
373392
# VarNamedVector takes over from Metadata.
374393
function from_linked_vec_transform(dist::UnivariateDistribution)
375-
f_link = Bijectors.bijector(dist)
376-
f_invlink = inverse(f_link)
377-
f_vec = from_vec_transform(f_link, size(dist))
394+
f_invlink = invlink_transform(dist)
395+
f_vec = from_vec_transform(inverse(f_invlink), size(dist))
378396
f_combined = f_invlink f_vec
379397
sz = Bijectors.output_size(f_combined, size(dist))
380398
return UnwrapSingletonTransform(sz) f_combined

0 commit comments

Comments
 (0)