Skip to content

Commit a69ac12

Browse files
committed
Minor changes in constants.py
1 parent 10ee2e8 commit a69ac12

File tree

4 files changed

+38
-4
lines changed

4 files changed

+38
-4
lines changed

fslite/fs/constants.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
"""
22
This file contains a list of constants used in the feature selection and machine learning methods.
33
"""
4+
from typing import Dict, List, Union
45

56
FS_METHODS = {
67
'univariate': {
78
"title": 'Univariate Feature Selection',
89
"methods": [
910
{
1011
'name': 'anova',
11-
'description': 'ANOVA univariate feature selection (F-classification)'
12+
'description': 'Univariate ANOVA feature selection (f-classification)'
13+
},
14+
{
15+
'name': 'u_corr',
16+
'description': 'Univariate correlation'
17+
},
18+
{
19+
'name': 'f_regression',
20+
'description': 'Univariate f-regression'
1221
}
1322
]
1423
},
@@ -68,7 +77,7 @@ def get_fs_methods():
6877
"""
6978
return FS_METHODS
7079

71-
def get_fs_method_details(method_name: str):
80+
def get_fs_method_details(method_name: str) -> Union[Dict, None]:
7281
"""
7382
Get the details of the feature selection method, this function search in all-methods definitions
7483
and get the details of the method with the given name. If the method is not found, it returns None.
@@ -82,3 +91,25 @@ def get_fs_method_details(method_name: str):
8291
if method['name'].lower() == method_name.lower():
8392
return method
8493
return None
94+
95+
def get_fs_univariate_methods() -> List:
96+
"""
97+
Get the list of univariate methods implemented in the library
98+
:return: list
99+
"""
100+
univariate_methods = FS_METHODS['univariate']
101+
univariate_names = [method["name"] for method in univariate_methods["methods"]]
102+
return univariate_names
103+
104+
def is_valid_univariate_method(method_name: str) -> bool:
105+
"""
106+
This method check if the given method name is a supported univariate method
107+
:param method_name method name
108+
:return: boolean
109+
"""
110+
for method in FS_METHODS["univariate"]["methods"]:
111+
if method["name"].lower() == method_name:
112+
return True
113+
return False
114+
115+

fslite/fs/fdataframe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def select_features_by_index(self, feature_indexes: List[int]) -> 'FSDataFrame':
221221
def to_pandas(self) -> DataFrame:
222222
"""
223223
Return the DataFrame representation of the FSDataFrame.
224-
225224
:return: Pandas DataFrame.
226225
"""
227226

fslite/fs/univariate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66
from sklearn.feature_selection import SelectKBest, f_classif, f_regression
77

8+
from fslite.fs.constants import get_fs_univariate_methods, is_valid_univariate_method
89
from fslite.fs.fdataframe import FSDataFrame
910

1011
logging.basicConfig(format="%(levelname)s (%(name)s %(lineno)s): %(message)s")
@@ -100,6 +101,9 @@ def univariate_filter(df: FSDataFrame,
100101
:return: Filtered DataFrame with selected features
101102
"""
102103

104+
if not is_valid_univariate_method(univariate_method):
105+
raise NotImplementedError("The provided method {} is not implented !! please select one from this list {}".format(univariate_method, get_fs_univariate_methods()))
106+
103107
selected_features = []
104108

105109
if univariate_method == 'anova':

fslite/tests/test_univariate_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_univariate_filter_corr():
1616
# create FSDataFrame instance
1717
fs_df = FSDataFrame(df=df,sample_col='Sample',label_col='label')
1818

19-
fsdf_filtered = univariate_filter(fs_df,univariate_method='u_corr', corr_threshold=0.3)
19+
fsdf_filtered = univariate_filter(fs_df, univariate_method='u_corr', corr_threshold=0.3)
2020

2121
assert fs_df.count_features() == 500
2222
assert fsdf_filtered.count_features() == 211

0 commit comments

Comments
 (0)