@@ -50,27 +50,25 @@ def __init__(
50
50
:param parse_features: Coerce all features to float.
51
51
"""
52
52
53
- self .__df = self ._convert_psdf_to_sdf (df )
54
53
self .__sample_col = sample_col
55
54
self .__label_col = label_col
56
- self .__row_index_name = row_index_col
55
+ self .__row_index_col = row_index_col
56
+ self .__df = df
57
57
58
58
# check input dataframe
59
59
self ._check_df ()
60
60
61
61
# replace dots in column names, if any.
62
62
if parse_col_names :
63
- # TODO: Dots in column names are prone to errors, since dots are used to access attributes from DataFrame.
64
- # Should we make this replacement optional? Or print out a warning?
65
63
self .__df = self .__df .toDF (* (c .replace ('.' , '_' ) for c in self .__df .columns ))
66
64
67
65
# If the specified row index column name does not exist, add row index to the dataframe
68
- if self .__row_index_name not in self .__df .columns :
69
- self .__df = self ._add_row_index (index_name = self .__row_index_name )
66
+ if self .__row_index_col not in self .__df .columns :
67
+ self .__df = self ._add_row_index (index_name = self .__row_index_col )
70
68
71
69
if parse_features :
72
70
# coerce all features to float
73
- non_features_cols = [self .__sample_col , self .__label_col , self .__row_index_name ]
71
+ non_features_cols = [self .__sample_col , self .__label_col , self .__row_index_col ]
74
72
feature_cols = [c for c in self .__df .columns if c not in non_features_cols ]
75
73
self .__df = self .__df .withColumns ({c : self .__df [c ].cast ('float' ) for c in feature_cols })
76
74
@@ -87,7 +85,7 @@ def _check_df(self):
87
85
raise ValueError (f"Column sample name { self .__sample_col } not found..." )
88
86
elif self .__label_col not in col_names :
89
87
raise ValueError (f"Column label name { self .__label_col } not found..." )
90
- elif not isinstance (self .__row_index_name , str ):
88
+ elif not isinstance (self .__row_index_col , str ):
91
89
raise ValueError ("Row index column name must be a valid string..." )
92
90
else :
93
91
pass
@@ -97,21 +95,24 @@ def _set_indexed_cols(self) -> Series:
97
95
Create a distributed indexed Series representing features.
98
96
:return: Pandas on (PoS) Series
99
97
"""
100
- non_features_cols = [self .__sample_col , self .__label_col , self .__row_index_name ]
98
+ non_features_cols = [self .__sample_col , self .__label_col , self .__row_index_col ]
101
99
features = [f for f in self .__df .columns if f not in non_features_cols ]
102
100
return Series (features )
103
101
104
- def _set_indexed_rows (self ) -> Series :
102
+ def _set_indexed_rows (self ) -> pd . Series :
105
103
"""
106
- Create a distributed indexed Series representing samples labels.
107
- It will use existing row indices, if any .
104
+ Create an indexed Series representing sample labels.
105
+ It will use existing row indices from the DataFrame .
108
106
109
- :return: Pandas on (PoS) Series
107
+ :return: Pandas Series
110
108
"""
111
- # TODO: Check for equivalent to pandas distributed Series in .
112
- label = self .__df .select (self .__label_col ).collect ()
113
- row_index = self .__df .select (self .__row_index_name ).collect ()
114
- return Series (label , index = row_index )
109
+
110
+ # Extract the label and row index columns from the DataFrame
111
+ labels = self .__df [self .__label_col ]
112
+ row_indices = self .__df [self .__row_index_col ]
113
+
114
+ # Create a Pandas Series with row_indices as index and labels as values
115
+ return pd .Series (data = labels .values , index = row_indices .values )
115
116
116
117
def get_features_indexed (self ) -> Series :
117
118
"""
@@ -223,7 +224,7 @@ def get_row_index_name(self) -> str:
223
224
224
225
:return: Row id column name.
225
226
"""
226
- return self .__row_index_name
227
+ return self .__row_index_col
227
228
228
229
def _add_row_index (self , index_name : str = '_row_index' ) -> pd .DataFrame :
229
230
"""
@@ -276,12 +277,12 @@ def filter_features(self, features: List[str], keep: bool = True) -> 'FSDataFram
276
277
sdf = sdf .select (
277
278
self .__sample_col ,
278
279
self .__label_col ,
279
- self .__row_index_name ,
280
+ self .__row_index_col ,
280
281
* features )
281
282
else :
282
283
sdf = sdf .drop (* features )
283
284
284
- fsdf_filtered = self .update (sdf , self .__sample_col , self .__label_col , self .__row_index_name )
285
+ fsdf_filtered = self .update (sdf , self .__sample_col , self .__label_col , self .__row_index_col )
285
286
count_b = fsdf_filtered .count_features ()
286
287
287
288
logger .info (f"{ count_b } features out of { count_a } remain after applying this filter..." )
0 commit comments