@@ -533,68 +533,6 @@ def test_point_logps():
533
533
assert "a" in logp_vals .keys ()
534
534
535
535
536
- class TestUpdateStartVals (SeededTest ):
537
- def setup_method (self ):
538
- super ().setup_method ()
539
-
540
- def test_soft_update_all_present (self ):
541
- model = pm .Model ()
542
- start = {"a" : 1 , "b" : 2 }
543
- test_point = {"a" : 3 , "b" : 4 }
544
- model .update_start_vals (start , test_point )
545
- assert start == {"a" : 1 , "b" : 2 }
546
-
547
- def test_soft_update_one_missing (self ):
548
- model = pm .Model ()
549
- start = {
550
- "a" : 1 ,
551
- }
552
- test_point = {"a" : 3 , "b" : 4 }
553
- model .update_start_vals (start , test_point )
554
- assert start == {"a" : 1 , "b" : 4 }
555
-
556
- def test_soft_update_empty (self ):
557
- model = pm .Model ()
558
- start = {}
559
- test_point = {"a" : 3 , "b" : 4 }
560
- model .update_start_vals (start , test_point )
561
- assert start == test_point
562
-
563
- def test_soft_update_transformed (self ):
564
- with pm .Model () as model :
565
- pm .Exponential ("a" , 1 )
566
- start = {"a" : 2.0 }
567
- test_point = {"a_log__" : 0 }
568
- model .update_start_vals (start , test_point )
569
- assert_almost_equal (np .exp (start ["a_log__" ]), start ["a" ])
570
-
571
- def test_soft_update_parent (self ):
572
- with pm .Model () as model :
573
- a = pm .Uniform ("a" , lower = 0.0 , upper = 1.0 )
574
- b = pm .Uniform ("b" , lower = 2.0 , upper = 3.0 )
575
- pm .Uniform ("lower" , lower = a , upper = 3.0 )
576
- pm .Uniform ("upper" , lower = 0.0 , upper = b )
577
- pm .Uniform ("interv" , lower = a , upper = b )
578
-
579
- initial_point = {
580
- "a_interval__" : np .array (0.0 , dtype = aesara .config .floatX ),
581
- "b_interval__" : np .array (0.0 , dtype = aesara .config .floatX ),
582
- "lower_interval__" : np .array (0.0 , dtype = aesara .config .floatX ),
583
- "upper_interval__" : np .array (0.0 , dtype = aesara .config .floatX ),
584
- "interv_interval__" : np .array (0.0 , dtype = aesara .config .floatX ),
585
- }
586
- start = {"a" : 0.3 , "b" : 2.1 , "lower" : 1.4 , "upper" : 1.4 , "interv" : 1.4 }
587
- test_point = {
588
- "lower_interval__" : - 0.3746934494414109 ,
589
- "upper_interval__" : 0.693147180559945 ,
590
- "interv_interval__" : 0.4519851237430569 ,
591
- }
592
- model .update_start_vals (start , initial_point )
593
- assert_almost_equal (start ["lower_interval__" ], test_point ["lower_interval__" ])
594
- assert_almost_equal (start ["upper_interval__" ], test_point ["upper_interval__" ])
595
- assert_almost_equal (start ["interv_interval__" ], test_point ["interv_interval__" ])
596
-
597
-
598
536
class TestShapeEvaluation :
599
537
def test_eval_rv_shapes (self ):
600
538
with pm .Model (
@@ -626,17 +564,21 @@ def test_valid_start_point(self):
626
564
a = pm .Uniform ("a" , lower = 0.0 , upper = 1.0 )
627
565
b = pm .Uniform ("b" , lower = 2.0 , upper = 3.0 )
628
566
629
- start = {"a" : 0.3 , "b" : 2.1 }
630
- model .update_start_vals (start , model .initial_point )
567
+ start = {
568
+ "a_interval__" : model .rvs_to_values [a ].tag .transform .forward (a , 0.3 ).eval (),
569
+ "b_interval__" : model .rvs_to_values [b ].tag .transform .forward (b , 2.1 ).eval (),
570
+ }
631
571
model .check_start_vals (start )
632
572
633
573
def test_invalid_start_point (self ):
634
574
with pm .Model () as model :
635
575
a = pm .Uniform ("a" , lower = 0.0 , upper = 1.0 )
636
576
b = pm .Uniform ("b" , lower = 2.0 , upper = 3.0 )
637
577
638
- start = {"a" : np .nan , "b" : np .nan }
639
- model .update_start_vals (start , model .initial_point )
578
+ start = {
579
+ "a_interval__" : np .nan ,
580
+ "b_interval__" : model .rvs_to_values [b ].tag .transform .forward (b , 2.1 ).eval (),
581
+ }
640
582
with pytest .raises (pm .exceptions .SamplingError ):
641
583
model .check_start_vals (start )
642
584
@@ -645,8 +587,11 @@ def test_invalid_variable_name(self):
645
587
a = pm .Uniform ("a" , lower = 0.0 , upper = 1.0 )
646
588
b = pm .Uniform ("b" , lower = 2.0 , upper = 3.0 )
647
589
648
- start = {"a" : 0.3 , "b" : 2.1 , "c" : 1.0 }
649
- model .update_start_vals (start , model .initial_point )
590
+ start = {
591
+ "a_interval__" : model .rvs_to_values [a ].tag .transform .forward (a , 0.3 ).eval (),
592
+ "b_interval__" : model .rvs_to_values [b ].tag .transform .forward (b , 2.1 ).eval (),
593
+ "c" : 1.0 ,
594
+ }
650
595
with pytest .raises (KeyError ):
651
596
model .check_start_vals (start )
652
597
0 commit comments