Skip to content

Commit 8921afa

Browse files
authored
Make generator configurable for stream executor (#59)
This makes it possible to feed some data to the first tasks instead of monotonic batches. This also means the iterable might terminate prior to max_batches.
1 parent 185df12 commit 8921afa

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

dplutils/pipeline/ray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ class RayStreamGraphExecutor(StreamingGraphExecutor):
153153
all pending tasks for ray_poll_timeout seconds. The timeout gives
154154
opportunity to re-evaluate cluster resources in case it has expanded
155155
since last scheduling loop
156+
\*args, \*\*kwargs: These are passed to
157+
:py:class:`StreamingGraphexecutor<dplutils.pipeline.stream.StreamingGraphExecutor>`
156158
"""
157159
def __init__(self, *args, ray_poll_timeout: int = 20, **kwargs):
158160
super().__init__(*args, **kwargs)

dplutils/pipeline/stream.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import networkx as nx
44
from abc import ABC, abstractmethod
55
from collections import deque
6+
from collections.abc import Generator
67
from dataclasses import dataclass, field
7-
from typing import Any
8+
from typing import Any, Callable
89
from dplutils.pipeline import PipelineTask, PipelineExecutor
910
from dplutils.pipeline.utils import deque_extract
1011

@@ -56,6 +57,15 @@ class StreamingGraphExecutor(PipelineExecutor, ABC):
5657
default, for each run, it generates a indefinite stream of input dataframes
5758
tagged with a monotonically incrementing batch id.
5859
60+
Args:
61+
max_batches: maximum number of batches from the source generator to feed
62+
to the input task(s). Default is None, which means either exhaust the
63+
source generator or run indefinitely.
64+
generator: A callable that when called returns a generator which yields
65+
dataframes. The yielded dataframes are assumed to be a single row, in
66+
which case input task batching will be honored.
67+
68+
5969
Implementations must override abstract methods for (remote) task submission
6070
and polling. The following must be overriden, see their docs for more:
6171
@@ -66,17 +76,18 @@ class StreamingGraphExecutor(PipelineExecutor, ABC):
6676
- :meth:`task_submit`
6777
- :meth:`task_submittable`
6878
"""
69-
def __init__(self, graph, max_batches=None):
79+
def __init__(self, graph, max_batches: int=None, generator: Callable[[], Generator[pd.DataFrame, None, None]]=None):
7080
super().__init__(graph)
7181
self.max_batches = max_batches
7282
# make a local copy of the graph with each node wrapped in a tracker
7383
# object
7484
self.stream_graph = nx.relabel_nodes(self.graph, StreamTask)
85+
self.generator_fun = generator or self.source_generator_fun
7586

7687
def execute(self):
7788
self.n_sourced = 0
7889
self.source_exhausted = False
79-
self.source_generator = self.source_generator_fun()
90+
self.source_generator = self.generator_fun()
8091
while True:
8192
batch = self.execute_until_output()
8293
if batch is None:
@@ -120,13 +131,18 @@ def resolve_completed(self):
120131
def process_source(self, source):
121132
source_batch = []
122133
for _ in range(source.task.batch_size or 1):
123-
source_batch.append(next(self.source_generator))
134+
try:
135+
source_batch.append(next(self.source_generator))
136+
except StopIteration:
137+
self.source_exhausted = True
138+
return
124139
self.n_sourced += 1
125140
if self.n_sourced == self.max_batches:
126141
self.source_exhausted = True
127142
break
128143
source.pending.appendleft(self.task_submit(source.task, source_batch))
129144
source.counter += 1
145+
return
130146

131147
def enqueue_tasks(self):
132148
# Work through the graph in reverse order, submitting any tasks as

tests/pipeline/test_stream_executor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
import pandas as pd
13
from dplutils.pipeline import PipelineTask
24
from dplutils.pipeline.stream import LocalSerialExecutor, StreamTask
35
from test_suite import PipelineExecutorTestSuite
@@ -24,3 +26,17 @@ def test_stream_exhausted_indicator_considers_splits(dummy_steps):
2426
assert pl.task_exhausted(a_task)
2527
a_task.split_pending.append(1)
2628
assert not pl.task_exhausted(a_task)
29+
30+
31+
@pytest.mark.parametrize('max_batches', [1,10,None])
32+
def test_stream_executor_generator_override(max_batches):
33+
st = PipelineTask('task_name', lambda x: x)
34+
def generator():
35+
n = 12
36+
for i in range(n):
37+
yield pd.DataFrame({'customgen': [i]})
38+
pl = LocalSerialExecutor([st], max_batches=max_batches, generator=generator)
39+
res = list(pl.run())
40+
expected_rows = max_batches if max_batches else 12
41+
assert len(res) == expected_rows
42+
assert pd.concat(res).customgen.to_list() == list(range(expected_rows))

0 commit comments

Comments
 (0)