@@ -20,8 +20,9 @@ using DynamicPPL:
20
20
hasconditioned_nested,
21
21
getconditioned_nested,
22
22
collapse_prefix_stack,
23
- prefix_cond_and_fixed_variables,
24
- getvalue
23
+ prefix_cond_and_fixed_variables
24
+ using LinearAlgebra: I
25
+ using Random: Xoshiro
25
26
26
27
using EnzymeCore
27
28
@@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
103
104
# sometimes only the main symbol (e.g. it contains `x` when
104
105
# `vn` is `x[1]`)
105
106
for vn in conditioned_vns
106
- val = DynamicPPL . getvalue (conditioned_values, vn)
107
+ val = getvalue (conditioned_values, vn)
107
108
# These VarNames are present in the conditioning values, so
108
109
# we should always be able to extract the value.
109
110
@test hasconditioned_nested (context, vn)
@@ -433,12 +434,180 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
433
434
end
434
435
435
436
@testset " InitContext" begin
436
- @testset " PriorInit" begin end
437
+ empty_varinfos = [
438
+ VarInfo (),
439
+ DynamicPPL. typed_varinfo (VarInfo ()),
440
+ VarInfo (DynamicPPL. VarNamedVector ()),
441
+ DynamicPPL. typed_vector_varinfo (DynamicPPL. typed_varinfo (VarInfo ())),
442
+ SimpleVarInfo (),
443
+ SimpleVarInfo (Dict {VarName,Any} ()),
444
+ ]
445
+
446
+ @model function test_init_model ()
447
+ x ~ Normal ()
448
+ y ~ MvNormal (fill (x, 2 ), I)
449
+ 1.0 ~ Normal ()
450
+ return nothing
451
+ end
452
+ function test_generating_new_values (strategy:: AbstractInitStrategy )
453
+ @testset " generating new values: $(typeof (strategy)) " begin
454
+ # Check that init!! can generate values that weren't there
455
+ # previously.
456
+ model = test_init_model ()
457
+ for empty_vi in empty_varinfos
458
+ this_vi = deepcopy (empty_vi)
459
+ _, vi = DynamicPPL. init!! (model, this_vi, strategy)
460
+ @test Set (keys (vi)) == Set ([@varname (x), @varname (y)])
461
+ x, y = vi[@varname (x)], vi[@varname (y)]
462
+ @test x isa Real
463
+ @test y isa AbstractVector{<: Real }
464
+ @test length (y) == 2
465
+ (; logprior, loglikelihood) = getlogp (vi)
466
+ @test logpdf (Normal (), x) + logpdf (MvNormal (fill (x, 2 ), I), y) ==
467
+ logprior
468
+ @test logpdf (Normal (), 1.0 ) == loglikelihood
469
+ end
470
+ end
471
+ end
472
+ function test_replacing_values (strategy:: AbstractInitStrategy )
473
+ @testset " replacing old values: $(typeof (strategy)) " begin
474
+ # Check that init!! can overwrite values that were already there.
475
+ model = test_init_model ()
476
+ for empty_vi in empty_varinfos
477
+ # start by generating some rubbish values
478
+ vi = deepcopy (empty_vi)
479
+ old_x, old_y = 100000.00 , [300000.00 , 500000.00 ]
480
+ push!! (vi, @varname (x), old_x, Normal ())
481
+ push!! (vi, @varname (y), old_y, MvNormal (fill (old_x, 2 ), I))
482
+ # then overwrite it
483
+ _, new_vi = DynamicPPL. init!! (model, vi, strategy)
484
+ new_x, new_y = new_vi[@varname (x)], new_vi[@varname (y)]
485
+ # check that the values are (presumably) different
486
+ @test old_x != new_x
487
+ @test old_y != new_y
488
+ end
489
+ end
490
+ end
491
+ function test_rng_respected (strategy:: AbstractInitStrategy )
492
+ @testset " check that RNG is respected: $(typeof (strategy)) " begin
493
+ model = test_init_model ()
494
+ for empty_vi in empty_varinfos
495
+ _, vi1 = DynamicPPL. init!! (
496
+ Xoshiro (468 ), model, deepcopy (empty_vi), strategy
497
+ )
498
+ _, vi2 = DynamicPPL. init!! (
499
+ Xoshiro (468 ), model, deepcopy (empty_vi), strategy
500
+ )
501
+ _, vi3 = DynamicPPL. init!! (
502
+ Xoshiro (469 ), model, deepcopy (empty_vi), strategy
503
+ )
504
+ @test vi1[@varname (x)] == vi2[@varname (x)]
505
+ @test vi1[@varname (y)] == vi2[@varname (y)]
506
+ @test vi1[@varname (x)] != vi3[@varname (x)]
507
+ @test vi1[@varname (y)] != vi3[@varname (y)]
508
+ end
509
+ end
510
+ end
437
511
438
- @testset " UniformInit" begin end
512
+ @testset " PriorInit" begin
513
+ test_generating_new_values (PriorInit ())
514
+ test_replacing_values (PriorInit ())
515
+ test_rng_respected (PriorInit ())
516
+
517
+ @testset " check that values are within support" begin
518
+ # Not many other sensible checks we can do for priors.
519
+ @model just_unif () = x ~ Uniform (0.0 , 1e-7 )
520
+ for _ in 1 : 100
521
+ _, vi = DynamicPPL. init!! (just_unif (), VarInfo (), PriorInit ())
522
+ @test vi[@varname (x)] isa Real
523
+ @test 0.0 <= vi[@varname (x)] <= 1e-7
524
+ end
525
+ end
526
+ end
439
527
440
- @testset " ParamsInit" begin end
528
+ @testset " UniformInit" begin
529
+ test_generating_new_values (UniformInit ())
530
+ test_replacing_values (UniformInit ())
531
+ test_rng_respected (UniformInit ())
532
+
533
+ @testset " check that bounds are respected" begin
534
+ @testset " unconstrained" begin
535
+ umin, umax = - 1.0 , 1.0
536
+ @model just_norm () = x ~ Normal ()
537
+ for _ in 1 : 100
538
+ _, vi = DynamicPPL. init!! (
539
+ just_norm (), VarInfo (), UniformInit (umin, umax)
540
+ )
541
+ @test vi[@varname (x)] isa Real
542
+ @test umin <= vi[@varname (x)] <= umax
543
+ end
544
+ end
545
+ @testset " constrained" begin
546
+ umin, umax = - 1.0 , 1.0
547
+ @model just_beta () = x ~ Beta (2 , 2 )
548
+ inv_bijector = inverse (Bijectors. bijector (Beta (2 , 2 )))
549
+ tmin, tmax = inv_bijector (umin), inv_bijector (umax)
550
+ for _ in 1 : 100
551
+ _, vi = DynamicPPL. init!! (
552
+ just_beta (), VarInfo (), UniformInit (umin, umax)
553
+ )
554
+ @test vi[@varname (x)] isa Real
555
+ @test tmin <= vi[@varname (x)] <= tmax
556
+ end
557
+ end
558
+ end
559
+ end
441
560
442
- @testset " rng is respected (at least with PriorInit" begin end
561
+ @testset " ParamsInit" begin
562
+ @testset " given full set of parameters" begin
563
+ # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564
+ my_x, my_y = 1.0 , [2.0 , 3.0 ]
565
+ params_nt = (; x= my_x, y= my_y)
566
+ params_dict = Dict (@varname (x) => my_x, @varname (y) => my_y)
567
+ model = test_init_model ()
568
+ for empty_vi in empty_varinfos
569
+ _, vi = DynamicPPL. init!! (
570
+ model, deepcopy (empty_vi), ParamsInit (params_nt)
571
+ )
572
+ @test vi[@varname (x)] == my_x
573
+ @test vi[@varname (y)] == my_y
574
+ logp_nt = getlogp (vi)
575
+ _, vi = DynamicPPL. init!! (
576
+ model, deepcopy (empty_vi), ParamsInit (params_dict)
577
+ )
578
+ @test vi[@varname (x)] == my_x
579
+ @test vi[@varname (y)] == my_y
580
+ logp_dict = getlogp (vi)
581
+ @test logp_nt == logp_dict
582
+ end
583
+ end
584
+
585
+ @testset " given only partial parameters" begin
586
+ # In this case, we expect `ParamsInit` to use the value of x, and
587
+ # generate a new value for y.
588
+ my_x = 1.0
589
+ params_nt = (; x= my_x)
590
+ params_dict = Dict (@varname (x) => my_x)
591
+ model = test_init_model ()
592
+ for empty_vi in empty_varinfos
593
+ _, vi = DynamicPPL. init!! (
594
+ Xoshiro (468 ), model, deepcopy (empty_vi), ParamsInit (params_nt)
595
+ )
596
+ @test vi[@varname (x)] == my_x
597
+ nt_y = vi[@varname (y)]
598
+ @test nt_y isa AbstractVector{<: Real }
599
+ @test length (nt_y) == 2
600
+ _, vi = DynamicPPL. init!! (
601
+ Xoshiro (469 ), model, deepcopy (empty_vi), ParamsInit (params_dict)
602
+ )
603
+ @test vi[@varname (x)] == my_x
604
+ dict_y = vi[@varname (y)]
605
+ @test dict_y isa AbstractVector{<: Real }
606
+ @test length (dict_y) == 2
607
+ # the values should be different since we used different seeds
608
+ @test dict_y != nt_y
609
+ end
610
+ end
611
+ end
443
612
end
444
613
end
0 commit comments