Skip to content

Commit 0c4e60b

Browse files
committed
local test fix
1 parent 5ce203f commit 0c4e60b

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

tests/lc2st_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_lc2st_false_positiv_rate(method):
197197

198198
# use big num_train and num_epochs to obtain "good" estimator
199199
# (convergence of the estimator)
200-
num_train = 10_000
200+
num_train = 1_000
201201
num_epochs = 200
202202

203203
num_cal = 1_000
@@ -215,7 +215,7 @@ def test_lc2st_false_positiv_rate(method):
215215
# Train the neural posterior estimators
216216
inference = NPE(prior, density_estimator='maf')
217217
inference = inference.append_simulations(theta=theta_train, x=x_train)
218-
npe = inference.train(training_batch_size=1000, max_num_epochs=num_epochs)
218+
npe = inference.train(training_batch_size=100, max_num_epochs=num_epochs)
219219

220220
thetas = prior.sample((num_cal,))
221221
xs = simulator(thetas)

tutorials/13_diagnostics_lc2st.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,6 @@
277277
"# sample calibration data\n",
278278
"theta_cal = prior.sample((NUM_CAL,))\n",
279279
"x_cal = simulator(theta_cal)\n",
280-
"# post_samples_cal = npe.sample((1,), x_cal).reshape(-1, theta_cal.shape[-1]).detach()\n",
281280
"post_samples_cal = posterior.sample_batched((1,), x=x_cal)[0]"
282281
]
283282
},

0 commit comments

Comments
 (0)