@@ -51,27 +51,25 @@ def __init__(
51
51
:param parse_features: Coerce all features to float.
52
52
"""
53
53
54
- self .__df = df
55
54
self .__sample_col = sample_col
56
55
self .__label_col = label_col
57
- self .__row_index_name = row_index_col
56
+ self .__row_index_col = row_index_col
57
+ self .__df = df
58
58
59
59
# check input dataframe
60
60
self ._check_df ()
61
61
62
62
# replace dots in column names, if any.
63
63
if parse_col_names :
64
- # TODO: Dots in column names are prone to errors, since dots are used to access attributes from DataFrame.
65
- # Should we make this replacement optional? Or print out a warning?
66
64
self .__df = self .__df .toDF (* (c .replace ('.' , '_' ) for c in self .__df .columns ))
67
65
68
66
# If the specified row index column name does not exist, add row index to the dataframe
69
- if self .__row_index_name not in self .__df .columns :
70
- self .__df = self ._add_row_index (index_name = self .__row_index_name )
67
+ if self .__row_index_col not in self .__df .columns :
68
+ self .__df = self ._add_row_index (index_name = self .__row_index_col )
71
69
72
70
if parse_features :
73
71
# coerce all features to float
74
- non_features_cols = [self .__sample_col , self .__label_col , self .__row_index_name ]
72
+ non_features_cols = [self .__sample_col , self .__label_col , self .__row_index_col ]
75
73
feature_cols = [c for c in self .__df .columns if c not in non_features_cols ]
76
74
self .__df = self .__df .withColumns ({c : self .__df [c ].cast ('float' ) for c in feature_cols })
77
75
@@ -88,7 +86,7 @@ def _check_df(self):
88
86
raise ValueError (f"Column sample name { self .__sample_col } not found..." )
89
87
elif self .__label_col not in col_names :
90
88
raise ValueError (f"Column label name { self .__label_col } not found..." )
91
- elif not isinstance (self .__row_index_name , str ):
89
+ elif not isinstance (self .__row_index_col , str ):
92
90
raise ValueError ("Row index column name must be a valid string..." )
93
91
else :
94
92
pass
@@ -98,21 +96,24 @@ def _set_indexed_cols(self) -> Series:
98
96
Create a distributed indexed Series representing features.
99
97
:return: Pandas on (PoS) Series
100
98
"""
101
- non_features_cols = [self .__sample_col , self .__label_col , self .__row_index_name ]
99
+ non_features_cols = [self .__sample_col , self .__label_col , self .__row_index_col ]
102
100
features = [f for f in self .__df .columns if f not in non_features_cols ]
103
101
return Series (features )
104
102
105
- def _set_indexed_rows (self ) -> Series :
103
+ def _set_indexed_rows (self ) -> pd . Series :
106
104
"""
107
- Create a distributed indexed Series representing samples labels.
108
- It will use existing row indices, if any .
105
+ Create an indexed Series representing sample labels.
106
+ It will use existing row indices from the DataFrame .
109
107
110
108
:return: Pandas Series
111
109
"""
112
110
113
- label = self .__df [self .__label_col ]
114
- row_index = self .__df [self .__row_index_name ]
115
- return pd .Series (data = label .values , index = row_index .values )
111
+ # Extract the label and row index columns from the DataFrame
112
+ labels = self .__df [self .__label_col ]
113
+ row_indices = self .__df [self .__row_index_col ]
114
+
115
+ # Create a Pandas Series with row_indices as index and labels as values
116
+ return pd .Series (data = labels .values , index = row_indices .values )
116
117
117
118
def get_features_indexed (self ) -> Series :
118
119
"""
@@ -224,7 +225,7 @@ def get_row_index_name(self) -> str:
224
225
225
226
:return: Row id column name.
226
227
"""
227
- return self .__row_index_name
228
+ return self .__row_index_col
228
229
229
230
def _add_row_index (self , index_name : str = '_row_index' ) -> pd .DataFrame :
230
231
"""
@@ -277,12 +278,12 @@ def filter_features(self, features: List[str], keep: bool = True) -> 'FSDataFram
277
278
sdf = sdf .select (
278
279
self .__sample_col ,
279
280
self .__label_col ,
280
- self .__row_index_name ,
281
+ self .__row_index_col ,
281
282
* features )
282
283
else :
283
284
sdf = sdf .drop (* features )
284
285
285
- fsdf_filtered = self .update (sdf , self .__sample_col , self .__label_col , self .__row_index_name )
286
+ fsdf_filtered = self .update (sdf , self .__sample_col , self .__label_col , self .__row_index_col )
286
287
count_b = fsdf_filtered .count_features ()
287
288
288
289
logger .info (f"{ count_b } features out of { count_a } remain after applying this filter..." )
0 commit comments