Skip to content

Commit 586e982

Browse files
authored
Merge pull request #52 from ncooder/docs_correction
Corrected docstrings and parameter names
2 parents 58d9e9d + 949f3a9 commit 586e982

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

src/arfs/feature_selection/mrmr.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
This module provides MinRedundancyMaxRelevance (MRMR) feature selection for classification or regression tasks.
44
In a classification task, the target should be of object or pandas category dtype, while in a regression task,
5-
the target should be of numpy categorical dtype. The predictors can be categorical or numerical without requiring encoding,
5+
the target should be numeric. The predictors can be categorical or numerical without requiring encoding,
66
as the appropriate method (correlation, correlation ratio, or Theil's U) will be automatically selected based on the data type.
77
88
Module Structure:
@@ -42,16 +42,16 @@ class MinRedundancyMaxRelevance(SelectorMixin, BaseEstimator):
4242
relevance_func: callable, optional
4343
relevance function having arguments "X", "y", "sample_weight" and returning a pd.Series
4444
containing a score of relevance for each feature
45-
redundancy: callable, optional
45+
redundancy_func: callable, optional
4646
Redundancy method.
4747
If callable, it should take "X", "sample_weight" as input and return a pandas.Series
4848
containing a score of redundancy for each feature.
49-
denominator: str or callable (optional, default='mean')
49+
denominator_func: str or callable (optional, default='mean')
5050
Synthesis function to apply to the denominator of MRMR score.
5151
If string, name of method. Supported: 'max', 'mean'.
5252
If callable, it should take an iterable as input and return a scalar.
5353
task: str
54-
either "regression" or "classifiction"
54+
either "regression" or "classification"
5555
only_same_domain: bool (optional, default=False)
5656
If False, all the necessary correlation coefficients are computed.
5757
If True, only features belonging to the same domain are compared.
@@ -60,7 +60,7 @@ class MinRedundancyMaxRelevance(SelectorMixin, BaseEstimator):
6060
return_scores: bool (optional, default=False)
6161
If False, only the list of selected features is returned.
6262
If True, a tuple containing (list of selected features, relevance, redundancy) is returned.
63-
n_jobs: int (optional, default=-1)
63+
n_jobs: int (optional, default=1)
6464
Maximum number of workers to use. Only used when relevance = "f" or redundancy = "corr".
6565
If -1, use as many workers as min(cpu count, number of features).
6666
show_progress: bool (optional, default=True)
@@ -89,10 +89,11 @@ class MinRedundancyMaxRelevance(SelectorMixin, BaseEstimator):
8989
>>> pred_name = [f"pred_{i}" for i in range(X.shape[1])]
9090
>>> X.columns = pred_name
9191
>>> y.name = "target"
92-
>>> fs_mrmr = MinRedundancyMaxRelevance(n_features_to_select=5,
92+
>>> fs_mrmr = MinRedundancyMaxRelevance(
93+
>>> n_features_to_select=5,
9394
>>> relevance_func=None,
9495
>>> redundancy_func=None,
95-
>>> task= "regression",#"classification",
96+
>>> task="regression", #"classification",
9697
>>> denominator_func=np.mean,
9798
>>> only_same_domain=False,
9899
>>> return_scores=False,
@@ -146,16 +147,16 @@ def fit(self, X, y, sample_weight=None):
146147
X : pd.DataFrame, shape (n_samples, n_features)
147148
Data from which to compute variances, where `n_samples` is
148149
the number of samples and `n_features` is the number of features.
149-
y : any, default=None
150-
Ignored. This parameter exists only for compatibility with
151-
sklearn.pipeline.Pipeline.
150+
y : array-like or pd.Series of shape (n_samples,)
151+
Target vector. Must be numeric for regression or categorical for classification.
152152
sample_weight : pd.Series, optional, shape (n_samples,)
153153
weights for computing the statistics (e.g. weighted average)
154154
155155
Returns
156156
-------
157157
self : object
158-
Returns the instance itself.
158+
If `return_scores=False`, returns self.
159+
If `return_scores=True`, returns (selected_features, relevance_scores).
159160
"""
160161

161162
if isinstance(X, pd.DataFrame):
@@ -212,6 +213,9 @@ def fit(self, X, y, sample_weight=None):
212213
[x in self.selected_features for x in self.feature_names_in_]
213214
)
214215
self.not_selected_features_ = self.not_selected_features
216+
217+
if self.return_scores:
218+
return self.selected_features_, self.relevance_, self.redundancy_
215219
return self
216220

217221
def transform(self, X):
@@ -232,7 +236,7 @@ def transform(self, X):
232236
raise TypeError("X is not a dataframe")
233237
return X[self.selected_features_]
234238

235-
def fit_transform(self, X, y, sample_weight=None):
239+
def fit_transform(self, X, y, sample_weight=None, **fit_params):
236240
"""
237241
Fit to data, then transform it.
238242
Fits transformer to `X` and `y` and optionally sample_weight

0 commit comments

Comments
 (0)