Skip to content

Commit e63e7a3

Browse files
authored
Fix streaming input for batch and multiple sources (#99)
This fixes the behavior of multiple source tasks to match that of a branched task (broadcast outputs to children). It also adds support for generated input of length > 1 to be handled as expected included possibly splitting the batch.
1 parent f4a82f2 commit e63e7a3

File tree

4 files changed

+73
-22
lines changed

4 files changed

+73
-22
lines changed

dplutils/pipeline/stream.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ class StreamingGraphExecutor(PipelineExecutor, ABC):
6565
to the input task(s). Default is None, which means either exhaust the
6666
source generator or run indefinitely.
6767
generator: A callable that when called returns a generator which yields
68-
dataframes. The yielded dataframes are assumed to be a single row, in
69-
which case input task batching will be honored.
68+
dataframes. The driver will call ``len()`` on the yielded dataframes to
69+
obtain the number of rows and will split and batch according to the input
70+
task settings. Each generated dataframe, regardless of size, counts as
71+
a single source batch with respect to ``max_batches``.
7072
7173
7274
Implementations must override abstract methods for (remote) task submission
@@ -136,39 +138,41 @@ def resolve_completed(self):
136138
task.data_in.extendleft(self.task_resolve_output(ready))
137139
return None
138140

139-
def process_source(self, source):
140-
source_batch = []
141-
for _ in range(source.task.batch_size or 1):
141+
def _feed_source(self, source):
142+
if self.source_exhausted:
143+
return
144+
total_length = sum(i.length for i in source.data_in)
145+
while total_length < (source.task.batch_size or 1):
142146
try:
143-
source_batch.append(next(self.source_generator))
147+
next_df = next(self.source_generator)
144148
except StopIteration:
145149
self.source_exhausted = True
146-
if len(source_batch) == 0:
147-
return
148150
break
151+
# We feed any generated source to all source tasks similar the way
152+
# upstream forked outputs broadcast. We add to data_in so that any
153+
# necessary batching and splitting can be handled by normal procedure.
154+
for task in self.stream_graph.source_tasks:
155+
task.data_in.append(StreamBatch(data=next_df, length=len(next_df)))
149156
self.n_sourced += 1
150157
if self.n_sourced == self.max_batches:
151158
self.source_exhausted = True
152159
break
153-
source.pending.appendleft(self.task_submit(source.task, source_batch))
154-
source.counter += 1
155-
return
160+
total_length += len(next_df)
156161

157162
def enqueue_tasks(self):
158163
# Work through the graph in reverse order, submitting any tasks as
159164
# needed. Reverse order ensures we prefer to send tasks that are closer
160165
# to the end of the pipeline and only feed as necessary.
161-
rank = 0
162-
for task in self.stream_graph.walk_back(sort_key=lambda x: x.counter):
163-
if task in self.stream_graph.source_tasks or len(task.data_in) == 0:
164-
continue
166+
def _handle_one_task(task, rank):
167+
eligible = submitted = False
168+
if len(task.data_in) == 0:
169+
return (eligible, submitted)
165170

166171
batch_size = task.task.batch_size
167172
if batch_size is not None:
168173
for batch in deque_extract(task.data_in, lambda b: b.length > batch_size):
169174
task.split_pending.appendleft(self.split_batch_submit(batch, batch_size))
170175

171-
eligible = False
172176
while len(task.data_in) > 0:
173177
num_to_merge = deque_num_merge(task.data_in, batch_size)
174178
if num_to_merge == 0:
@@ -184,18 +188,29 @@ def enqueue_tasks(self):
184188
merged = [task.data_in.pop().data for i in range(num_to_merge)]
185189
task.pending.appendleft(self.task_submit(task.task, merged))
186190
task.counter += 1
191+
submitted = True
192+
return (eligible, submitted)
187193

194+
# proceed through all non-source tasks, which will be handled separately
195+
# below due to the need to feed from generator.
196+
rank = 0
197+
for task in self.stream_graph.walk_back(sort_key=lambda x: x.counter):
198+
if task in self.stream_graph.source_tasks:
199+
continue
200+
eligible, _ = _handle_one_task(task, rank)
188201
if eligible: # update rank of this task if it _could_ be done, whether or not it was
189202
rank += 1
190203

191-
# in least-run order try to enqueue as many of the source tasks as can fit
204+
# Source as many inputs as can fit on source tasks. We prioritize flushing the
205+
# input queue and secondarily on number of invocations in case batch sizes differ.
192206
while True:
193207
task_scheduled = False
194-
for source in sorted(self.stream_graph.source_tasks, key=lambda x: x.counter):
195-
if self.source_exhausted or not self.task_submittable(source.task, rank):
196-
continue
197-
self.process_source(source)
198-
task_scheduled = True
208+
for task in sorted(self.stream_graph.source_tasks, key=lambda x: (-len(x.data_in), x.counter)):
209+
if self.task_submittable(task.task, rank):
210+
self._feed_source(task)
211+
_, task_scheduled = _handle_one_task(task, rank)
212+
if task_scheduled: # we want to re-evalute the sort order
213+
break
199214
if not task_scheduled:
200215
break
201216

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ def multi_output_graph():
6868
return PipelineGraph([(t1, t2A), (t1, t2B)])
6969

7070

71+
@pytest.fixture
72+
def multi_source_graph():
73+
t1A = PipelineTask("srcA", lambda x: x.assign(a="A"))
74+
t1B = PipelineTask("srcB", lambda x: x.assign(b="B"))
75+
t2 = PipelineTask("sink", lambda x: x.assign(sink="sink"))
76+
return PipelineGraph([(t1A, t2), (t1B, t2)])
77+
78+
7179
@pytest.fixture
7280
def graph_suite(dummy_steps, dummy_pipeline_graph, multi_output_graph):
7381
return {

tests/pipeline/test_stream_executor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,18 @@ def generator():
6262
else:
6363
# do not submit empty source batches
6464
assert len(res) == 0
65+
66+
67+
def test_stream_executor_input_batch_size_splits(dummy_steps):
68+
def generator():
69+
for i in range(4):
70+
yield pd.DataFrame({"col": range(2)})
71+
72+
# sanity check to ensure below test is actually inspect the split action
73+
pl = LocalSerialExecutor(dummy_steps, generator=generator)
74+
res = [i.data for i in pl.run()]
75+
assert len(res) == 4
76+
# explicitly set batch size so we should call split on each input
77+
pl = LocalSerialExecutor(dummy_steps, generator=generator).set_config("task1.batch_size", 1)
78+
res = [i.data for i in pl.run()]
79+
assert len(res) == 8

tests/pipeline/test_suite.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from math import ceil
2+
13
import pandas as pd
24
import pytest
35

@@ -105,3 +107,14 @@ def test_with_merge_batch(self, dummy_steps, max_batches):
105107
assert all([len(i.data) == expected_len for i in res])
106108
final = pd.concat([i.data for i in res])
107109
assert final["id"].nunique() == max_batches
110+
111+
def test_all_sources_get_each_input_and_batch(self, multi_source_graph, max_batches):
112+
pl = self.executor(multi_source_graph, max_batches=max_batches).set_config("srcA.batch_size", 4)
113+
res = list(pl.run())
114+
expected_len = max_batches + ceil(max_batches / 4)
115+
assert len(res) == expected_len
116+
# now ensure that each source task got same set of batches
117+
final = pd.concat([i.data for i in res])
118+
idset_a = set(final[final.a == "A"].id)
119+
idset_b = set(final[final.b == "B"].id)
120+
assert idset_a == idset_b == set(range(max_batches))

0 commit comments

Comments
 (0)