Skip to content

Commit 30e0659

Browse files
committed
Merge branch 'refactor-py' of https://github.com/bigbio/fsspark into refactor-py
# Conflicts: # fsspark/fs/fdataframe.py
2 parents cc471f0 + b99aee0 commit 30e0659

File tree

4 files changed

+46
-98
lines changed

4 files changed

+46
-98
lines changed

environment.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ dependencies:
1010
- pyspark~=3.3.0
1111
- networkx~=2.8.7
1212
- numpy~=1.23.4
13-
- pandas~=1.5.1
1413
- pyarrow~=8.0.0

fsspark/fs/fdataframe.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,27 +51,25 @@ def __init__(
5151
:param parse_features: Coerce all features to float.
5252
"""
5353

54-
self.__df = df
5554
self.__sample_col = sample_col
5655
self.__label_col = label_col
57-
self.__row_index_name = row_index_col
56+
self.__row_index_col = row_index_col
57+
self.__df = df
5858

5959
# check input dataframe
6060
self._check_df()
6161

6262
# replace dots in column names, if any.
6363
if parse_col_names:
64-
# TODO: Dots in column names are prone to errors, since dots are used to access attributes from DataFrame.
65-
# Should we make this replacement optional? Or print out a warning?
6664
self.__df = self.__df.toDF(*(c.replace('.', '_') for c in self.__df.columns))
6765

6866
# If the specified row index column name does not exist, add row index to the dataframe
69-
if self.__row_index_name not in self.__df.columns:
70-
self.__df = self._add_row_index(index_name=self.__row_index_name)
67+
if self.__row_index_col not in self.__df.columns:
68+
self.__df = self._add_row_index(index_name=self.__row_index_col)
7169

7270
if parse_features:
7371
# coerce all features to float
74-
non_features_cols = [self.__sample_col, self.__label_col, self.__row_index_name]
72+
non_features_cols = [self.__sample_col, self.__label_col, self.__row_index_col]
7573
feature_cols = [c for c in self.__df.columns if c not in non_features_cols]
7674
self.__df = self.__df.withColumns({c: self.__df[c].cast('float') for c in feature_cols})
7775

@@ -88,7 +86,7 @@ def _check_df(self):
8886
raise ValueError(f"Column sample name {self.__sample_col} not found...")
8987
elif self.__label_col not in col_names:
9088
raise ValueError(f"Column label name {self.__label_col} not found...")
91-
elif not isinstance(self.__row_index_name, str):
89+
elif not isinstance(self.__row_index_col, str):
9290
raise ValueError("Row index column name must be a valid string...")
9391
else:
9492
pass
@@ -98,21 +96,24 @@ def _set_indexed_cols(self) -> Series:
9896
Create a distributed indexed Series representing features.
9997
:return: Pandas on (PoS) Series
10098
"""
101-
non_features_cols = [self.__sample_col, self.__label_col, self.__row_index_name]
99+
non_features_cols = [self.__sample_col, self.__label_col, self.__row_index_col]
102100
features = [f for f in self.__df.columns if f not in non_features_cols]
103101
return Series(features)
104102

105-
def _set_indexed_rows(self) -> Series:
103+
def _set_indexed_rows(self) -> pd.Series:
106104
"""
107-
Create a distributed indexed Series representing samples labels.
108-
It will use existing row indices, if any.
105+
Create an indexed Series representing sample labels.
106+
It will use existing row indices from the DataFrame.
109107
110108
:return: Pandas Series
111109
"""
112110

113-
label = self.__df[self.__label_col]
114-
row_index = self.__df[self.__row_index_name]
115-
return pd.Series(data=label.values, index=row_index.values)
111+
# Extract the label and row index columns from the DataFrame
112+
labels = self.__df[self.__label_col]
113+
row_indices = self.__df[self.__row_index_col]
114+
115+
# Create a Pandas Series with row_indices as index and labels as values
116+
return pd.Series(data=labels.values, index=row_indices.values)
116117

117118
def get_features_indexed(self) -> Series:
118119
"""
@@ -224,7 +225,7 @@ def get_row_index_name(self) -> str:
224225
225226
:return: Row id column name.
226227
"""
227-
return self.__row_index_name
228+
return self.__row_index_col
228229

229230
def _add_row_index(self, index_name: str = '_row_index') -> pd.DataFrame:
230231
"""
@@ -277,12 +278,12 @@ def filter_features(self, features: List[str], keep: bool = True) -> 'FSDataFram
277278
sdf = sdf.select(
278279
self.__sample_col,
279280
self.__label_col,
280-
self.__row_index_name,
281+
self.__row_index_col,
281282
*features)
282283
else:
283284
sdf = sdf.drop(*features)
284285

285-
fsdf_filtered = self.update(sdf, self.__sample_col, self.__label_col, self.__row_index_name)
286+
fsdf_filtered = self.update(sdf, self.__sample_col, self.__label_col, self.__row_index_col)
286287
count_b = fsdf_filtered.count_features()
287288

288289
logger.info(f"{count_b} features out of {count_a} remain after applying this filter...")

fsspark/tests/test_FSDataFrame.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

fsspark/tests/test_fsdataframe.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
import pandas as pd
3+
from fsspark.fs.fdataframe import FSDataFrame
4+
5+
def test_initializes_fsdataframe():
6+
7+
# Create a sample DataFrame
8+
data = {
9+
'sample_id': [1, 2, 3],
10+
'label': ['A', 'B', 'C'],
11+
'feature1': [0.1, 0.2, 0.3],
12+
'feature2': [1.1, 1.2, 1.3]
13+
}
14+
df = pd.DataFrame(data)
15+
16+
# Initialize FSDataFrame
17+
fs_df = FSDataFrame(
18+
df=df,
19+
sample_col='sample_id',
20+
label_col='label',
21+
row_index_col='_row_index',
22+
parse_col_names=False,
23+
parse_features=False
24+
)
25+
26+
# Assertions to check if the initialization is correct
27+
assert (fs_df.get_sdf(), df)

0 commit comments

Comments
 (0)