@@ -38,11 +38,9 @@ def _create_npe(num_simulations, max_epochs=None):
38
38
inference = NPE (prior , density_estimator = 'maf' )
39
39
inference = inference .append_simulations (theta = theta_train , x = x_train )
40
40
41
- train_kwargs = {"training_batch_size" : 100 }
42
- if max_epochs :
43
- train_kwargs ["max_num_epochs" ] = max_epochs
44
-
45
- return inference .train (** train_kwargs )
41
+ return inference .train (
42
+ max_num_epochs = 2 ** 31 - 1 if max_epochs is None else max_epochs ,
43
+ )
46
44
47
45
return _create_npe
48
46
@@ -54,7 +52,7 @@ def badly_trained_npe(npe_factory):
54
52
55
53
@pytest .fixture (scope = "session" )
56
54
def well_trained_npe (npe_factory ):
57
- return npe_factory (num_simulations = 10_000 )
55
+ return npe_factory (num_simulations = 5_000 )
58
56
59
57
60
58
@pytest .fixture (scope = "session" )
@@ -276,7 +274,7 @@ def test_lc2st_false_positiv_rate(method, basic_setup, well_trained_npe, set_see
276
274
proportion_rejected = torch .tensor (results ).float ().mean ()
277
275
278
276
assert proportion_rejected < (1 - confidence_level ), (
279
- f "LC2ST p-values too small, test should be rejected \
280
- less then { (1 - confidence_level ) * 100 } % of the time, \
281
- but was rejected { proportion_rejected * 100 } % of the time."
277
+ "LC2ST p-values too small, test should be rejected "
278
+ f" less then { (1 - confidence_level ) * 100.0 :<.2f } % of the time, "
279
+ f" but was rejected { proportion_rejected * 100.0 :<.2f } % of the time."
282
280
)
0 commit comments