Skip to content

Commit 089d1d3

Browse files
committed
fix: change net kwarg
1 parent d070d2a commit 089d1d3

File tree

1 file changed

+4
-4
lines changed
  • sbi/neural_nets/net_builders

1 file changed

+4
-4
lines changed

sbi/neural_nets/net_builders/mnle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def build_mnle(
5656
z_score_x: Optional[str] = "independent",
5757
z_score_y: Optional[str] = "independent",
5858
flow_model: str = "nsf",
59-
categorical_model: str = "made",
59+
categorical_model: str = "mlp",
6060
embedding_net: nn.Module = nn.Identity(),
6161
combined_embedding_net: Optional[nn.Module] = None,
6262
num_transforms: int = 2,
@@ -104,7 +104,7 @@ def build_mnle(
104104
flow_model: type of flow model to use for the continuous part of the
105105
data.
106106
categorical_model: type of categorical net to use for the discrete part of
107-
the data. Can be "made" or "categorical".
107+
the data. Can be "made" or "mlp".
108108
embedding_net: Optional embedding network for y, required if y is > 1D.
109109
combined_embedding_net: Optional embedding for combining the discrete
110110
part of the input and the embedded condition into a joined
@@ -157,7 +157,7 @@ def build_mnle(
157157
num_layers=hidden_layers,
158158
embedding_net=embedding_net,
159159
)
160-
elif categorical_model == "categorical":
160+
elif categorical_model == "mlp":
161161
discrete_net = build_categoricalmassestimator(
162162
disc_x,
163163
batch_y,
@@ -169,7 +169,7 @@ def build_mnle(
169169
)
170170
else:
171171
raise ValueError(
172-
f"Unknown categorical net {categorical_model}. Must be 'made' or 'categorical'."
172+
f"Unknown categorical net {categorical_model}. Must be 'made' or 'mlp'."
173173
)
174174

175175
if combined_embedding_net is None:

0 commit comments

Comments
 (0)