File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change 8
8
from typing import Optional
9
9
10
10
import torch
11
- from torch import Tensor
11
+ from torch import Tensor , nn
12
12
13
13
from sbi .inference .trainers .npe .npe_base import PosteriorEstimatorTrainer
14
14
from sbi .neural_nets .estimators import UnconditionalDensityEstimator
@@ -148,6 +148,12 @@ def calc_misspecification_mmd(
148
148
"no neural net found,"
149
149
"neural_net should not be None when mode is 'embedding'"
150
150
)
151
+ if isinstance (inference ._neural_net .embedding_net , nn .modules .linear .Identity ):
152
+ warnings .warn (
153
+ "The embedding net might be the identity function,"
154
+ "in that case the MMD is computed in the x-space." ,
155
+ stacklevel = 2 ,
156
+ )
151
157
if inference ._neural_net .embedding_net is None :
152
158
raise AttributeError (
153
159
"embedding_net attribute is None but is required for misspecification "
You can’t perform that action at this time.
0 commit comments