Skip to content

Commit 8e08dbe

Browse files
authored
Fix dataframe splitting numpy issue (#91)
Resolves #87
1 parent 2c4572a commit 8e08dbe

File tree

4 files changed

+63
-5
lines changed

4 files changed

+63
-5
lines changed

dplutils/pipeline/ray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dplutils import observer
1111
from dplutils.pipeline import OutputBatch, PipelineExecutor, PipelineTask
1212
from dplutils.pipeline.stream import StreamBatch, StreamingGraphExecutor
13+
from dplutils.pipeline.utils import split_dataframe
1314

1415

1516
def set_run_id(inputs, run_id):
@@ -31,7 +32,7 @@ def wrapper(indf):
3132
if task.batch_size is None:
3233
return funcwrapper(indf, **kwargs)
3334

34-
splits = np.array_split(indf, np.ceil(len(indf) / task.batch_size))
35+
splits = split_dataframe(indf, max_rows=task.batch_size)
3536
refs = [
3637
ray.remote(funcwrapper)
3738
.options(
@@ -123,7 +124,7 @@ def wrapper(*df_list):
123124

124125

125126
def stream_split_func(df, splits):
126-
splits = np.array_split(df, splits)
127+
splits = split_dataframe(df, num_splits=splits)
127128
return [len(i) for i in splits] + splits
128129

129130

dplutils/pipeline/stream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010

1111
from dplutils.pipeline import OutputBatch, PipelineExecutor, PipelineTask
12-
from dplutils.pipeline.utils import deque_extract
12+
from dplutils.pipeline.utils import deque_extract, split_dataframe
1313

1414

1515
@dataclass
@@ -310,7 +310,7 @@ def task_submit(self, pt, df_list):
310310

311311
def split_batch_submit(self, stream_batch, max_rows):
312312
df = stream_batch.data
313-
return np.array_split(df, np.ceil(len(df) / max_rows))
313+
return split_dataframe(df, max_rows=max_rows)
314314

315315
def task_resolve_output(self, to):
316316
if isinstance(to, list):

dplutils/pipeline/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import numpy as np
2+
3+
14
def dict_from_coord(coord, value):
25
d = {}
36
if "." in coord:
@@ -16,3 +19,19 @@ def deque_extract(queue, condition):
1619
yield queue.pop()
1720
else:
1821
queue.rotate()
22+
23+
24+
def split_dataframe(df, max_rows=None, num_splits=None):
25+
"""Split dataframe by max number of rows or given splits.
26+
27+
If num_splits is provided, ignore max_rows and produce a set of num_splits
28+
splits with each dataframe having approximately equal number of rows (one may be
29+
different in size).
30+
31+
If max_rows provided, produce the smallest number of splits that ensure no
32+
single split has more than max_rows
33+
"""
34+
if num_splits is None:
35+
num_splits = np.ceil(len(df) / max_rows)
36+
chunks = np.linspace(0, num_splits, num=len(df), endpoint=False, dtype=np.int32)
37+
return [chunk for _, chunk in df.groupby(chunks)]
Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,43 @@
1-
from dplutils.pipeline.utils import dict_from_coord
1+
from collections import defaultdict
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from dplutils.pipeline.utils import dict_from_coord, split_dataframe
27

38

49
def test_dict_from_coord():
510
assert dict_from_coord("a.b.c", "value") == {"a": {"b": {"c": "value"}}}
11+
12+
13+
@pytest.mark.parametrize(
14+
"df_len, max_rows, num_splits, expected_splits",
15+
[
16+
(100, 20, None, 5),
17+
(101, 20, None, 6),
18+
(101, 21, None, 5),
19+
(10, 1, None, 10),
20+
(10, 5, None, 2),
21+
(10, 10, None, 1),
22+
(100, None, 5, 5),
23+
(101, None, 5, 5),
24+
(101, 50, 5, 5), # this should prefer num_splits, not max
25+
],
26+
)
27+
def test_split_dataframe(df_len, max_rows, num_splits, expected_splits):
28+
df = pd.DataFrame({"a": range(df_len), "b": range(df_len)})
29+
splits = split_dataframe(df, max_rows=max_rows, num_splits=num_splits)
30+
assert len(splits) == expected_splits
31+
assert sum(len(i) for i in splits) == df_len
32+
if max_rows and not num_splits:
33+
assert max(len(i) for i in splits) <= max_rows
34+
for s in splits:
35+
assert isinstance(s, pd.DataFrame)
36+
assert list(s.columns) == ["a", "b"]
37+
# all but one should be same length
38+
len_map = defaultdict(int)
39+
for s in splits:
40+
len_map[len(s)] += 1
41+
assert len(len_map.keys()) <= 2
42+
if len(len_map) > 1:
43+
assert 1 in len_map.values()

0 commit comments

Comments
 (0)