1
1
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2
2
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3
3
4
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
5
5
6
6
import numpy as np
7
7
import torch
@@ -24,10 +24,9 @@ def __init__(
24
24
seed : int = 1 ,
25
25
num_folds : int = 1 ,
26
26
num_ensemble : int = 1 ,
27
- classifier : str = "mlp" ,
27
+ classifier : Union [ str , Type [ BaseEstimator ]] = MLPClassifier ,
28
28
z_score : bool = False ,
29
- clf_class : Optional [Any ] = None ,
30
- clf_kwargs : Optional [Dict [str , Any ]] = None ,
29
+ classifier_kwargs : Optional [Dict [str , Any ]] = None ,
31
30
num_trials_null : int = 100 ,
32
31
permutation : bool = True ,
33
32
) -> None :
@@ -71,10 +70,11 @@ def __init__(
71
70
num_ensemble: Number of classifiers for ensembling, defaults to 1.
72
71
This is useful to reduce variance coming from the classifier.
73
72
z_score: Whether to z-score to normalize the data, defaults to False.
74
- classifier: Classification architecture to use,
75
- possible values: "random_forest" or "mlp", defaults to "mlp".
76
- clf_class: Custom sklearn classifier class, defaults to None.
77
- clf_kwargs: Custom kwargs for the sklearn classifier, defaults to None.
73
+ classifier: Classification architecture to use, can be one of the following:
74
+ - "random_forest" or "mlp", defaults to "mlp" or
75
+ - A classifier class (e.g., RandomForestClassifier, MLPClassifier)
76
+ classifier_kwargs: Custom kwargs for the sklearn classifier,
77
+ defaults to None.
78
78
num_trials_null: Number of trials to estimate the null distribution,
79
79
defaults to 100.
80
80
permutation: Whether to use the permutation method for the null hypothesis,
@@ -111,10 +111,26 @@ def __init__(
111
111
self .num_ensemble = num_ensemble
112
112
113
113
# initialize classifier
114
- if "mlp" in classifier .lower ():
115
- ndim = thetas .shape [- 1 ]
116
- self .clf_class = MLPClassifier
117
- if clf_kwargs is None :
114
+ if isinstance (classifier , str ):
115
+ if classifier .lower () == 'mlp' :
116
+ classifier = MLPClassifier
117
+ elif classifier .lower () == 'random_forest' :
118
+ classifier = RandomForestClassifier
119
+ else :
120
+ raise ValueError (
121
+ f'Invalid classifier: "{ classifier } ".'
122
+ 'Expected "mlp", "random_forest", '
123
+ 'or a valid scikit-learn classifier class.'
124
+ )
125
+ assert issubclass (classifier , BaseEstimator ), (
126
+ "classifier must be a subclass of sklearn's BaseEstimator"
127
+ )
128
+ self .clf_class = classifier
129
+
130
+ self .clf_kwargs = classifier_kwargs
131
+ if self .clf_kwargs is None :
132
+ if self .clf_class == MLPClassifier :
133
+ ndim = thetas .shape [- 1 ]
118
134
self .clf_kwargs = {
119
135
"activation" : "relu" ,
120
136
"hidden_layer_sizes" : (10 * ndim , 10 * ndim ),
@@ -123,19 +139,8 @@ def __init__(
123
139
"early_stopping" : True ,
124
140
"n_iter_no_change" : 50 ,
125
141
}
126
- elif "random_forest" in classifier .lower ():
127
- self .clf_class = RandomForestClassifier
128
- if clf_kwargs is None :
142
+ else :
129
143
self .clf_kwargs = {}
130
- elif "custom" :
131
- if clf_class is None or clf_kwargs is None :
132
- raise ValueError (
133
- "Please provide a valid sklearn classifier class and kwargs."
134
- )
135
- self .clf_class = clf_class
136
- self .clf_kwargs = clf_kwargs
137
- else :
138
- raise NotImplementedError
139
144
140
145
# initialize classifiers, will be set after training
141
146
self .trained_clfs = None
0 commit comments