Skip to content

Commit 0ea6f63

Browse files
Mini sbibm (#1335)
* Start of mini_sbibm * Working prototype * Rename to get globally discovered * ruff * extended to something reasonable * remove example * different modes * Added docs and formating * colored numbers * formating and stuff * formating? * Fixing suggestions and adding dependency * Using taskt dict * licence note * Update tests/bm_test.py Co-authored-by: Jan <janfb@users.noreply.github.com> * num_eval_obs fix, remove xdist comment * Clarifing the fixture * remove comment * Adding xdist cache to gitignore * load with weights only False to have pytorch 2.6 compatibility * torch 2.6 compatibility * Small user guide to it * Updated with working xdist support now. Simplified API. Saves result human-readable * formating * Update docs/docs/contribute.md Co-authored-by: Jan <janfb@users.noreply.github.com> * Fix negative interaction with normal pytest xdist --------- Co-authored-by: Jan <janfb@users.noreply.github.com>
1 parent 43ae353 commit 0ea6f63

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+1095
-2
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ coverage.xml
6161
*,cover
6262
.hypothesis/
6363
tests/.mypy_cache
64+
.xdist_results/
65+
.xdist_harvested/
66+
.bm_results/
6467

6568
# Translations
6669
*.mo

docs/docs/contribute.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,40 @@ fails (xfailed).
210210
- Commit and push again until CI tests pass. Don't hesitate to ask for help by
211211
commenting on the PR.
212212

213+
#### mini-sbibm tests
214+
215+
As SBI is a fundamentally data-driven approach, we are not only interested in whether
216+
the modifications to the codebase "pass the tests" but also in whether they improve or
217+
at least do not deteriorate the performance of the package for inference. To this end,
218+
we have a set of *mini-sbibm* tests (a minimal version of the sbi benchmarking package [`sbibm`](https://github.com/sbi-benchmark/sbibm)) that are intended for developers to run locally.
219+
220+
These tests differ from the regular tests in that they always pass (provided there
221+
are no errors) but output performance metrics that can be compared, e.g., to the
222+
performance metrics of the main branch or relative to each other. The user-facing API
223+
is available via `pytest` through custom flags. To run the mini-sbibm tests, you can use
224+
the following command:
225+
```bash
226+
pytest --bm
227+
```
228+
This will run all the mini-sbibm tests on all methods with default parameters and output
229+
the performance metrics nicely formatted to the console. If you have multiple CPU cores
230+
available, you can run the tests in parallel using the `-n auto` flag:
231+
```bash
232+
pytest --bm -n auto
233+
```
234+
What if you are currently working on a specific method and you want to run the
235+
mini-sbibm tests only for this class of methods? You can use the `--bm-mode` flag:
236+
```bash
237+
pytest --bm --bm-mode nspe
238+
```
239+
This will run the mini-sbibm tests only for methods of the `nspe` class, but with a
240+
few major hyperparameter choices, such as different base network architectures and
241+
different diffusion processes.
242+
243+
The currently available modes are: `"npe"`, `"nle"`, `"nre"`, `"fmpe"`, `"npse"`,
244+
`"snpe"`, `"snle"`, and `"snre"`. If you require another mode, you can add it to the
245+
test suite in `tests/test_bm.py`.
246+
213247
## Contributing to the documentation
214248

215249
Most of the documentation for `sbi` is written in markdown and the website is generated

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ dev = [
7373
"pytest-cov",
7474
"pytest-testmon",
7575
"pytest-xdist",
76+
"pytest-harvest",
7677
"torchtestcase",
7778
]
7879

@@ -132,7 +133,8 @@ testpaths = ["tests"]
132133
markers = [
133134
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
134135
"gpu: marks tests that require a gpu (deselect with '-m \"not gpu\"')",
135-
"mcmc: marks tests that require MCMC sampling (deselect with '-m \"not mcmc\"')"
136+
"mcmc: marks tests that require MCMC sampling (deselect with '-m \"not mcmc\"')",
137+
"benchmark: marks test that are soley for benchmarking purposes"
136138
]
137139
xfail_strict = true
138140

tests/bm_test.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2+
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3+
4+
import pytest
5+
import torch
6+
from pytest_harvest import ResultsBag
7+
8+
from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
9+
from sbi.inference.posteriors.base_posterior import NeuralPosterior
10+
from sbi.inference.trainers.npe import NPE_C
11+
from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C
12+
from sbi.utils.metrics import c2st
13+
14+
from .mini_sbibm import get_task
15+
from .mini_sbibm.base_task import Task
16+
17+
# Global settings
18+
SEED = 0
19+
TASKS = ["two_moons", "linear_mvg_2d", "gaussian_linear", "slcp"]
20+
NUM_SIMULATIONS = 2000
21+
NUM_EVALUATION_OBS = 3 # Currently only 3 observation tested for speed
22+
NUM_ROUNDS_SEQUENTIAL = 2
23+
NUM_EVALUATION_OBS_SEQ = 1
24+
TRAIN_KWARGS = {}
25+
26+
# Density estimators to test
27+
DENSITY_ESTIMATORS = ["mdn", "made", "maf", "nsf", "maf_rqs"] # "Kinda exhaustive"
28+
CLASSIFIERS = ["mlp", "resnet"]
29+
NNS = ["mlp", "resnet"]
30+
SCORE_ESTIMATORS = ["mlp", "ada_mlp"]
31+
32+
# Benchmarking method groups i.e. what to run for different --bm-mode
33+
METHOD_GROUPS = {
34+
"none": [NPE, NRE, NLE, FMPE, NPSE],
35+
"npe": [NPE],
36+
"nle": [NLE],
37+
"nre": [NRE_A, NRE_B, NRE_C, BNRE],
38+
"fmpe": [FMPE],
39+
"npse": [NPSE],
40+
"snpe": [NPE_C], # NPE_B not implemented, NPE_A need Gaussian prior
41+
"snle": [NLE],
42+
"snre": [NRE_A, NRE_B, NRE_C, BNRE],
43+
}
44+
METHOD_PARAMS = {
45+
"none": [{}],
46+
"npe": [{"density_estimator": de} for de in DENSITY_ESTIMATORS],
47+
"nle": [{"density_estimator": de} for de in ["maf", "nsf"]],
48+
"nre": [{"classifier": cl} for cl in CLASSIFIERS],
49+
"fmpe": [{"density_estimator": nn} for nn in NNS],
50+
"npse": [
51+
{"score_estimator": nn, "sde_type": sde}
52+
for nn in SCORE_ESTIMATORS
53+
for sde in ["ve", "vp"]
54+
],
55+
"snpe": [{}],
56+
"snle": [{}],
57+
"snre": [{}],
58+
}
59+
60+
61+
@pytest.fixture
62+
def method_list(benchmark_mode: str) -> list:
63+
"""
64+
Fixture to get the list of methods based on the benchmark mode.
65+
66+
Args:
67+
benchmark_mode (str): The benchmark mode.
68+
69+
Returns:
70+
list: List of methods for the given benchmark mode.
71+
"""
72+
name = str(benchmark_mode).lower()
73+
if name not in METHOD_GROUPS:
74+
raise ValueError(f"Benchmark mode '{benchmark_mode}' is not supported.")
75+
return METHOD_GROUPS[name]
76+
77+
78+
@pytest.fixture
79+
def kwargs_list(benchmark_mode: str) -> list:
80+
"""
81+
Fixture to get the list of kwargs based on the benchmark mode.
82+
83+
Args:
84+
benchmark_mode (str): The benchmark mode.
85+
86+
Returns:
87+
list: List of kwargs for the given benchmark mode.
88+
"""
89+
name = str(benchmark_mode).lower()
90+
if name not in METHOD_PARAMS:
91+
raise ValueError(f"Benchmark mode '{benchmark_mode}' is not supported.")
92+
return METHOD_PARAMS[name]
93+
94+
95+
# Use pytest.mark.parametrize dynamically
96+
# Generates a list of methods to test based on the benchmark mode
97+
def pytest_generate_tests(metafunc):
98+
"""
99+
Dynamically generates a list of methods to test based on the benchmark mode.
100+
101+
Args:
102+
metafunc: The metafunc object from pytest.
103+
"""
104+
if "inference_class" in metafunc.fixturenames:
105+
method_list = metafunc.config.getoption("--bm-mode")
106+
name = str(method_list).lower()
107+
method_group = METHOD_GROUPS.get(name, [])
108+
metafunc.parametrize("inference_class", method_group)
109+
if "extra_kwargs" in metafunc.fixturenames:
110+
kwargs_list = metafunc.config.getoption("--bm-mode")
111+
name = str(kwargs_list).lower()
112+
kwargs_group = METHOD_PARAMS.get(name, [])
113+
metafunc.parametrize("extra_kwargs", kwargs_group)
114+
115+
116+
def standard_eval_c2st_loop(posterior: NeuralPosterior, task: Task) -> float:
117+
"""
118+
Evaluates the C2ST metric for the given posterior and task.
119+
120+
Args:
121+
posterior: The posterior distribution.
122+
task: The task object.
123+
124+
Returns:
125+
float: The mean C2ST value.
126+
"""
127+
c2st_scores = []
128+
for i in range(1, NUM_EVALUATION_OBS + 1):
129+
c2st_val = eval_c2st(posterior, task, i)
130+
c2st_scores.append(c2st_val)
131+
132+
mean_c2st = sum(c2st_scores) / len(c2st_scores)
133+
# Convert to float rounded to 3 decimal places
134+
mean_c2st = float(f"{mean_c2st:.3f}")
135+
return mean_c2st
136+
137+
138+
def eval_c2st(
139+
posterior: NeuralPosterior,
140+
task: Task,
141+
idx_observation: int,
142+
num_samples: int = 1000,
143+
) -> float:
144+
"""
145+
Evaluates the C2ST metric for a specific observation.
146+
147+
Args:
148+
posterior: The posterior distribution.
149+
task: The task object.
150+
i (int): The observation index.
151+
152+
Returns:
153+
float: The C2ST value.
154+
"""
155+
x_o = task.get_observation(idx_observation)
156+
posterior_samples = task.get_reference_posterior_samples(idx_observation)
157+
approx_posterior_samples = posterior.sample((num_samples,), x=x_o)
158+
if isinstance(approx_posterior_samples, tuple):
159+
approx_posterior_samples = approx_posterior_samples[0]
160+
assert posterior_samples.shape[0] >= num_samples, "Not enough reference samples"
161+
c2st_val = c2st(posterior_samples[:num_samples], approx_posterior_samples)
162+
return float(c2st_val)
163+
164+
165+
def train_and_eval_amortized_inference(
166+
inference_class, task_name: str, extra_kwargs: dict, results_bag: ResultsBag
167+
) -> None:
168+
"""
169+
Performs amortized inference evaluation.
170+
171+
Args:
172+
method: The inference method.
173+
task_name: The name of the task.
174+
extra_kwargs: Additional keyword arguments for the method.
175+
results_bag: The results bag to store evaluation results. Subclass of dict, but
176+
allows item assignment with dot notation.
177+
"""
178+
torch.manual_seed(SEED)
179+
task = get_task(task_name)
180+
thetas, xs = task.get_data(NUM_SIMULATIONS)
181+
prior = task.get_prior()
182+
183+
inference = inference_class(prior, **extra_kwargs)
184+
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)
185+
186+
posterior = inference.build_posterior()
187+
188+
mean_c2st = standard_eval_c2st_loop(posterior, task)
189+
190+
# Cache results
191+
results_bag.metric = mean_c2st
192+
results_bag.num_simulations = NUM_SIMULATIONS
193+
results_bag.task_name = task_name
194+
results_bag.method = inference_class.__name__ + str(extra_kwargs)
195+
196+
197+
def train_and_eval_sequential_inference(
198+
inference_class, task_name: str, extra_kwargs: dict, results_bag: ResultsBag
199+
) -> None:
200+
"""
201+
Performs sequential inference evaluation.
202+
203+
Args:
204+
method: The inference method.
205+
task_name (str): The name of the task.
206+
extra_kwargs (dict): Additional keyword arguments for the method.
207+
results_bag: The results bag to store evaluation results.
208+
"""
209+
torch.manual_seed(SEED)
210+
task = get_task(task_name)
211+
num_simulations = NUM_SIMULATIONS // NUM_ROUNDS_SEQUENTIAL
212+
thetas, xs = task.get_data(num_simulations)
213+
prior = task.get_prior()
214+
idx_eval = NUM_EVALUATION_OBS_SEQ
215+
x_o = task.get_observation(idx_eval)
216+
simulator = task.get_simulator()
217+
218+
# Round 1
219+
inference = inference_class(prior, **extra_kwargs)
220+
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)
221+
222+
for _ in range(NUM_ROUNDS_SEQUENTIAL - 1):
223+
proposal = inference.build_posterior().set_default_x(x_o)
224+
thetas_i = proposal.sample((num_simulations,))
225+
xs_i = simulator(thetas_i)
226+
if "npe" in inference_class.__name__.lower():
227+
# NPE_C requires a Gaussian prior
228+
_ = inference.append_simulations(thetas_i, xs_i, proposal=proposal).train(
229+
**TRAIN_KWARGS
230+
)
231+
else:
232+
inference.append_simulations(thetas_i, xs_i).train(**TRAIN_KWARGS)
233+
234+
posterior = inference.build_posterior()
235+
236+
c2st_val = eval_c2st(posterior, task, idx_eval)
237+
238+
# Cache results
239+
results_bag.metric = c2st_val
240+
results_bag.num_simulations = NUM_SIMULATIONS
241+
results_bag.task_name = task_name
242+
results_bag.method = inference_class.__name__ + str(extra_kwargs)
243+
244+
245+
@pytest.mark.benchmark
246+
@pytest.mark.parametrize("task_name", TASKS, ids=str)
247+
def test_run_benchmark(
248+
inference_class,
249+
task_name: str,
250+
results_bag,
251+
extra_kwargs: dict,
252+
benchmark_mode: str,
253+
) -> None:
254+
"""
255+
Benchmark test for amortized and sequential inference methods.
256+
257+
Args:
258+
inference_class: The inference class to test i.e. NPE, NLE, NRE ...
259+
task_name: The name of the task.
260+
results_bag: The results bag to store evaluation results.
261+
extra_kwargs: Additional keyword arguments for the method.
262+
benchmark_mode: The benchmark mode. This is a fixture which based on user
263+
input, determines which type of methods should be run.
264+
"""
265+
if benchmark_mode in ["snpe", "snle", "snre"]:
266+
train_and_eval_sequential_inference(
267+
inference_class, task_name, extra_kwargs, results_bag
268+
)
269+
else:
270+
train_and_eval_amortized_inference(
271+
inference_class, task_name, extra_kwargs, results_bag
272+
)

0 commit comments

Comments
 (0)