Skip to content

Commit 90d3d5e

Browse files
committed
Re-inserted check on calc misspecificaton mmd
1 parent 5fff8db commit 90d3d5e

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

sbi/diagnostics/misspecification.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Optional
99

1010
import torch
11-
from torch import Tensor
11+
from torch import Tensor, nn
1212

1313
from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
1414
from sbi.neural_nets.estimators import UnconditionalDensityEstimator
@@ -148,6 +148,12 @@ def calc_misspecification_mmd(
148148
"no neural net found,"
149149
"neural_net should not be None when mode is 'embedding'"
150150
)
151+
if isinstance(inference._neural_net.embedding_net, nn.modules.linear.Identity):
152+
warnings.warn(
153+
"The embedding net might be the identity function,"
154+
"in that case the MMD is computed in the x-space.",
155+
stacklevel=2,
156+
)
151157
if inference._neural_net.embedding_net is None:
152158
raise AttributeError(
153159
"embedding_net attribute is None but is required for misspecification "

0 commit comments

Comments
 (0)