Skip to content

Commit be335fc

Browse files
committed
move NaN check to posterior.sample level, update tests, fix rejection sampling warning
1 parent 05b3478 commit be335fc

File tree

4 files changed

+12
-9
lines changed

4 files changed

+12
-9
lines changed

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,12 @@ def _sample_via_diffusion(
340340
# Concatenate all batches and ensure we return exactly the requested number
341341
samples = torch.cat(all_samples, dim=0)[:total_samples_needed]
342342

343+
if torch.isnan(samples).all():
344+
raise RuntimeError(
345+
"All samples NaN after diffusion sampling. "
346+
"This may indicate numerical instability in the vector field."
347+
)
348+
343349
return samples
344350

345351
def sample_via_ode(

sbi/samplers/rejection/rejection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def accept_reject_sample(
327327
max(int(1.5 * num_remaining / max(min_acceptance_rate, 1e-12)), 100),
328328
)
329329
if (
330-
num_samples_possible > 1000
330+
num_samples_possible > (sampling_batch_size - 1)
331331
and min_acceptance_rate < warn_acceptance
332332
and not leakage_warning_raised
333333
):

sbi/samplers/score/diffuser.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,6 @@ def run(
174174
if save_intermediate:
175175
intermediate_samples.append(samples)
176176

177-
# Check for NaN values after predictor
178-
if torch.isnan(samples).any():
179-
raise RuntimeError(
180-
"NaN values detected after diffusion sampling "
181-
"This may indicate numerical instability in the vector field."
182-
)
183-
184177
if save_intermediate:
185178
return torch.cat(intermediate_samples, dim=0)
186179
else:

tests/linearGaussian_vector_field_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,11 @@ def test_vector_field_iid_inference(
373373
num_trials: The number of trials to run.
374374
"""
375375

376-
if prior_type == "uniform" and iid_method in ["gauss", "auto_gauss", "jac_gauss"]:
376+
if (
377+
vector_field_type == "fmpe"
378+
and prior_type == "uniform"
379+
and iid_method in ["gauss", "auto_gauss", "jac_gauss"]
380+
):
377381
# TODO: Predictor produces NaNs for these cases, see #1656
378382
pytest.skip("Known issue of IID methods with uniform priors, see #1656.")
379383

0 commit comments

Comments
 (0)