Skip to content

Commit 597da91

Browse files
authored
fix: add nan check to _loss methods. (#1361)
* add nan check to _loss method. * add checks for nres
1 parent 0bb84d9 commit 597da91

File tree

7 files changed

+25
-5
lines changed

7 files changed

+25
-5
lines changed

sbi/inference/trainers/nle/nle_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
reshape_to_batch_event,
2323
)
2424
from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation
25+
from sbi.utils.torchutils import assert_all_finite
2526

2627

2728
class LikelihoodEstimator(NeuralInference, ABC):
@@ -381,4 +382,6 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
381382
theta, event_shape=self._neural_net.condition_shape
382383
)
383384
x = reshape_to_batch_event(x, event_shape=self._neural_net.input_shape)
384-
return self._neural_net.loss(x, condition=theta)
385+
loss = self._neural_net.loss(x, condition=theta)
386+
assert_all_finite(loss, "NLE loss")
387+
return loss

sbi/inference/trainers/npe/npe_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
warn_if_zscoring_changes_data,
4343
)
4444
from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior
45+
from sbi.utils.torchutils import assert_all_finite
4546

4647

4748
class PosteriorEstimator(NeuralInference, ABC):
@@ -609,6 +610,7 @@ def _loss(
609610
# Must be extended ones other Estimators are implemented. See #966,
610611
loss = -self._log_prob_proposal_posterior(theta, x, masks, proposal)
611612

613+
assert_all_finite(loss, "NPE loss")
612614
return calibration_kernel(x) * loss
613615

614616
def _check_proposal(self, proposal):

sbi/inference/trainers/npse/npse.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
warn_if_zscoring_changes_data,
2929
)
3030
from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior
31+
from sbi.utils.torchutils import assert_all_finite
3132

3233

3334
class NPSE(NeuralInference):
@@ -510,6 +511,7 @@ def _loss(
510511
"Multi-round NPSE with arbitrary proposals is not implemented"
511512
)
512513

514+
assert_all_finite(loss, "NPSE loss")
513515
return calibration_kernel(x) * loss
514516

515517
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:

sbi/inference/trainers/nre/bnre.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sbi.inference.trainers.nre.nre_a import NRE_A
1111
from sbi.sbi_types import TensorboardSummaryWriter
1212
from sbi.utils.sbiutils import del_entries
13+
from sbi.utils.torchutils import assert_all_finite
1314

1415

1516
class BNRE(NRE_A):
@@ -142,4 +143,6 @@ def _loss(
142143
.square()
143144
)
144145

145-
return bce + regularization_strength * regularizer
146+
loss = bce + regularization_strength * regularizer
147+
assert_all_finite(loss, "BNRE loss")
148+
return loss

sbi/inference/trainers/nre/nre_a.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sbi.inference.trainers.nre.nre_base import RatioEstimator
1111
from sbi.sbi_types import TensorboardSummaryWriter
1212
from sbi.utils.sbiutils import del_entries
13+
from sbi.utils.torchutils import assert_all_finite
1314

1415

1516
class NRE_A(RatioEstimator):
@@ -124,4 +125,6 @@ def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
124125
labels[1::2] = 0.0
125126

126127
# Binary cross entropy to learn the likelihood (AALR-specific)
127-
return nn.BCELoss()(likelihood, labels)
128+
loss = nn.BCELoss()(likelihood, labels)
129+
assert_all_finite(loss, "NRE-A loss")
130+
return loss

sbi/inference/trainers/nre/nre_b.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sbi.inference.trainers.nre.nre_base import RatioEstimator
1111
from sbi.sbi_types import TensorboardSummaryWriter
1212
from sbi.utils.sbiutils import del_entries
13+
from sbi.utils.torchutils import assert_all_finite
1314

1415

1516
class NRE_B(RatioEstimator):
@@ -122,4 +123,6 @@ def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
122123
# "correct" one for the 1-out-of-N classification.
123124
log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)
124125

125-
return -torch.mean(log_prob)
126+
loss = -torch.mean(log_prob)
127+
assert_all_finite(loss, "NRE-B loss")
128+
return loss

sbi/inference/trainers/nre/nre_c.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sbi.inference.trainers.nre.nre_base import RatioEstimator
1111
from sbi.sbi_types import TensorboardSummaryWriter
1212
from sbi.utils.sbiutils import del_entries
13+
from sbi.utils.torchutils import assert_all_finite
1314

1415

1516
class NRE_C(RatioEstimator):
@@ -190,7 +191,10 @@ def _loss(
190191

191192
# relative weights. p_marginal := p_0, and p_joint := p_K * K from the notation.
192193
p_marginal, p_joint = self._get_prior_probs_marginal_and_joint(gamma)
193-
return -torch.mean(p_marginal * log_prob_marginal + p_joint * log_prob_joint)
194+
195+
loss = -torch.mean(p_marginal * log_prob_marginal + p_joint * log_prob_joint)
196+
assert_all_finite(loss, "NRE-C loss")
197+
return loss
194198

195199
@staticmethod
196200
def _get_prior_probs_marginal_and_joint(gamma: float) -> Tuple[float, float]:

0 commit comments

Comments
 (0)