Skip to content

Commit b99aee0

Browse files
committed
first iteration of pandas fdataframe.py
1 parent 70fec44 commit b99aee0

File tree

4 files changed

+48
-100
lines changed

4 files changed

+48
-100
lines changed

environment.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ dependencies:
1010
- pyspark~=3.3.0
1111
- networkx~=2.8.7
1212
- numpy~=1.23.4
13-
- pandas~=1.5.1
1413
- pyarrow~=8.0.0

fsspark/fs/fdataframe.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,25 @@ def __init__(
5050
:param parse_features: Coerce all features to float.
5151
"""
5252

53-
self.__df = self._convert_psdf_to_sdf(df)
5453
self.__sample_col = sample_col
5554
self.__label_col = label_col
56-
self.__row_index_name = row_index_col
55+
self.__row_index_col = row_index_col
56+
self.__df = df
5757

5858
# check input dataframe
5959
self._check_df()
6060

6161
# replace dots in column names, if any.
6262
if parse_col_names:
63-
# TODO: Dots in column names are prone to errors, since dots are used to access attributes from DataFrame.
64-
# Should we make this replacement optional? Or print out a warning?
6563
self.__df = self.__df.toDF(*(c.replace('.', '_') for c in self.__df.columns))
6664

6765
# If the specified row index column name does not exist, add row index to the dataframe
68-
if self.__row_index_name not in self.__df.columns:
69-
self.__df = self._add_row_index(index_name=self.__row_index_name)
66+
if self.__row_index_col not in self.__df.columns:
67+
self.__df = self._add_row_index(index_name=self.__row_index_col)
7068

7169
if parse_features:
7270
# coerce all features to float
73-
non_features_cols = [self.__sample_col, self.__label_col, self.__row_index_name]
71+
non_features_cols = [self.__sample_col, self.__label_col, self.__row_index_col]
7472
feature_cols = [c for c in self.__df.columns if c not in non_features_cols]
7573
self.__df = self.__df.withColumns({c: self.__df[c].cast('float') for c in feature_cols})
7674

@@ -87,7 +85,7 @@ def _check_df(self):
8785
raise ValueError(f"Column sample name {self.__sample_col} not found...")
8886
elif self.__label_col not in col_names:
8987
raise ValueError(f"Column label name {self.__label_col} not found...")
90-
elif not isinstance(self.__row_index_name, str):
88+
elif not isinstance(self.__row_index_col, str):
9189
raise ValueError("Row index column name must be a valid string...")
9290
else:
9391
pass
@@ -97,21 +95,24 @@ def _set_indexed_cols(self) -> Series:
9795
Create a distributed indexed Series representing features.
9896
:return: Pandas on (PoS) Series
9997
"""
100-
non_features_cols = [self.__sample_col, self.__label_col, self.__row_index_name]
98+
non_features_cols = [self.__sample_col, self.__label_col, self.__row_index_col]
10199
features = [f for f in self.__df.columns if f not in non_features_cols]
102100
return Series(features)
103101

104-
def _set_indexed_rows(self) -> Series:
102+
def _set_indexed_rows(self) -> pd.Series:
105103
"""
106-
Create a distributed indexed Series representing samples labels.
107-
It will use existing row indices, if any.
104+
Create an indexed Series representing sample labels.
105+
It will use existing row indices from the DataFrame.
108106
109-
:return: Pandas on (PoS) Series
107+
:return: Pandas Series
110108
"""
111-
# TODO: Check for equivalent to pandas distributed Series in .
112-
label = self.__df.select(self.__label_col).collect()
113-
row_index = self.__df.select(self.__row_index_name).collect()
114-
return Series(label, index=row_index)
109+
110+
# Extract the label and row index columns from the DataFrame
111+
labels = self.__df[self.__label_col]
112+
row_indices = self.__df[self.__row_index_col]
113+
114+
# Create a Pandas Series with row_indices as index and labels as values
115+
return pd.Series(data=labels.values, index=row_indices.values)
115116

116117
def get_features_indexed(self) -> Series:
117118
"""
@@ -223,7 +224,7 @@ def get_row_index_name(self) -> str:
223224
224225
:return: Row id column name.
225226
"""
226-
return self.__row_index_name
227+
return self.__row_index_col
227228

228229
def _add_row_index(self, index_name: str = '_row_index') -> pd.DataFrame:
229230
"""
@@ -276,12 +277,12 @@ def filter_features(self, features: List[str], keep: bool = True) -> 'FSDataFram
276277
sdf = sdf.select(
277278
self.__sample_col,
278279
self.__label_col,
279-
self.__row_index_name,
280+
self.__row_index_col,
280281
*features)
281282
else:
282283
sdf = sdf.drop(*features)
283284

284-
fsdf_filtered = self.update(sdf, self.__sample_col, self.__label_col, self.__row_index_name)
285+
fsdf_filtered = self.update(sdf, self.__sample_col, self.__label_col, self.__row_index_col)
285286
count_b = fsdf_filtered.count_features()
286287

287288
logger.info(f"{count_b} features out of {count_a} remain after applying this filter...")

fsspark/tests/test_FSDataFrame.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

fsspark/tests/test_fsdataframe.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
import pandas as pd
3+
from fsspark.fs.fdataframe import FSDataFrame
4+
5+
def test_initializes_fsdataframe():
6+
7+
# Create a sample DataFrame
8+
data = {
9+
'sample_id': [1, 2, 3],
10+
'label': ['A', 'B', 'C'],
11+
'feature1': [0.1, 0.2, 0.3],
12+
'feature2': [1.1, 1.2, 1.3]
13+
}
14+
df = pd.DataFrame(data)
15+
16+
# Initialize FSDataFrame
17+
fs_df = FSDataFrame(
18+
df=df,
19+
sample_col='sample_id',
20+
label_col='label',
21+
row_index_col='_row_index',
22+
parse_col_names=False,
23+
parse_features=False
24+
)
25+
26+
# Assertions to check if the initialization is correct
27+
assert (fs_df.get_sdf(), df)

0 commit comments

Comments
 (0)