Skip to content

Commit 4cd7700

Browse files
authored
fix: explicit types and '==' for lc2st classifier comparison (#1550)
* fix: replace 'in' operator with '==' for proper classifier comparison * Update classifier argument to accept classifier classes * Update classifier initialization in lc2st constructor method * Remove clf_class parameter * Change classifier type from ClassifierMixin to BaseEstimator
1 parent 2a35e66 commit 4cd7700

File tree

2 files changed

+30
-28
lines changed

2 files changed

+30
-28
lines changed

sbi/diagnostics/lc2st.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
55

66
import numpy as np
77
import torch
@@ -24,10 +24,9 @@ def __init__(
2424
seed: int = 1,
2525
num_folds: int = 1,
2626
num_ensemble: int = 1,
27-
classifier: str = "mlp",
27+
classifier: Union[str, Type[BaseEstimator]] = MLPClassifier,
2828
z_score: bool = False,
29-
clf_class: Optional[Any] = None,
30-
clf_kwargs: Optional[Dict[str, Any]] = None,
29+
classifier_kwargs: Optional[Dict[str, Any]] = None,
3130
num_trials_null: int = 100,
3231
permutation: bool = True,
3332
) -> None:
@@ -71,10 +70,11 @@ def __init__(
7170
num_ensemble: Number of classifiers for ensembling, defaults to 1.
7271
This is useful to reduce variance coming from the classifier.
7372
z_score: Whether to z-score to normalize the data, defaults to False.
74-
classifier: Classification architecture to use,
75-
possible values: "random_forest" or "mlp", defaults to "mlp".
76-
clf_class: Custom sklearn classifier class, defaults to None.
77-
clf_kwargs: Custom kwargs for the sklearn classifier, defaults to None.
73+
classifier: Classification architecture to use, can be one of the following:
74+
- "random_forest" or "mlp", defaults to "mlp" or
75+
- A classifier class (e.g., RandomForestClassifier, MLPClassifier)
76+
classifier_kwargs: Custom kwargs for the sklearn classifier,
77+
defaults to None.
7878
num_trials_null: Number of trials to estimate the null distribution,
7979
defaults to 100.
8080
permutation: Whether to use the permutation method for the null hypothesis,
@@ -111,10 +111,26 @@ def __init__(
111111
self.num_ensemble = num_ensemble
112112

113113
# initialize classifier
114-
if "mlp" in classifier.lower():
115-
ndim = thetas.shape[-1]
116-
self.clf_class = MLPClassifier
117-
if clf_kwargs is None:
114+
if isinstance(classifier, str):
115+
if classifier.lower() == 'mlp':
116+
classifier = MLPClassifier
117+
elif classifier.lower() == 'random_forest':
118+
classifier = RandomForestClassifier
119+
else:
120+
raise ValueError(
121+
f'Invalid classifier: "{classifier}".'
122+
'Expected "mlp", "random_forest", '
123+
'or a valid scikit-learn classifier class.'
124+
)
125+
assert issubclass(classifier, BaseEstimator), (
126+
"classifier must be a subclass of sklearn's BaseEstimator"
127+
)
128+
self.clf_class = classifier
129+
130+
self.clf_kwargs = classifier_kwargs
131+
if self.clf_kwargs is None:
132+
if self.clf_class == MLPClassifier:
133+
ndim = thetas.shape[-1]
118134
self.clf_kwargs = {
119135
"activation": "relu",
120136
"hidden_layer_sizes": (10 * ndim, 10 * ndim),
@@ -123,19 +139,8 @@ def __init__(
123139
"early_stopping": True,
124140
"n_iter_no_change": 50,
125141
}
126-
elif "random_forest" in classifier.lower():
127-
self.clf_class = RandomForestClassifier
128-
if clf_kwargs is None:
142+
else:
129143
self.clf_kwargs = {}
130-
elif "custom":
131-
if clf_class is None or clf_kwargs is None:
132-
raise ValueError(
133-
"Please provide a valid sklearn classifier class and kwargs."
134-
)
135-
self.clf_class = clf_class
136-
self.clf_kwargs = clf_kwargs
137-
else:
138-
raise NotImplementedError
139144

140145
# initialize classifiers, will be set after training
141146
self.trained_clfs = None

tests/lc2st_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
@pytest.mark.parametrize("method", (LC2ST, LC2ST_NF))
19-
@pytest.mark.parametrize("classifier", ('mlp', 'random_forest', 'custom'))
19+
@pytest.mark.parametrize("classifier", ('mlp', 'random_forest', MLPClassifier))
2020
@pytest.mark.parametrize("cv_folds", (1, 2))
2121
@pytest.mark.parametrize("num_ensemble", (1, 3))
2222
@pytest.mark.parametrize("z_score", (True, False))
@@ -72,9 +72,6 @@ def test_running_lc2st(method, classifier, cv_folds, num_ensemble, z_score):
7272
"num_eval": num_eval,
7373
}
7474
kwargs_eval = {}
75-
if classifier == "custom":
76-
kwargs_test["clf_class"] = MLPClassifier
77-
kwargs_test["clf_kwargs"] = {"alpha": 0.0, "max_iter": 2500}
7875
kwargs_test["classifier"] = classifier
7976

8077
lc2st = method(

0 commit comments

Comments
 (0)