Skip to content

Commit 7b72071

Browse files
committed
Add tests for InitContext behaviour
1 parent b3b9df6 commit 7b72071

File tree

3 files changed

+184
-16
lines changed

3 files changed

+184
-16
lines changed

src/DynamicPPL.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ export AbstractVarInfo,
109109
ConditionContext,
110110
assume,
111111
tilde_assume,
112+
# Initialisation
113+
AbstractInitStrategy,
114+
PriorInit,
115+
UniformInit,
116+
ParamsInit,
112117
# Pseudo distributions
113118
NamedDist,
114119
NoDist,

src/contexts/init.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,11 @@ function tilde_assume(
147147
# are linked.
148148
insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi)
149149
f = if insert_transformed_value
150-
to_linked_internal_transform(vi, vn, dist)
150+
link_transform(dist)
151151
else
152-
to_internal_transform(vi, vn, dist)
152+
identity
153153
end
154-
# TODO(penelopeysm): We would really like to do:
155-
# y, logjac = with_logabsdet_jacobian(f, x)
156-
# Unfortunately, `to_{linked_}internal_transform` returns a function that
157-
# always converts x to a vector, i.e., if dist is univariate, f(x) will be
158-
# a vector of length 1. It would be nice if we could unify these.
159-
y = f(x)
160-
logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x)
154+
y, logjac = with_logabsdet_jacobian(f, x)
161155
# Add the new value to the VarInfo. `push!!` errors if the value already
162156
# exists, hence the need for setindex!!.
163157
if in_varinfo

test/contexts.jl

Lines changed: 176 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ using DynamicPPL:
2020
hasconditioned_nested,
2121
getconditioned_nested,
2222
collapse_prefix_stack,
23-
prefix_cond_and_fixed_variables,
24-
getvalue
23+
prefix_cond_and_fixed_variables
24+
using LinearAlgebra: I
25+
using Random: Xoshiro
2526

2627
using EnzymeCore
2728

@@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
103104
# sometimes only the main symbol (e.g. it contains `x` when
104105
# `vn` is `x[1]`)
105106
for vn in conditioned_vns
106-
val = DynamicPPL.getvalue(conditioned_values, vn)
107+
val = getvalue(conditioned_values, vn)
107108
# These VarNames are present in the conditioning values, so
108109
# we should always be able to extract the value.
109110
@test hasconditioned_nested(context, vn)
@@ -433,12 +434,180 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
433434
end
434435

435436
@testset "InitContext" begin
436-
@testset "PriorInit" begin end
437+
empty_varinfos = [
438+
VarInfo(),
439+
DynamicPPL.typed_varinfo(VarInfo()),
440+
VarInfo(DynamicPPL.VarNamedVector()),
441+
DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())),
442+
SimpleVarInfo(),
443+
SimpleVarInfo(Dict{VarName,Any}()),
444+
]
445+
446+
@model function test_init_model()
447+
x ~ Normal()
448+
y ~ MvNormal(fill(x, 2), I)
449+
1.0 ~ Normal()
450+
return nothing
451+
end
452+
function test_generating_new_values(strategy::AbstractInitStrategy)
453+
@testset "generating new values: $(typeof(strategy))" begin
454+
# Check that init!! can generate values that weren't there
455+
# previously.
456+
model = test_init_model()
457+
for empty_vi in empty_varinfos
458+
this_vi = deepcopy(empty_vi)
459+
_, vi = DynamicPPL.init!!(model, this_vi, strategy)
460+
@test Set(keys(vi)) == Set([@varname(x), @varname(y)])
461+
x, y = vi[@varname(x)], vi[@varname(y)]
462+
@test x isa Real
463+
@test y isa AbstractVector{<:Real}
464+
@test length(y) == 2
465+
(; logprior, loglikelihood) = getlogp(vi)
466+
@test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) ==
467+
logprior
468+
@test logpdf(Normal(), 1.0) == loglikelihood
469+
end
470+
end
471+
end
472+
function test_replacing_values(strategy::AbstractInitStrategy)
473+
@testset "replacing old values: $(typeof(strategy))" begin
474+
# Check that init!! can overwrite values that were already there.
475+
model = test_init_model()
476+
for empty_vi in empty_varinfos
477+
# start by generating some rubbish values
478+
vi = deepcopy(empty_vi)
479+
old_x, old_y = 100000.00, [300000.00, 500000.00]
480+
push!!(vi, @varname(x), old_x, Normal())
481+
push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I))
482+
# then overwrite it
483+
_, new_vi = DynamicPPL.init!!(model, vi, strategy)
484+
new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)]
485+
# check that the values are (presumably) different
486+
@test old_x != new_x
487+
@test old_y != new_y
488+
end
489+
end
490+
end
491+
function test_rng_respected(strategy::AbstractInitStrategy)
492+
@testset "check that RNG is respected: $(typeof(strategy))" begin
493+
model = test_init_model()
494+
for empty_vi in empty_varinfos
495+
_, vi1 = DynamicPPL.init!!(
496+
Xoshiro(468), model, deepcopy(empty_vi), strategy
497+
)
498+
_, vi2 = DynamicPPL.init!!(
499+
Xoshiro(468), model, deepcopy(empty_vi), strategy
500+
)
501+
_, vi3 = DynamicPPL.init!!(
502+
Xoshiro(469), model, deepcopy(empty_vi), strategy
503+
)
504+
@test vi1[@varname(x)] == vi2[@varname(x)]
505+
@test vi1[@varname(y)] == vi2[@varname(y)]
506+
@test vi1[@varname(x)] != vi3[@varname(x)]
507+
@test vi1[@varname(y)] != vi3[@varname(y)]
508+
end
509+
end
510+
end
437511

438-
@testset "UniformInit" begin end
512+
@testset "PriorInit" begin
513+
test_generating_new_values(PriorInit())
514+
test_replacing_values(PriorInit())
515+
test_rng_respected(PriorInit())
516+
517+
@testset "check that values are within support" begin
518+
# Not many other sensible checks we can do for priors.
519+
@model just_unif() = x ~ Uniform(0.0, 1e-7)
520+
for _ in 1:100
521+
_, vi = DynamicPPL.init!!(just_unif(), VarInfo(), PriorInit())
522+
@test vi[@varname(x)] isa Real
523+
@test 0.0 <= vi[@varname(x)] <= 1e-7
524+
end
525+
end
526+
end
439527

440-
@testset "ParamsInit" begin end
528+
@testset "UniformInit" begin
529+
test_generating_new_values(UniformInit())
530+
test_replacing_values(UniformInit())
531+
test_rng_respected(UniformInit())
532+
533+
@testset "check that bounds are respected" begin
534+
@testset "unconstrained" begin
535+
umin, umax = -1.0, 1.0
536+
@model just_norm() = x ~ Normal()
537+
for _ in 1:100
538+
_, vi = DynamicPPL.init!!(
539+
just_norm(), VarInfo(), UniformInit(umin, umax)
540+
)
541+
@test vi[@varname(x)] isa Real
542+
@test umin <= vi[@varname(x)] <= umax
543+
end
544+
end
545+
@testset "constrained" begin
546+
umin, umax = -1.0, 1.0
547+
@model just_beta() = x ~ Beta(2, 2)
548+
inv_bijector = inverse(Bijectors.bijector(Beta(2, 2)))
549+
tmin, tmax = inv_bijector(umin), inv_bijector(umax)
550+
for _ in 1:100
551+
_, vi = DynamicPPL.init!!(
552+
just_beta(), VarInfo(), UniformInit(umin, umax)
553+
)
554+
@test vi[@varname(x)] isa Real
555+
@test tmin <= vi[@varname(x)] <= tmax
556+
end
557+
end
558+
end
559+
end
441560

442-
@testset "rng is respected (at least with PriorInit" begin end
561+
@testset "ParamsInit" begin
562+
@testset "given full set of parameters" begin
563+
# test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564+
my_x, my_y = 1.0, [2.0, 3.0]
565+
params_nt = (; x=my_x, y=my_y)
566+
params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y)
567+
model = test_init_model()
568+
for empty_vi in empty_varinfos
569+
_, vi = DynamicPPL.init!!(
570+
model, deepcopy(empty_vi), ParamsInit(params_nt)
571+
)
572+
@test vi[@varname(x)] == my_x
573+
@test vi[@varname(y)] == my_y
574+
logp_nt = getlogp(vi)
575+
_, vi = DynamicPPL.init!!(
576+
model, deepcopy(empty_vi), ParamsInit(params_dict)
577+
)
578+
@test vi[@varname(x)] == my_x
579+
@test vi[@varname(y)] == my_y
580+
logp_dict = getlogp(vi)
581+
@test logp_nt == logp_dict
582+
end
583+
end
584+
585+
@testset "given only partial parameters" begin
586+
# In this case, we expect `ParamsInit` to use the value of x, and
587+
# generate a new value for y.
588+
my_x = 1.0
589+
params_nt = (; x=my_x)
590+
params_dict = Dict(@varname(x) => my_x)
591+
model = test_init_model()
592+
for empty_vi in empty_varinfos
593+
_, vi = DynamicPPL.init!!(
594+
Xoshiro(468), model, deepcopy(empty_vi), ParamsInit(params_nt)
595+
)
596+
@test vi[@varname(x)] == my_x
597+
nt_y = vi[@varname(y)]
598+
@test nt_y isa AbstractVector{<:Real}
599+
@test length(nt_y) == 2
600+
_, vi = DynamicPPL.init!!(
601+
Xoshiro(469), model, deepcopy(empty_vi), ParamsInit(params_dict)
602+
)
603+
@test vi[@varname(x)] == my_x
604+
dict_y = vi[@varname(y)]
605+
@test dict_y isa AbstractVector{<:Real}
606+
@test length(dict_y) == 2
607+
# the values should be different since we used different seeds
608+
@test dict_y != nt_y
609+
end
610+
end
611+
end
443612
end
444613
end

0 commit comments

Comments
 (0)