Skip to content

Commit 80740b2

Browse files
authored
Fix failing tutorials, change MNLE default for log_transform to False (#1367)
* fix fmpe typo, add nle seed * fix: change default log transfrom to False. This was set to True for positive reaction times but does not hold in general. * small nb fixes
1 parent 76c1e1b commit 80740b2

File tree

5 files changed

+16
-14
lines changed

5 files changed

+16
-14
lines changed

sbi/neural_nets/net_builders/mnle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def build_mnle(
6464
hidden_features: int = 50,
6565
hidden_layers: int = 2,
6666
tail_bound: float = 10.0,
67-
log_transform_x: bool = True,
67+
log_transform_x: bool = False,
6868
**kwargs,
6969
):
7070
"""Returns a density estimator for mixed data types.

tests/mnle_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
162162
x = mixed_simulator(theta, stimulus_condition=1.0)
163163

164164
# MNLE
165-
density_estimator = likelihood_nn(model="mnle", flow_model=flow_model)
165+
density_estimator = likelihood_nn(
166+
model="mnle", flow_model=flow_model, log_transform_x=True
167+
)
166168
trainer = MNLE(prior, density_estimator=density_estimator)
167169
trainer.append_simulations(theta, x).train(training_batch_size=200)
168170
posterior = trainer.build_posterior()
@@ -294,7 +296,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
294296
)
295297

296298
# MNLE
297-
estimator_fun = likelihood_nn(model="mnle", z_score_x=None)
299+
estimator_fun = likelihood_nn(model="mnle", log_transform_x=True)
298300
trainer = MNLE(proposal, estimator_fun)
299301
estimator = trainer.append_simulations(theta, x).train()
300302

@@ -362,9 +364,7 @@ def test_log_likelihood_over_local_iid_theta(
362364
"""
363365

364366
# train mnle on mixed data
365-
trainer = MNLE(
366-
density_estimator=likelihood_nn(model="mnle", z_score_x=None),
367-
)
367+
trainer = MNLE()
368368
proposal = MultipleIndependent(
369369
[
370370
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),

tests/tutorials_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_tutorials(notebook_path):
2222
"""Test that all notebooks in the tutorials directory can be executed."""
2323
with open(notebook_path) as f:
2424
nb = nbformat.read(f, as_version=4)
25-
ep = ExecutePreprocessor(timeout=1200, kernel_name='python3')
25+
ep = ExecutePreprocessor(timeout=600, kernel_name='python3')
2626
print(f"Executing notebook {notebook_path}")
2727
try:
2828
ep.preprocess(nb, {'metadata': {'path': os.path.dirname(notebook_path)}})

tutorials/16_implemented_methods.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@
187187
"from sbi.inference import FMPE\n",
188188
"\n",
189189
"inference = FMPE(prior)\n",
190-
"# FMPE does support multiple rounds of inference\n",
190+
"# FMPE does not support multiple rounds of inference\n",
191191
"theta = prior.sample((num_sims,))\n",
192192
"x = simulator(theta)\n",
193193
"inference.append_simulations(theta, x).train()\n",
@@ -310,7 +310,8 @@
310310
"\n",
311311
"inference = MNLE(prior)\n",
312312
"theta = prior.sample((num_sims,))\n",
313-
"x = simulator(theta)\n",
313+
"# add a column of discrete data to x.\n",
314+
"x = torch.cat((simulator(theta), torch.bernoulli(theta[:, :1])), dim=1)\n",
314315
"_ = inference.append_simulations(theta, x).train()\n",
315316
"posterior = inference.build_posterior().set_default_x(x_o)"
316317
]

tutorials/Example_01_DecisionMakingModel.ipynb

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@
129129
" Beta(torch.tensor([2.0]), torch.tensor([2.0])),\n",
130130
" ],\n",
131131
" validate_args=False,\n",
132-
")"
132+
")\n",
133+
"prior_transform = mcmc_transform(prior)"
133134
]
134135
},
135136
{
@@ -184,7 +185,7 @@
184185
"true_posterior = MCMCPosterior(\n",
185186
" potential_fn=BinomialGammaPotential(prior, x_o),\n",
186187
" proposal=prior,\n",
187-
" theta_transform=mcmc_transform(prior, enable_transform=True),\n",
188+
" theta_transform=prior_transform,\n",
188189
" **mcmc_kwargs,\n",
189190
")\n",
190191
"true_samples = true_posterior.sample((num_samples,))"
@@ -228,7 +229,8 @@
228229
"x = mixed_simulator(theta)\n",
229230
"\n",
230231
"# Train MNLE and obtain MCMC-based posterior.\n",
231-
"trainer = MNLE()\n",
232+
"estimator_builder = likelihood_nn(model=\"mnle\", log_transform_x=True)\n",
233+
"trainer = MNLE(proposal, estimator_builder)\n",
232234
"estimator = trainer.append_simulations(theta, x).train()"
233235
]
234236
},
@@ -610,7 +612,7 @@
610612
}
611613
],
612614
"source": [
613-
"estimator_builder = likelihood_nn(model=\"mnle\", z_score_x=None) # we don't want to z-score the binary data.\n",
615+
"estimator_builder = likelihood_nn(model=\"mnle\", log_transform_x=True)\n",
614616
"trainer = MNLE(proposal, estimator_builder)\n",
615617
"estimator = trainer.append_simulations(theta, x).train()"
616618
]
@@ -847,7 +849,6 @@
847849
"\n",
848850
"fig, ax = pairplot(\n",
849851
" [prior.sample((1000,))] + posterior_samples,\n",
850-
" # points=theta_o,\n",
851852
" diag=\"kde\",\n",
852853
" upper=\"contour\",\n",
853854
" diag_kwargs=dict(bins=100),\n",

0 commit comments

Comments
 (0)