Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
- main
release:
types: [ published ]
workflow_dispatch:

jobs:
docs:
Expand All @@ -17,31 +18,24 @@ jobs:
fetch-depth: 0
lfs: false

- name: Set up Python
uses: actions/setup-python@v2
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v5
with:
python-version: '3.10'
enable-cache: true
cache-dependency-glob: "pyproject.toml"

- name: Cache dependency
id: cache-dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip

- name: Install sbi and dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .[doc]
- name: Install dependencies with uv
run: uv sync --all-extras --doc

- name: strip output except plots and prints from tutorial notebooks
run: |
python tests/strip_notebook_outputs.py tutorials/
uv run python tests/strip_notebook_outputs.py tutorials/

- name: convert notebooks to markdown
run: |
cd docs
jupyter nbconvert --to markdown ../tutorials/*.ipynb --output-dir docs/tutorials/
uv run jupyter nbconvert --to markdown ../tutorials/*.ipynb --output-dir docs/tutorials/

- name: Configure Git user for bot
run: |
Expand All @@ -52,10 +46,10 @@ jobs:
if: ${{ github.event_name == 'push' }}
run: |
cd docs
mike deploy dev --push
uv run mike deploy dev --push

- name: Build and deploy the lastest documentation upon new release
if: ${{ github.event_name == 'release' }}
run: |
cd docs
mike deploy ${{ github.event.release.name }} latest -u --push
uv run mike deploy ${{ github.event.release.name }} latest -u --push
9 changes: 4 additions & 5 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@ jobs:
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v5
with:
python-version: '3.9'
python-version: '3.10'
enable-cache: true
cache-dependency-glob: "pyproject.toml"

- name: Install dependencies with uv
run: |
uv pip install -e .[dev]
run: uv sync --all-extras --dev

- name: Run the fast and the slow CPU tests with coverage
run: |
uv run pytest -v -x -n auto -m "not gpu" --cov=sbi --cov-report=xml tests/
run: uv run pytest -v -x -n auto -m "not gpu" --cov=sbi --cov-report=xml tests/

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4-beta
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.9', '3.12']
python-version: ['3.10', '3.13']

steps:
- name: Checkout
Expand All @@ -34,6 +34,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
cache-dependency-glob: "pyproject.toml"

- name: Install dependencies with uv
run: |
Expand Down
12 changes: 5 additions & 7 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'
- uses: pre-commit/action@v3.0.1
with:
extra_args: --all-files --show-diff-on-failure
Expand All @@ -40,14 +40,12 @@ jobs:
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v5
with:
python-version: '3.9'
python-version: '3.10'
enable-cache: true
cache-dependency-glob: "pyproject.toml"

- name: Install dependencies with uv
run: |
uv pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
uv pip install -e .[dev]
run: uv sync --all-extras --dev

- name: Check types with pyright
run: |
uv run pyright sbi
run: uv run pyright sbi
8 changes: 4 additions & 4 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
python-version: "3.10"
- name: Install pypa/build
run: >-
python3 -m
Expand All @@ -25,7 +25,7 @@ jobs:
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
Expand All @@ -45,7 +45,7 @@ jobs:

steps:
- name: Download all the dists
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
Expand All @@ -66,7 +66,7 @@ jobs:

steps:
- name: Download all the dists
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ posterior = inference.build_posterior()

### Installation

`sbi` requires Python 3.9 or higher. While a GPU isn't necessary, it can improve
`sbi` requires Python 3.10 or higher. While a GPU isn't necessary, it can improve
performance in some cases. We recommend using a virtual environment with
[`conda`](https://docs.conda.io/en/latest/miniconda.html) for an easy setup.

If `conda` is installed on the system, an environment for installing `sbi` can be created as follows:

```
conda create -n sbi_env python=3.9 && conda activate sbi_env
conda create -n sbi_env python=3.10 && conda activate sbi_env
```

### From PyPI
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/install.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Installation

`sbi` requires Python 3.9 or higher. A GPU is not required, but can lead to
`sbi` requires Python 3.10 or higher. A GPU is not required, but can lead to
speed-up in some cases. We recommend using a
[`conda`](https://docs.conda.io/en/latest/miniconda.html) virtual environment
([Miniconda installation
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3",
"Development Status :: 3 - Alpha",
]
requires-python = ">=3.9"
requires-python = ">=3.10"
dynamic = ["version"]
readme = "README.md"
keywords = ["Bayesian inference", "simulation-based inference", "PyTorch"]
Expand All @@ -38,7 +38,7 @@ dependencies = [
"pyknos>=0.16.0",
"pyro-ppl>=1.3.1",
"scikit-learn",
"scipy<1.13",
"scipy",
"tensorboard",
"torch>=1.13.0, <2.6.0",
"tqdm",
Expand Down Expand Up @@ -140,7 +140,7 @@ xfail_strict = true
[tool.pyright]
include = ["sbi"]
exclude = ["**/__pycache__", "**/__node_modules__", ".git", "docs", "tutorials", "tests"]
python_version = "3.9"
python_version = "3.10"
reportUnsupportedDunderAll = false
reportGeneralTypeIssues = false
reportInvalidTypeForm = false
Expand Down
18 changes: 13 additions & 5 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,9 @@
diag_kwargs_list = to_list_kwargs(diag_kwargs, len(samples))
diag_func = get_diag_funcs(diag_list)
diag_kwargs_filled = []
for i, (diag_i, diag_kwargs_i) in enumerate(zip(diag_list, diag_kwargs_list)):
for i, (diag_i, diag_kwargs_i) in enumerate(
zip(diag_list, diag_kwargs_list, strict=False)
):
diag_kwarg_filled_i = _get_default_diag_kwargs(diag_i, i)
# update the defaults dictionary with user provided values
diag_kwarg_filled_i = _update(diag_kwarg_filled_i, diag_kwargs_i)
Expand All @@ -798,7 +800,9 @@
upper_kwargs_list = to_list_kwargs(upper_kwargs, len(samples))
upper_func = get_offdiag_funcs(upper_list)
upper_kwargs_filled = []
for i, (upper_i, upper_kwargs_i) in enumerate(zip(upper_list, upper_kwargs_list)):
for i, (upper_i, upper_kwargs_i) in enumerate(
zip(upper_list, upper_kwargs_list, strict=False)
):
upper_kwarg_filled_i = _get_default_offdiag_kwargs(upper_i, i)
# update the defaults dictionary with user provided values
upper_kwarg_filled_i = _update(upper_kwarg_filled_i, upper_kwargs_i)
Expand All @@ -809,7 +813,9 @@
lower_kwargs_list = to_list_kwargs(lower_kwargs, len(samples))
lower_func = get_offdiag_funcs(lower_list)
lower_kwargs_filled = []
for i, (lower_i, lower_kwargs_i) in enumerate(zip(lower_list, lower_kwargs_list)):
for i, (lower_i, lower_kwargs_i) in enumerate(
zip(lower_list, lower_kwargs_list, strict=False)
):
lower_kwarg_filled_i = _get_default_offdiag_kwargs(lower_i, i)
# update the defaults dictionary with user provided values
lower_kwarg_filled_i = _update(lower_kwarg_filled_i, lower_kwargs_i)
Expand Down Expand Up @@ -910,7 +916,9 @@
diag_kwargs_list = to_list_kwargs(diag_kwargs, len(samples))
diag_func = get_diag_funcs(diag_list)
diag_kwargs_filled = []
for i, (diag_i, diag_kwargs_i) in enumerate(zip(diag_list, diag_kwargs_list)):
for i, (diag_i, diag_kwargs_i) in enumerate(

Check warning on line 919 in sbi/analysis/plot.py

View check run for this annotation

Codecov / codecov/patch

sbi/analysis/plot.py#L919

Added line #L919 was not covered by tests
zip(diag_list, diag_kwargs_list, strict=False)
):
diag_kwarg_filled_i = _get_default_diag_kwargs(diag_i, i)
diag_kwarg_filled_i = _update(diag_kwarg_filled_i, diag_kwargs_i)
diag_kwargs_filled.append(diag_kwarg_filled_i)
Expand Down Expand Up @@ -2031,7 +2039,7 @@
# normalize color intensity
norm = Normalize(vmin=vmin, vmax=vmax)
# set color intensity
for w, p in zip(weights, patches):
for w, p in zip(weights, patches, strict=False):

Check warning on line 2042 in sbi/analysis/plot.py

View check run for this annotation

Codecov / codecov/patch

sbi/analysis/plot.py#L2042

Added line #L2042 was not covered by tests
p.set_facecolor(cmap(w))
if show_colorbar:
plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax_, label=label)
Expand Down
6 changes: 3 additions & 3 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _run_sbc(
ranks = torch.zeros((num_sbc_samples, len(reduce_fns)))
# Iterate over all sbc samples and calculate ranks.
for sbc_idx, (true_theta, x_i) in tqdm(
enumerate(zip(thetas, xs)),
enumerate(zip(thetas, xs, strict=False)),
total=num_sbc_samples,
disable=not show_progress_bar,
desc=f"Calculating ranks for {num_sbc_samples} sbc samples.",
Expand Down Expand Up @@ -188,7 +188,7 @@ def get_nltp(thetas: Tensor, xs: Tensor, posterior: NeuralPosterior) -> Tensor:
nltp = torch.zeros(thetas.shape[0])
unnormalized_log_prob = not isinstance(posterior, (DirectPosterior, ScorePosterior))

for idx, (tho, xo) in enumerate(zip(thetas, xs)):
for idx, (tho, xo) in enumerate(zip(thetas, xs, strict=False)):
# Log prob of true params under posterior.
if unnormalized_log_prob:
nltp[idx] = -posterior.potential(tho, x=xo)
Expand Down Expand Up @@ -266,7 +266,7 @@ def check_prior_vs_dap(prior_samples: Tensor, dap_samples: Tensor) -> Tensor:

return torch.tensor([
c2st(s1.unsqueeze(1), s2.unsqueeze(1))
for s1, s2 in zip(prior_samples.T, dap_samples.T)
for s1, s2 in zip(prior_samples.T, dap_samples.T, strict=False)
])


Expand Down
4 changes: 4 additions & 0 deletions sbi/neural_nets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def classifier_nn(
check_net_device(embedding_net_theta, "cpu", embedding_net_warn_msg),
check_net_device(embedding_net_x, "cpu", embedding_net_warn_msg),
),
strict=False,
),
**kwargs,
)
Expand Down Expand Up @@ -194,6 +195,7 @@ def likelihood_nn(
check_net_device(embedding_net, "cpu", embedding_net_warn_msg),
num_components,
),
strict=False,
),
**kwargs,
)
Expand Down Expand Up @@ -332,6 +334,7 @@ def posterior_nn(
check_net_device(embedding_net, "cpu", embedding_net_warn_msg),
num_components,
),
strict=False,
),
**kwargs,
)
Expand Down Expand Up @@ -434,6 +437,7 @@ def posterior_score_nn(
hidden_features,
embedding_net,
),
strict=False,
),
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion sbi/samplers/mcmc/slice_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def run(self, num_samples: int) -> np.ndarray:
):
all_samples: Sequence[np.ndarray] = Parallel(n_jobs=self.num_workers)( # pyright: ignore[reportAssignmentType]
delayed(self.run_fun)(num_samples, initial_params_batch, seed)
for initial_params_batch, seed in zip(self.x, seeds)
for initial_params_batch, seed in zip(self.x, seeds, strict=False)
)

samples = np.stack(all_samples).astype(np.float32)
Expand Down
6 changes: 3 additions & 3 deletions sbi/samplers/vi/vi_divergence_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@

def update_state(self) -> None:
"""This updates the current state."""
for state_para, para in zip(self.state_dict, self.q.parameters()):
for state_para, para in zip(self.state_dict, self.q.parameters(), strict=False):
if torch.isfinite(para).all():
state_para.data = para.data.clone()
else:
Expand All @@ -213,7 +213,7 @@
Args:
warm_up_rounds: Number of warm_up_round one should do after failure.
"""
for state_para, para in zip(self.state_dict, self.q.parameters()):
for state_para, para in zip(self.state_dict, self.q.parameters(), strict=False):

Check warning on line 216 in sbi/samplers/vi/vi_divergence_optimizers.py

View check run for this annotation

Codecov / codecov/patch

sbi/samplers/vi/vi_divergence_optimizers.py#L216

Added line #L216 was not covered by tests
para.data = state_para.data.clone().to(para.device)
self._optimizer.__init__(self.q.parameters(), self.learning_rate)
self.warm_up(warm_up_rounds)
Expand Down Expand Up @@ -464,7 +464,7 @@
def update_surrogate_q(self) -> None:
"""Updates the surrogate with new parameters."""
for param, param_surro in zip(
self.q.parameters(), self._surrogate_q.parameters()
self.q.parameters(), self._surrogate_q.parameters(), strict=False
):
param_surro.data = param.data
param_surro.requires_grad = False
Expand Down
2 changes: 1 addition & 1 deletion sbi/simulators/simutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def simulator_seeded(theta: Tensor, seed) -> Tensor:
) as _:
simulation_outputs: List[Tensor] = Parallel(n_jobs=num_workers)( # pyright: ignore[reportAssignmentType]
delayed(simulator_seeded)(batch, batch_seed)
for batch, batch_seed in zip(batches, batch_seeds)
for batch, batch_seed in zip(batches, batch_seeds, strict=False)
)
else:
pbar = tqdm(
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _compute_covariance(
# Compute E[X] * E[Y].
expected_values_of_marginals = [
_expected_value_f_of_x(prob.unsqueeze(0), lim.unsqueeze(0))
for prob, lim in zip(_calc_marginals(probs, limits), limits)
for prob, lim in zip(_calc_marginals(probs, limits), limits, strict=False)
]

return expected_value_of_joint - f(*expected_values_of_marginals)
Expand Down Expand Up @@ -104,7 +104,7 @@ def _expected_value_f_of_x(

x_values_over_which_we_integrate = [
torch.linspace(lim[0].item(), lim[1].item(), prob.shape[0], device=probs.device)
for lim, prob in zip(torch.flip(limits, [0]), probs)
for lim, prob in zip(torch.flip(limits, [0]), probs, strict=False)
] # See #403 and #404 for flip().
grids = list(torch.meshgrid(x_values_over_which_we_integrate))
expected_val = torch.sum(f(*grids) * probs)
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/diagnostics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def sample_fun(
tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(sample_fun)(posterior, sample_shape, x=x, seed=s)
for x, s in zip(xs, seeds)
for x, s in zip(xs, seeds, strict=False)
),
disable=not show_progress_bar,
total=len(xs),
Expand Down
Loading
Loading