Skip to content

Commit eae9cc9

Browse files
authored
tests: fix slow vector field tests, fix iid-scores (#1657)
* wip: fix vf tests * adapt tests * refactor score utils, small fixes. * refactor vf slow tests. * remove nan check during diffusion * move nan check to last diffusion step. * skip idd-score tests for npse as well * move NaN check to posterior.sample level, update tests, fix rejection sampling warning
1 parent e3fdb10 commit eae9cc9

File tree

6 files changed

+204
-143
lines changed

6 files changed

+204
-143
lines changed

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import math
45
import warnings
56
from typing import Dict, Literal, Optional, Union
67

@@ -150,7 +151,9 @@ def sample(
150151
corrector_params: Optional[Dict] = None,
151152
steps: int = 500,
152153
ts: Optional[Tensor] = None,
153-
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
154+
iid_method: Optional[
155+
Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"]
156+
] = None,
154157
iid_params: Optional[Dict] = None,
155158
max_sampling_batch_size: int = 10_000,
156159
sample_with: Optional[str] = None,
@@ -201,19 +204,22 @@ def sample(
201204
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
202205
is_iid = x.shape[0] > 1
203206
self.potential_fn.set_x(
204-
x, x_is_iid=is_iid, iid_method=iid_method, iid_params=iid_params
207+
x,
208+
x_is_iid=is_iid,
209+
iid_method=iid_method or self.potential_fn.iid_method,
210+
iid_params=iid_params,
205211
)
206212

207213
num_samples = torch.Size(sample_shape).numel()
208214

209215
if sample_with == "ode":
210-
samples = rejection.accept_reject_sample(
216+
samples, _ = rejection.accept_reject_sample(
211217
proposal=self.sample_via_ode,
212218
accept_reject_fn=lambda theta: within_support(self.prior, theta),
213219
num_samples=num_samples,
214220
show_progress_bars=show_progress_bars,
215221
max_sampling_batch_size=max_sampling_batch_size,
216-
)[0]
222+
)
217223
elif sample_with == "sde":
218224
proposal_sampling_kwargs = {
219225
"predictor": predictor,
@@ -225,14 +231,14 @@ def sample(
225231
"max_sampling_batch_size": max_sampling_batch_size,
226232
"show_progress_bars": show_progress_bars,
227233
}
228-
samples = rejection.accept_reject_sample(
234+
samples, _ = rejection.accept_reject_sample(
229235
proposal=self._sample_via_diffusion,
230236
accept_reject_fn=lambda theta: within_support(self.prior, theta),
231237
num_samples=num_samples,
232238
show_progress_bars=show_progress_bars,
233239
max_sampling_batch_size=max_sampling_batch_size,
234240
proposal_sampling_kwargs=proposal_sampling_kwargs,
235-
)[0]
241+
)
236242
else:
237243
raise ValueError(
238244
f"Expected sample_with to be 'ode' or 'sde', but got {sample_with}."
@@ -282,13 +288,16 @@ def _sample_via_diffusion(
282288
"The vector field estimator does not support the 'sde' sampling method."
283289
)
284290

285-
num_samples = torch.Size(sample_shape).numel()
291+
total_samples_needed = torch.Size(sample_shape).numel()
286292

287-
max_sampling_batch_size = (
293+
# Determine effective batch size for sampling
294+
effective_batch_size = (
288295
self.max_sampling_batch_size
289296
if max_sampling_batch_size is None
290297
else max_sampling_batch_size
291298
)
299+
# Ensure we don't use larger batches than total samples needed
300+
effective_batch_size = min(effective_batch_size, total_samples_needed)
292301

293302
# TODO: the time schedule should be provided by the estimator, see issue #1437
294303
if ts is None:
@@ -297,28 +306,45 @@ def _sample_via_diffusion(
297306
ts = torch.linspace(t_max, t_min, steps)
298307
ts = ts.to(self.device)
299308

309+
# Initialize the diffusion sampler
300310
diffuser = Diffuser(
301311
self.potential_fn,
302312
predictor=predictor,
303313
corrector=corrector,
304314
predictor_params=predictor_params,
305315
corrector_params=corrector_params,
306316
)
307-
max_sampling_batch_size = min(max_sampling_batch_size, num_samples)
308-
samples = []
309-
num_iter = num_samples // max_sampling_batch_size
310-
num_iter = (
311-
num_iter + 1 if (num_samples % max_sampling_batch_size) != 0 else num_iter
312-
)
313-
for _ in range(num_iter):
314-
samples.append(
315-
diffuser.run(
316-
num_samples=max_sampling_batch_size,
317-
ts=ts,
318-
show_progress_bars=show_progress_bars,
319-
)
317+
318+
# Calculate how many batches we need
319+
num_batches = math.ceil(total_samples_needed / effective_batch_size)
320+
321+
# Generate samples in batches
322+
all_samples = []
323+
samples_generated = 0
324+
325+
for _ in range(num_batches):
326+
# Calculate how many samples to generate in this batch
327+
remaining_samples = total_samples_needed - samples_generated
328+
current_batch_size = min(effective_batch_size, remaining_samples)
329+
330+
# Generate samples for this batch
331+
batch_samples = diffuser.run(
332+
num_samples=current_batch_size,
333+
ts=ts,
334+
show_progress_bars=show_progress_bars,
335+
)
336+
337+
all_samples.append(batch_samples)
338+
samples_generated += current_batch_size
339+
340+
# Concatenate all batches and ensure we return exactly the requested number
341+
samples = torch.cat(all_samples, dim=0)[:total_samples_needed]
342+
343+
if torch.isnan(samples).all():
344+
raise RuntimeError(
345+
"All samples NaN after diffusion sampling. "
346+
"This may indicate numerical instability in the vector field."
320347
)
321-
samples = torch.cat(samples, dim=0)[:num_samples]
322348

323349
return samples
324350

@@ -443,14 +469,14 @@ def sample_batched(
443469
max_sampling_batch_size = capped
444470

445471
if self.sample_with == "ode":
446-
samples = rejection.accept_reject_sample(
472+
samples, _ = rejection.accept_reject_sample(
447473
proposal=self.sample_via_ode,
448474
accept_reject_fn=lambda theta: within_support(self.prior, theta),
449475
num_samples=num_samples,
450476
num_xos=batch_size,
451477
show_progress_bars=show_progress_bars,
452478
max_sampling_batch_size=max_sampling_batch_size,
453-
)[0]
479+
)
454480
samples = samples.reshape(
455481
sample_shape + batch_shape + self.vector_field_estimator.input_shape
456482
)
@@ -465,15 +491,15 @@ def sample_batched(
465491
"max_sampling_batch_size": max_sampling_batch_size,
466492
"show_progress_bars": show_progress_bars,
467493
}
468-
samples = rejection.accept_reject_sample(
494+
samples, _ = rejection.accept_reject_sample(
469495
proposal=self._sample_via_diffusion,
470496
accept_reject_fn=lambda theta: within_support(self.prior, theta),
471497
num_samples=num_samples,
472498
num_xos=batch_size,
473499
show_progress_bars=show_progress_bars,
474500
max_sampling_batch_size=max_sampling_batch_size,
475501
proposal_sampling_kwargs=proposal_sampling_kwargs,
476-
)[0]
502+
)
477503
samples = samples.reshape(
478504
sample_shape + batch_shape + self.vector_field_estimator.input_shape
479505
)

sbi/inference/potentials/score_fn_iid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def estimate_posterior_precision(
661661
precision_est_budget = min(int(prior.event_shape[0] * 1000), 5000)
662662

663663
thetas = posterior.sample_batched(
664-
torch.Size([precision_est_budget]),
664+
sample_shape=torch.Size([precision_est_budget]),
665665
x=conditions,
666666
show_progress_bars=False,
667667
steps=precision_initial_sampler_steps,
@@ -740,7 +740,7 @@ def ensure_lam_positive_definite(
740740
denoising_posterior_precision: torch.Tensor,
741741
N: int,
742742
precision_nugget: float = 0.1,
743-
) -> (torch.Tensor, torch.Tensor):
743+
) -> tuple[torch.Tensor, torch.Tensor]:
744744
r"""
745745
Ensure that the matrix is positive definite.
746746

sbi/inference/potentials/vector_field_potential.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,6 @@
2121
from sbi.utils.torchutils import ensure_theta_batched
2222

2323

24-
def vector_field_estimator_based_potential(
25-
vector_field_estimator: ConditionalVectorFieldEstimator,
26-
prior: Optional[Distribution],
27-
x_o: Optional[Tensor],
28-
enable_transform: bool = True,
29-
**kwargs,
30-
) -> Tuple["VectorFieldBasedPotential", TorchTransform]:
31-
r"""Returns the potential function gradient for vector field estimators.
32-
33-
Args:
34-
vector_field_estimator: The neural network modelling the vector field.
35-
prior: The prior distribution.
36-
x_o: The observed data at which to evaluate the vector field.
37-
enable_transform: Whether to enable transforms. Not supported yet.
38-
**kwargs: Additional keyword arguments passed to
39-
`VectorFieldBasedPotential`.
40-
Returns:
41-
The potential function and a transformation that maps
42-
to unconstrained space.
43-
"""
44-
device = str(next(vector_field_estimator.parameters()).device)
45-
46-
potential_fn = VectorFieldBasedPotential(
47-
vector_field_estimator, prior, x_o, device=device, **kwargs
48-
)
49-
50-
if prior is not None:
51-
theta_transform = mcmc_transform(
52-
prior, device=device, enable_transform=enable_transform
53-
)
54-
else:
55-
theta_transform = torch.distributions.transforms.identity_transform
56-
57-
return potential_fn, theta_transform
58-
59-
6024
class VectorFieldBasedPotential(BasePotential):
6125
def __init__(
6226
self,
@@ -130,7 +94,7 @@ def set_x(
13094
self,
13195
x_o: Optional[Tensor],
13296
x_is_iid: Optional[bool] = False,
133-
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
97+
iid_method: Optional[str] = None,
13498
iid_params: Optional[Dict[str, Any]] = None,
13599
**ode_kwargs,
136100
):
@@ -149,7 +113,7 @@ def set_x(
149113
ode_kwargs: Additional keyword arguments for the neural ODE.
150114
"""
151115
super().set_x(x_o, x_is_iid)
152-
self.iid_method = iid_method
116+
self.iid_method = iid_method or self.iid_method
153117
self.iid_params = iid_params
154118
# NOTE: Once IID potential evaluation is supported. This needs to be adapted.
155119
# See #1450.
@@ -286,6 +250,42 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
286250
return flow
287251

288252

253+
def vector_field_estimator_based_potential(
254+
vector_field_estimator: ConditionalVectorFieldEstimator,
255+
prior: Optional[Distribution],
256+
x_o: Optional[Tensor],
257+
enable_transform: bool = True,
258+
**kwargs,
259+
) -> Tuple[VectorFieldBasedPotential, TorchTransform]:
260+
r"""Returns the potential function gradient for vector field estimators.
261+
262+
Args:
263+
vector_field_estimator: The neural network modelling the vector field.
264+
prior: The prior distribution.
265+
x_o: The observed data at which to evaluate the vector field.
266+
enable_transform: Whether to enable transforms. Not supported yet.
267+
**kwargs: Additional keyword arguments passed to
268+
`VectorFieldBasedPotential`.
269+
Returns:
270+
The potential function and a transformation that maps
271+
to unconstrained space.
272+
"""
273+
device = str(next(vector_field_estimator.parameters()).device)
274+
275+
potential_fn = VectorFieldBasedPotential(
276+
vector_field_estimator, prior, x_o, device=device, **kwargs
277+
)
278+
279+
if prior is not None:
280+
theta_transform = mcmc_transform(
281+
prior, device=device, enable_transform=enable_transform
282+
)
283+
else:
284+
theta_transform = torch.distributions.transforms.identity_transform
285+
286+
return potential_fn, theta_transform
287+
288+
289289
class DifferentiablePotentialFunction(torch.autograd.Function):
290290
"""
291291
A wrapper of `VectorFieldBasedPotential` with a custom autograd function

sbi/samplers/rejection/rejection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ def accept_reject_sample(
269269
pbar = tqdm(
270270
disable=not show_progress_bars,
271271
total=num_samples,
272-
desc=f"Drawing {num_samples} posterior samples for {num_xos} observations",
272+
desc=f"Drawing {num_samples} samples for {num_xos} observation" + "s"
273+
if num_xos > 1
274+
else "",
273275
)
274276

275277
accepted = [[] for _ in range(num_xos)]
@@ -280,6 +282,7 @@ def accept_reject_sample(
280282
sampling_batch_size = min(num_samples, max_sampling_batch_size)
281283
num_sampled_total = torch.zeros(num_xos)
282284
num_samples_possible = 0
285+
283286
while num_remaining > 0:
284287
# Sample and reject.
285288
candidates = proposal(
@@ -288,6 +291,7 @@ def accept_reject_sample(
288291
)
289292
# SNPE-style rejection-sampling when the proposal is the neural net.
290293
are_accepted = accept_reject_fn(candidates)
294+
291295
# Reshape necessary in certain cases which do not follow the shape conventions
292296
# of the "DensityEstimator" class.
293297
are_accepted = are_accepted.reshape(sampling_batch_size, num_xos)
@@ -323,7 +327,7 @@ def accept_reject_sample(
323327
max(int(1.5 * num_remaining / max(min_acceptance_rate, 1e-12)), 100),
324328
)
325329
if (
326-
num_samples_possible > 1000
330+
num_samples_possible > (sampling_batch_size - 1)
327331
and min_acceptance_rate < warn_acceptance
328332
and not leakage_warning_raised
329333
):

sbi/samplers/score/diffuser.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,34 @@ def run(
143143
Returns:
144144
Tensor: Samples from the distribution(s).
145145
"""
146+
# Initialize samples from the base distribution
146147
samples = self.initialize(num_samples).to(ts.device)
148+
149+
# Set up progress bar for time-stepping through the diffusion process
150+
total_time_steps = ts.numel() - 1 # We skip the first time point
147151
pbar = tqdm(
148152
range(1, ts.numel()),
149153
disable=not show_progress_bars,
150-
desc=f"Drawing {num_samples} posterior samples",
154+
desc=f"Generating {num_samples} posterior samples in {total_time_steps} "
155+
"diffusion steps.",
151156
)
152157

153158
if save_intermediate:
154159
intermediate_samples = [samples]
155160

156-
for i in pbar:
157-
t1 = ts[i - 1]
158-
t0 = ts[i]
159-
samples = self.predictor(samples, t1, t0)
161+
# Step through the diffusion process from t_max to t_min
162+
for time_step_idx in pbar:
163+
# Get current and next time points (going backwards in time)
164+
t_current = ts[time_step_idx - 1] # Previous time point
165+
t_next = ts[time_step_idx] # Current time point
166+
167+
# Apply predictor step
168+
samples = self.predictor(samples, t_current, t_next)
169+
170+
# Apply corrector step if available
160171
if self.corrector is not None:
161-
samples = self.corrector(samples, t0, t1)
172+
samples = self.corrector(samples, t_next, t_current)
173+
162174
if save_intermediate:
163175
intermediate_samples.append(samples)
164176

0 commit comments

Comments
 (0)