Skip to content

Commit a646c25

Browse files
committed
Merge remote-tracking branch 'upstream/main' into simformer
2 parents 0ba5863 + 28f3deb commit a646c25

38 files changed

+1111
-712
lines changed

.github/workflows/publish.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,13 @@ jobs:
7676
inputs: >-
7777
./dist/*.tar.gz
7878
./dist/*.whl
79-
- name: Create GitHub Release
79+
- name: Ensure GitHub Release exists (no-op if already exists)
8080
env:
8181
GITHUB_TOKEN: ${{ github.token }}
82-
run: >-
83-
gh release create
84-
'${{ github.ref_name }}'
85-
--repo '${{ github.repository }}'
86-
--notes ""
82+
run: |
83+
# If a release for this tag already exists (e.g., created via GH UI), skip creation.
84+
gh release view '${{ github.ref_name }}' --repo '${{ github.repository }}' >/dev/null 2>&1 || \
85+
gh release create '${{ github.ref_name }}' --repo '${{ github.repository }}' --notes ""
8786
- name: Upload artifact signatures to GitHub Release
8887
env:
8988
GITHUB_TOKEN: ${{ github.token }}

CHANGELOG.md

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
# Changelog
22

3-
# Changelog
4-
53
## v0.25.0
64

75
### ✨ Highlights
86

97
#### 🚀 New Inference Methods
108

11-
* **MNPE class similar to MNLE** by @dgedon in https://github.com/sbi-dev/sbi/pull/1362
12-
* **Implementing SNPE-B (#199)** by @etouron1 in https://github.com/sbi-dev/sbi/pull/1471
9+
* **MNPE class mixed parameter (similar to MNLE)** by @dgedon in https://github.com/sbi-dev/sbi/pull/1362
10+
* **Implementation of SNPE-B (#199)** by @etouron1 in https://github.com/sbi-dev/sbi/pull/1471
1311

1412
#### 🧠 Neural Network Architectures & Embedding Networks
1513

16-
* **Simple transformer implementation** by @NicolasRR in https://github.com/sbi-dev/sbi/pull/1494
14+
* **Add transformer embedding net** by @NicolasRR in https://github.com/sbi-dev/sbi/pull/1494
1715
* **Add embedding net that uses 1D causal convolutions (#1459)** by @Aranka-S in https://github.com/sbi-dev/sbi/pull/1499
1816
* **Add LRU-backed embedding networks** by @famura in https://github.com/sbi-dev/sbi/pull/1512
1917
* **Add ResNet as embedding model** by @StefanWahl in https://github.com/sbi-dev/sbi/pull/1472
@@ -22,12 +20,13 @@
2220
#### ⭐ Major Features & Capabilities
2321

2422
* **Unify flow matching and score-based models** by @StarostinV in https://github.com/sbi-dev/sbi/pull/1497
25-
* **Model misspecification based on MMD** by @coschroeder in https://github.com/sbi-dev/sbi/pull/1502
23+
* **Model misspecification detection based on MMD** by @coschroeder in https://github.com/sbi-dev/sbi/pull/1502
2624
* **Marginal estimator log-prob based test for misspecification** by @swag2198 in https://github.com/sbi-dev/sbi/pull/1522
2725
* **Adding interface for unconditional flow training** by @plcrodrigues in https://github.com/sbi-dev/sbi/pull/1470
2826
* **Support using trained estimators in Pyro models** by @sethaxen in https://github.com/sbi-dev/sbi/pull/1491
2927
* **Add util to generate mcmc samples from user defined potential (#1405)** by @hayden-johnson in https://github.com/sbi-dev/sbi/pull/1483
3028
* **Logit transform** by @anastasiakrouglova in https://github.com/sbi-dev/sbi/pull/1485
29+
* **Log-prob for iid data for score estimators** by @Kartik-Sama in https://github.com/sbi-dev/sbi/pull/1508
3130

3231
#### 📚 Documentation & Tutorials
3332

@@ -46,6 +45,8 @@
4645
* fix: cap max_sampling_batch_size to prevent excessive memory by @janfb in https://github.com/sbi-dev/sbi/pull/1624
4746
* 1561 computation of denoising posterior precision matrix in jac method score fn iid by @manuelgloeckler in https://github.com/sbi-dev/sbi/pull/1636
4847
* fix xfail test, fix deprecation warnings by @janfb in https://github.com/sbi-dev/sbi/pull/1642
48+
* fix: iid-score device handling by @janfb in https://github.com/sbi-dev/sbi/pull/1650
49+
* fix: fmpe singularity on sde sampling by @manuelgloeckler in https://github.com/sbi-dev/sbi/pull/1661
4950

5051
### 🛠️ Maintenance & Improvements
5152

@@ -60,6 +61,7 @@
6061
* Use TypeAlias and consistent naming for sbi types by @janfb in https://github.com/sbi-dev/sbi/pull/1637
6162
* Add protocol for estimator builder by @abelaba in https://github.com/sbi-dev/sbi/pull/1633
6263
* Improve abc implementation by @janfb in https://github.com/sbi-dev/sbi/pull/1615
64+
* Refactor RatioEstimator to subclass ConditionalEstimator @abelaba in https://github.com/sbi-dev/sbi/pull/1652
6365

6466
#### 🏷️ Type Hints & API Improvements
6567

@@ -86,6 +88,8 @@
8688
* chore: reorder setup steps for Python and uv in CI/CD workflows by @janfb in https://github.com/sbi-dev/sbi/pull/1601
8789
* Fix/lc2st numpy type fixes by @janfb in https://github.com/sbi-dev/sbi/pull/1613
8890
* Fix failing CI on main. by @janfb in https://github.com/sbi-dev/sbi/pull/1618
91+
* Fix slow vector field tests by @janfb in https://github.com/sbi-dev/sbi/pull/1657
92+
* Add tests for sensitivity analysis by @janfb in https://github.com/sbi-dev/sbi/pull/1662
8993

9094
#### 📖 Documentation & Website
9195

@@ -113,6 +117,9 @@
113117
* fixed misrendered bullet list, tested locally by @psteinb in https://github.com/sbi-dev/sbi/pull/1594
114118
* Improvements to L-C2ST tutorial by @michaeldeistler in https://github.com/sbi-dev/sbi/pull/1588
115119
* docs: Change colortheme in light mode by @michaeldeistler in https://github.com/sbi-dev/sbi/pull/1638
120+
* Posterior parameters doc by @abelaba in https://github.com/sbi-dev/sbi/pull/1644
121+
* fix contributing links by @janfb in https://github.com/sbi-dev/sbi/pull/1647
122+
* docs: add posterior parameters dataclass how to guide by @abelaba in https://github.com/sbi-dev/sbi/pull/1654
116123

117124
#### 🏗️ Infrastructure & Dependencies
118125

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[![PyPI version](https://badge.fury.io/py/sbi.svg)](https://badge.fury.io/py/sbi)
22
[![Conda Version](https://img.shields.io/conda/vn/conda-forge/sbi.svg)](https://github.com/conda-forge/sbi-feedstock)
33
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://sbi.readthedocs.io/en/latest/contributing.html)
4-
[![Tests](https://github.com/sbi-dev/sbi/actions/workflows/ci.yml/badge.svg)](https://github.com/sbi-dev/sbi/actions)
4+
[![Tests](https://github.com/sbi-dev/sbi/actions/workflows/cd.yml/badge.svg)](https://github.com/sbi-dev/sbi/actions)
55
[![codecov](https://codecov.io/gh/sbi-dev/sbi/branch/main/graph/badge.svg)](https://codecov.io/gh/sbi-dev/sbi)
66
[![GitHub license](https://img.shields.io/github/license/sbi-dev/sbi)](https://github.com/sbi-dev/sbi/blob/master/LICENSE.txt)
77
[![DOI](https://joss.theoj.org/papers/10.21105/joss.07754/status.svg)](https://doi.org/10.21105/joss.07754)

docs/how_to_guide.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,12 @@ Visualization
7979
:maxdepth: 1
8080

8181
how_to_guide/05_conditional_distributions.ipynb
82+
83+
84+
Posterior Parameters
85+
--------------------
86+
87+
.. toctree::
88+
:maxdepth: 1
89+
90+
how_to_guide/19_posterior_parameters.ipynb

docs/how_to_guide/19_posterior_parameters.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"id": "8c4df01a",
3030
"metadata": {},
3131
"source": [
32-
"# Usage"
32+
"## Usage"
3333
]
3434
},
3535
{

sbi/analysis/sensitivity_analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def destandardizing_net(batch_t: Tensor, min_std: float = 1e-7) -> nn.Module:
4747
is_valid_t, *_ = handle_invalid_x(batch_t, True)
4848

4949
t_mean = torch.mean(batch_t[is_valid_t], dim=0)
50-
if len(batch_t > 1):
50+
51+
# Use batch size to decide whether a reliable std can be computed.
52+
if len(batch_t) > 1:
5153
t_std = torch.std(batch_t[is_valid_t], dim=0)
5254
t_std[t_std < min_std] = min_std
5355
else:

sbi/diagnostics/misspecification.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Optional
99

1010
import torch
11-
import torch.nn as nn
1211
from torch import Tensor
1312

1413
from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
@@ -149,25 +148,13 @@ def calc_misspecification_mmd(
149148
"no neural net found,"
150149
"neural_net should not be None when mode is 'embedding'"
151150
)
152-
neural_net = inference._neural_net
153-
if neural_net is None:
154-
raise ValueError(
155-
"no neural net found,"
156-
"neural_net should not be None when mode is 'embedding'"
157-
)
158-
if neural_net.embedding_net is None:
159-
raise ValueError(
160-
"no embedding net found,"
161-
"embedding_net should not be None when mode is 'embedding'"
162-
)
163-
if isinstance(neural_net.embedding_net, nn.modules.linear.Identity):
164-
warnings.warn(
165-
"The embedding net might be the identity function,"
166-
"in that case the MMD is computed in the x-space.",
167-
stacklevel=2,
151+
if inference._neural_net.embedding_net is None:
152+
raise AttributeError(
153+
"embedding_net attribute is None but is required for misspecification "
154+
"detection."
168155
)
169-
z_obs = neural_net.embedding_net(x_obs).detach()
170-
z = neural_net.embedding_net(x).detach()
156+
z_obs = inference._neural_net.embedding_net(x_obs).detach()
157+
z = inference._neural_net.embedding_net(x).detach()
171158
else:
172159
raise ValueError("mode should be either 'x_space' or 'embedding'")
173160

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import math
45
import warnings
56
from typing import Dict, Literal, Optional, Union
67

@@ -150,7 +151,9 @@ def sample(
150151
corrector_params: Optional[Dict] = None,
151152
steps: int = 500,
152153
ts: Optional[Tensor] = None,
153-
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
154+
iid_method: Optional[
155+
Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"]
156+
] = None,
154157
iid_params: Optional[Dict] = None,
155158
max_sampling_batch_size: int = 10_000,
156159
sample_with: Optional[str] = None,
@@ -210,19 +213,22 @@ def sample(
210213

211214
is_iid = x.shape[0] > 1
212215
self.potential_fn.set_x(
213-
x, x_is_iid=is_iid, iid_method=iid_method, iid_params=iid_params
216+
x,
217+
x_is_iid=is_iid,
218+
iid_method=iid_method or self.potential_fn.iid_method,
219+
iid_params=iid_params,
214220
)
215221

216222
num_samples = torch.Size(sample_shape).numel()
217223

218224
if sample_with == "ode":
219-
samples = rejection.accept_reject_sample(
225+
samples, _ = rejection.accept_reject_sample(
220226
proposal=self.sample_via_ode,
221227
accept_reject_fn=lambda theta: within_support(self.prior, theta),
222228
num_samples=num_samples,
223229
show_progress_bars=show_progress_bars,
224230
max_sampling_batch_size=max_sampling_batch_size,
225-
)[0]
231+
)
226232
elif sample_with == "sde":
227233
proposal_sampling_kwargs = {
228234
"predictor": predictor,
@@ -234,14 +240,14 @@ def sample(
234240
"max_sampling_batch_size": max_sampling_batch_size,
235241
"show_progress_bars": show_progress_bars,
236242
}
237-
samples = rejection.accept_reject_sample(
243+
samples, _ = rejection.accept_reject_sample(
238244
proposal=self._sample_via_diffusion,
239245
accept_reject_fn=lambda theta: within_support(self.prior, theta),
240246
num_samples=num_samples,
241247
show_progress_bars=show_progress_bars,
242248
max_sampling_batch_size=max_sampling_batch_size,
243249
proposal_sampling_kwargs=proposal_sampling_kwargs,
244-
)[0]
250+
)
245251
else:
246252
raise ValueError(
247253
f"Expected sample_with to be 'ode' or 'sde', but got {sample_with}."
@@ -263,6 +269,7 @@ def _sample_via_diffusion(
263269
ts: Optional[Tensor] = None,
264270
max_sampling_batch_size: int = 10_000,
265271
show_progress_bars: bool = True,
272+
save_intermediate: bool = False,
266273
) -> Tensor:
267274
r"""Return samples from posterior distribution $p(\theta|x)$.
268275
@@ -284,20 +291,26 @@ def _sample_via_diffusion(
284291
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
285292
`.sample()`.
286293
show_progress_bars: Whether to show a progress bar during sampling.
294+
save_intermediate: Whether to save intermediate results of the diffusion
295+
process. If True, the returned tensor has shape
296+
`(*sample_shape, steps, *input_shape)`.
287297
"""
288298

289299
if not self.vector_field_estimator.SCORE_DEFINED:
290300
raise ValueError(
291301
"The vector field estimator does not support the 'sde' sampling method."
292302
)
293303

294-
num_samples = torch.Size(sample_shape).numel()
304+
total_samples_needed = torch.Size(sample_shape).numel()
295305

296-
max_sampling_batch_size = (
306+
# Determine effective batch size for sampling
307+
effective_batch_size = (
297308
self.max_sampling_batch_size
298309
if max_sampling_batch_size is None
299310
else max_sampling_batch_size
300311
)
312+
# Ensure we don't use larger batches than total samples needed
313+
effective_batch_size = min(effective_batch_size, total_samples_needed)
301314

302315
# TODO: the time schedule should be provided by the estimator, see issue #1437
303316
if ts is None:
@@ -306,28 +319,46 @@ def _sample_via_diffusion(
306319
ts = torch.linspace(t_max, t_min, steps)
307320
ts = ts.to(self.device)
308321

322+
# Initialize the diffusion sampler
309323
diffuser = Diffuser(
310324
self.potential_fn,
311325
predictor=predictor,
312326
corrector=corrector,
313327
predictor_params=predictor_params,
314328
corrector_params=corrector_params,
315329
)
316-
max_sampling_batch_size = min(max_sampling_batch_size, num_samples)
317-
samples = []
318-
num_iter = num_samples // max_sampling_batch_size
319-
num_iter = (
320-
num_iter + 1 if (num_samples % max_sampling_batch_size) != 0 else num_iter
321-
)
322-
for _ in range(num_iter):
323-
samples.append(
324-
diffuser.run(
325-
num_samples=max_sampling_batch_size,
326-
ts=ts,
327-
show_progress_bars=show_progress_bars,
328-
)
330+
331+
# Calculate how many batches we need
332+
num_batches = math.ceil(total_samples_needed / effective_batch_size)
333+
334+
# Generate samples in batches
335+
all_samples = []
336+
samples_generated = 0
337+
338+
for _ in range(num_batches):
339+
# Calculate how many samples to generate in this batch
340+
remaining_samples = total_samples_needed - samples_generated
341+
current_batch_size = min(effective_batch_size, remaining_samples)
342+
343+
# Generate samples for this batch
344+
batch_samples = diffuser.run(
345+
num_samples=current_batch_size,
346+
ts=ts,
347+
show_progress_bars=show_progress_bars,
348+
save_intermediate=save_intermediate,
349+
)
350+
351+
all_samples.append(batch_samples)
352+
samples_generated += current_batch_size
353+
354+
# Concatenate all batches and ensure we return exactly the requested number
355+
samples = torch.cat(all_samples, dim=0)[:total_samples_needed]
356+
357+
if torch.isnan(samples).all():
358+
raise RuntimeError(
359+
"All samples NaN after diffusion sampling. "
360+
"This may indicate numerical instability in the vector field."
329361
)
330-
samples = torch.cat(samples, dim=0)[:num_samples]
331362

332363
return samples
333364

@@ -382,7 +413,10 @@ def log_prob(
382413
`(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
383414
support of the prior, -∞ (corresponding to 0 probability) outside.
384415
"""
385-
self.potential_fn.set_x(self._x_else_default_x(x), **(ode_kwargs or {}))
416+
x = self._x_else_default_x(x)
417+
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
418+
is_iid = x.shape[0] > 1
419+
self.potential_fn.set_x(x, x_is_iid=is_iid, **(ode_kwargs or {}))
386420

387421
theta = ensure_theta_batched(torch.as_tensor(theta))
388422
return self.potential_fn(
@@ -461,14 +495,14 @@ def sample_batched(
461495
max_sampling_batch_size = capped
462496

463497
if self.sample_with == "ode":
464-
samples = rejection.accept_reject_sample(
498+
samples, _ = rejection.accept_reject_sample(
465499
proposal=self.sample_via_ode,
466500
accept_reject_fn=lambda theta: within_support(self.prior, theta),
467501
num_samples=num_samples,
468502
num_xos=batch_size,
469503
show_progress_bars=show_progress_bars,
470504
max_sampling_batch_size=max_sampling_batch_size,
471-
)[0]
505+
)
472506
samples = samples.reshape(
473507
sample_shape + batch_shape + self.vector_field_estimator.input_shape
474508
)
@@ -483,15 +517,15 @@ def sample_batched(
483517
"max_sampling_batch_size": max_sampling_batch_size,
484518
"show_progress_bars": show_progress_bars,
485519
}
486-
samples = rejection.accept_reject_sample(
520+
samples, _ = rejection.accept_reject_sample(
487521
proposal=self._sample_via_diffusion,
488522
accept_reject_fn=lambda theta: within_support(self.prior, theta),
489523
num_samples=num_samples,
490524
num_xos=batch_size,
491525
show_progress_bars=show_progress_bars,
492526
max_sampling_batch_size=max_sampling_batch_size,
493527
proposal_sampling_kwargs=proposal_sampling_kwargs,
494-
)[0]
528+
)
495529
samples = samples.reshape(
496530
sample_shape + batch_shape + self.vector_field_estimator.input_shape
497531
)

0 commit comments

Comments
 (0)