Skip to content

Commit b7e1515

Browse files
Fixups for the documentation (#1585)
1 parent 483c42c commit b7e1515

File tree

5 files changed

+64
-52
lines changed

5 files changed

+64
-52
lines changed

docs/applications.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
SBI application explorer
22
========================
33

4+
We recommend to explore the applications
5+
`in fullscreen <https://sbi-applications-explorer.streamlit.app/>`_.
6+
47
.. raw:: html
58

69
<iframe
@@ -19,8 +22,7 @@ About the SBI application explorer
1922
The SBI application explorer shows how SBI is applied across different research fields
2023
and data types, in particular with respect to the number of parameters and the number
2124
of simulations used. This is to gain a quick overview over existing applications, and
22-
the right settings for your own work. We recommend to explore the applications
23-
[in fullscreen](https://sbi-applications-explorer.streamlit.app/).
25+
the right settings for your own work.
2426

2527

2628
Features
@@ -32,7 +34,7 @@ Features
3234

3335

3436
Contributing
35-
-----------
37+
------------
3638

3739
The data comes from a curated list of SBI papers.
3840

docs/how_to_guide/13_diagnostics_lc2st.ipynb

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,37 @@
2727
"from sbi.analysis.plot import pp_plot_lc2st\n",
2828
"\n",
2929
"# Sample calibration data.\n",
30-
"theta_cal = prior.sample((NUM_CAL,))\n",
31-
"x_cal = simulator(theta_cal)\n",
32-
"post_samples_cal = posterior.sample_batched((1,), x=x_cal)[0]\n",
30+
"num_lc2st_samples = 1_000\n",
31+
"prior_samples = prior.sample((num_lc2st_samples,))\n",
32+
"prior_predictives = simulator(prior_samples)\n",
33+
"\n",
34+
"# Generate one posterior sample for every prior predictive.\n",
35+
"post_samples_cal = []\n",
36+
"for x in prior_predictives:\n",
37+
" post_samples_cal.append(posterior.sample((1,), x=x)[0])\n",
38+
"post_samples_cal = torch.stack(post_samples_cal)\n",
3339
"\n",
3440
"# Train the L-C2ST classifier.\n",
3541
"lc2st = LC2ST(\n",
36-
" thetas=theta_cal,\n",
37-
" xs=x_cal,\n",
42+
" thetas=prior_samples,\n",
43+
" xs=prior_predictives,\n",
3844
" posterior_samples=post_samples_cal,\n",
3945
" classifier=\"mlp\",\n",
4046
" num_ensemble=1,\n",
4147
")\n",
4248
"_ = lc2st.train_under_null_hypothesis()\n",
4349
"_ = lc2st.train_on_observed_data()\n",
4450
"\n",
51+
"# Note: x_o must have a batch-dimension. I.e. `x_o.shape == (1, observation_shape)`.\n",
4552
"post_samples_star = posterior.sample((10_000,), x=x_o)\n",
4653
"probs_data, _ = lc2st.get_scores(\n",
47-
" theta_o=post_samples_star[i],\n",
54+
" theta_o=post_samples_star,\n",
4855
" x_o=x_o,\n",
4956
" return_probs=True,\n",
5057
" trained_clfs=lc2st.trained_clfs\n",
5158
")\n",
5259
"probs_null, _ = lc2st.get_statistics_under_null_hypothesis(\n",
53-
" theta_o=post_samples_star[i],\n",
60+
" theta_o=post_samples_star,\n",
5461
" x_o=x_o,\n",
5562
" return_probs=True\n",
5663
")\n",
@@ -71,14 +78,14 @@
7178
"cell_type": "markdown",
7279
"metadata": {},
7380
"source": [
74-
"<img src=\"data/l_c2st_pp_plot.png\" width=\"500\">"
81+
"<img src=\"data/L_C2ST_pp_plot.png\" width=\"500\">"
7582
]
7683
},
7784
{
7885
"cell_type": "markdown",
7986
"metadata": {},
8087
"source": [
81-
"If the red line lies within the gray region, the we cannot reject the null-hypothesis that the approximate posterior matches the true posterior. If the red line is below the gray area, then the `posterior_estimator` is over-confident. If the red line is above the gray area, then the `posterior_estimator` is under-confident."
88+
"If the red line lies within the gray region, the we cannot reject the null-hypothesis that the approximate posterior matches the true posterior. If the red line is below the gray area, then the `posterior` is over-confident. If the red line is above the gray area, then the `posterior` is under-confident."
8289
]
8390
},
8491
{

docs/how_to_guide/15_expected_coverage.ipynb

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"id": "e26566da",
5+
"id": "699388aa",
66
"metadata": {},
77
"source": [
88
"# How to run expected coverage"
99
]
1010
},
1111
{
1212
"cell_type": "markdown",
13-
"id": "9eedd737",
13+
"id": "4e734b07",
1414
"metadata": {},
1515
"source": [
1616
"Expected coverage provides a simple and interpretable tool to diagnose issues in the posterior. In comparison to other diagnostic tools such as L-C2ST, it requires relatively few additional simulations (~200) and it does not rely on any additional hyperparameters (as TARP would) or additional neural network training.\n",
@@ -22,23 +22,23 @@
2222
},
2323
{
2424
"cell_type": "markdown",
25-
"id": "1ca1862e",
25+
"id": "24ccbf7f",
2626
"metadata": {},
2727
"source": [
2828
"## Main syntax"
2929
]
3030
},
3131
{
3232
"cell_type": "markdown",
33-
"id": "bdd96080",
33+
"id": "f83b9e30",
3434
"metadata": {},
3535
"source": [
3636
"```python\n",
3737
"from sbi.diagnostics import run_sbc\n",
3838
"from sbi.analysis.plot import sbc_rank_plot\n",
3939
"\n",
4040
"# Obtain your `posterior_estimator` with NPE, NLE, NRE.\n",
41-
"posterior_estimator = DirectPosterior(posterior_net, prior)\n",
41+
"posterior = inference.build_posterior()\n",
4242
"\n",
4343
"num_sbc_samples = 200 # choose a number of sbc runs, should be ~100s\n",
4444
"prior_samples = prior.sample((num_sbc_samples,))\n",
@@ -49,8 +49,8 @@
4949
"ranks, dap_samples = run_sbc(\n",
5050
" prior_samples,\n",
5151
" prior_predictives,\n",
52-
" posterior_estimator,\n",
53-
" reduce_fns=lambda theta, x: -posterior_estimator.log_prob(theta, x),\n",
52+
" posterior,\n",
53+
" reduce_fns=lambda theta, x: -posterior.log_prob(theta, x),\n",
5454
" num_posterior_samples=num_posterior_samples,\n",
5555
" use_batched_sampling=False, # `True` can give speed-ups, but can cause memory issues.\n",
5656
")\n",
@@ -66,35 +66,35 @@
6666
},
6767
{
6868
"cell_type": "markdown",
69-
"id": "4a02be94",
69+
"id": "0a683358",
7070
"metadata": {},
7171
"source": [
7272
"This will return a figure such as the following:"
7373
]
7474
},
7575
{
7676
"cell_type": "markdown",
77-
"id": "a84f920d",
77+
"id": "aa7dc091",
7878
"metadata": {},
7979
"source": [
8080
"<img src=\"data/sbc_rank_plot.png\" width=\"500\">"
8181
]
8282
},
8383
{
8484
"cell_type": "markdown",
85-
"id": "7bc986ee",
85+
"id": "065fae1f",
8686
"metadata": {},
8787
"source": [
8888
"You can interpret this plots as follows:\n",
89-
"- If the blue line is below the diagonal, then the `posterior_estimator` is (on average) over\n",
89+
"- If the blue line is below the diagonal, then the `posterior` is (on average) over\n",
9090
"-confident.\n",
91-
"- If the line is above the gray region, then the `posterior_esitmator` is, on average, under-confident.\n",
91+
"- If the line is above the gray region, then the `posterior` is, on average, under-confident.\n",
9292
"- If the line is within the gray region, then we cannot reject the null hypothesis that the posterior is well-calibrated."
9393
]
9494
},
9595
{
9696
"cell_type": "markdown",
97-
"id": "c9a19771",
97+
"id": "78f37ce0",
9898
"metadata": {},
9999
"source": [
100100
"## Citation\n",

docs/how_to_guide/16_sbc.ipynb

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"id": "b917881c",
5+
"id": "7e0516a5",
66
"metadata": {},
77
"source": [
88
"# How to run simulation-based calibration (SBC)"
99
]
1010
},
1111
{
1212
"cell_type": "markdown",
13-
"id": "4e4f7f81",
13+
"id": "df0d4628",
1414
"metadata": {},
1515
"source": [
1616
"Similar to expected coverage, simulation-based calibration (SBC) provides a simple and interpretable tool to diagnose issues in the posterior. It also requires relatively few additional simulations (~200) and it does not rely on any additional hyperparameters (as TARP would) or additional neural network training.\n",
@@ -22,22 +22,22 @@
2222
},
2323
{
2424
"cell_type": "markdown",
25-
"id": "8e21e1ce",
25+
"id": "11650061",
2626
"metadata": {},
2727
"source": [
2828
"## Main syntax"
2929
]
3030
},
3131
{
3232
"cell_type": "markdown",
33-
"id": "abeefc9d",
33+
"id": "8174d1d2",
3434
"metadata": {},
3535
"source": [
3636
"```python\n",
3737
"from sbi.diagnostics import run_sbc\n",
3838
"\n",
3939
"# Obtain your `posterior_estimator` with NPE, NLE, NRE.\n",
40-
"posterior_estimator = DirectPosterior(posterior_net, prior)\n",
40+
"posterior = inference.build_posterior()\n",
4141
"\n",
4242
"num_sbc_samples = 200 # choose a number of sbc runs, should be ~100s\n",
4343
"prior_samples = prior.sample((num_sbc_samples,))\n",
@@ -48,7 +48,7 @@
4848
"ranks, dap_samples = run_sbc(\n",
4949
" prior_samples,\n",
5050
" prior_predictives,\n",
51-
" posterior_estimator,\n",
51+
" posterior,\n",
5252
" num_posterior_samples=num_posterior_samples,\n",
5353
" use_batched_sampling=False, # `True` can give speed-ups, but can cause memory issues.\n",
5454
")\n",
@@ -63,31 +63,31 @@
6363
},
6464
{
6565
"cell_type": "markdown",
66-
"id": "345cb5a1",
66+
"id": "902dcd07",
6767
"metadata": {},
6868
"source": [
6969
"The only difference to running expected coverage is that we did not pass `run_sbc(..., reduce_fns=...)` and we visualize it differently by not passing `sbc_rank_plot(..., plot_type=\"cdf\")`"
7070
]
7171
},
7272
{
7373
"cell_type": "markdown",
74-
"id": "705887d9",
74+
"id": "13128f34",
7575
"metadata": {},
7676
"source": [
7777
"This will return a figure such as the following:"
7878
]
7979
},
8080
{
8181
"cell_type": "markdown",
82-
"id": "ddaa4e01",
82+
"id": "1c0d27fe",
8383
"metadata": {},
8484
"source": [
8585
"<img src=\"data/sbc_plot.png\" width=\"500\">"
8686
]
8787
},
8888
{
8989
"cell_type": "markdown",
90-
"id": "51f22670",
90+
"id": "f8625c0c",
9191
"metadata": {},
9292
"source": [
9393
"This plots as many plots as there are parameters. For each of the parameters you can interpret the the shape of the red bars as follows:\n",
@@ -100,23 +100,23 @@
100100
},
101101
{
102102
"cell_type": "markdown",
103-
"id": "27b8f89e",
103+
"id": "b7e76003",
104104
"metadata": {},
105105
"source": [
106106
"## Example"
107107
]
108108
},
109109
{
110110
"cell_type": "markdown",
111-
"id": "5560ee0f",
111+
"id": "df48ce24",
112112
"metadata": {},
113113
"source": [
114114
"For a detailed example and further explanation, see [this tutorial](https://sbi.readthedocs.io/en/latest/advanced_tutorials/11_diagnostics_simulation_based_calibration.html)."
115115
]
116116
},
117117
{
118118
"cell_type": "markdown",
119-
"id": "f26b06c0",
119+
"id": "cdba7af6",
120120
"metadata": {},
121121
"source": [
122122
"## Citation\n",

0 commit comments

Comments
 (0)