@@ -56,7 +56,7 @@ def build_mnle(
56
56
z_score_x : Optional [str ] = "independent" ,
57
57
z_score_y : Optional [str ] = "independent" ,
58
58
flow_model : str = "nsf" ,
59
- categorical_model : str = "made " ,
59
+ categorical_model : str = "mlp " ,
60
60
embedding_net : nn .Module = nn .Identity (),
61
61
combined_embedding_net : Optional [nn .Module ] = None ,
62
62
num_transforms : int = 2 ,
@@ -104,7 +104,7 @@ def build_mnle(
104
104
flow_model: type of flow model to use for the continuous part of the
105
105
data.
106
106
categorical_model: type of categorical net to use for the discrete part of
107
- the data. Can be "made" or "categorical ".
107
+ the data. Can be "made" or "mlp ".
108
108
embedding_net: Optional embedding network for y, required if y is > 1D.
109
109
combined_embedding_net: Optional embedding for combining the discrete
110
110
part of the input and the embedded condition into a joined
@@ -157,7 +157,7 @@ def build_mnle(
157
157
num_layers = hidden_layers ,
158
158
embedding_net = embedding_net ,
159
159
)
160
- elif categorical_model == "categorical " :
160
+ elif categorical_model == "mlp " :
161
161
discrete_net = build_categoricalmassestimator (
162
162
disc_x ,
163
163
batch_y ,
@@ -169,7 +169,7 @@ def build_mnle(
169
169
)
170
170
else :
171
171
raise ValueError (
172
- f"Unknown categorical net { categorical_model } . Must be 'made' or 'categorical '."
172
+ f"Unknown categorical net { categorical_model } . Must be 'made' or 'mlp '."
173
173
)
174
174
175
175
if combined_embedding_net is None :
0 commit comments