Skip to content

Commit 8852430

Browse files
committed
add errors for MAP and iid data, adapt tests
1 parent c258cdc commit 8852430

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

sbi/inference/posteriors/score_posterior.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,9 @@ def map(
331331
Returns:
332332
The MAP estimate.
333333
"""
334+
raise NotImplementedError(
335+
"MAP estimation is currently not working accurately for ScorePosterior."
336+
)
334337
return super().map(
335338
x=x,
336339
num_iter=num_iter,

sbi/inference/potentials/score_based_potential.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def gradient(
160160
input=theta, condition=self.x_o, time=time
161161
)
162162
else:
163+
raise NotImplementedError(
164+
"Score accumulation for IID data is not yet fully implemented."
165+
)
163166
if self.prior is None:
164167
raise ValueError(
165168
"Prior must be provided when interpreting the data as IID."

tests/linearGaussian_npse_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List
2+
13
import pytest
24
import torch
35
from torch import eye, ones, zeros
@@ -30,7 +32,7 @@
3032
],
3133
)
3234
def test_c2st_npse_on_linearGaussian(
33-
sde_type, num_dim: int, prior_str: str, sample_with: list[str]
35+
sde_type, num_dim: int, prior_str: str, sample_with: List[str]
3436
):
3537
"""Test whether NPSE infers well a simple example with available ground truth."""
3638

@@ -78,7 +80,7 @@ def test_c2st_npse_on_linearGaussian(
7880
check_c2st(
7981
samples,
8082
target_samples,
81-
alg=f"npse-{sde_type or "vp"}-{prior_str}-{num_dim}D-{method}",
83+
alg=f"npse-{sde_type or 'vp'}-{prior_str}-{num_dim}D-{method}",
8284
)
8385

8486
# Checks for log_prob()
@@ -157,7 +159,12 @@ def simulator(theta):
157159
check_c2st(samples, target_samples, alg="npse_different_dims_and_resume_training")
158160

159161

160-
@pytest.mark.xfail(reason="iid_bridge not working.")
162+
@pytest.mark.xfail(
163+
reason="iid_bridge not working.",
164+
raises=NotImplementedError,
165+
strict=True,
166+
match="Score accumulation*",
167+
)
161168
@pytest.mark.parametrize("num_trials", [2, 10])
162169
def test_npse_iid_inference(num_trials):
163170
"""Test whether NPSE infers well a simple example with available ground truth."""

0 commit comments

Comments
 (0)