Skip to content

Commit a8a8f32

Browse files
authored
Skip zero-length batches in stream executor (#90)
Zero length data batches risk starting tasks and serializing data unnecessarily, consuming resources for batches that ultimately would not be expected to yield any data. Resolves #62
1 parent 8e08dbe commit a8a8f32

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

dplutils/pipeline/stream.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def resolve_completed(self):
124124
for task in self.stream_graph.walk_fwd():
125125
for ready in deque_extract(task.pending, self.is_task_ready):
126126
block_info = self.task_resolve_output(ready)
127+
if block_info.length == 0:
128+
continue
127129
if task in self.stream_graph.sink_tasks:
128130
return OutputBatch(block_info.data, task=task.name)
129131
else:

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def dummy_steps():
4646
]
4747

4848

49+
@pytest.fixture
50+
def blackhole_step():
51+
return [PipelineTask("blackhole", func=lambda x: pd.DataFrame([]))]
52+
53+
4954
@pytest.fixture
5055
def dummy_pipeline_graph():
5156
t1 = PipelineTask("task1", lambda x: x.assign(t1="1"))

tests/pipeline/test_suite.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from dplutils.pipeline import OutputBatch
5+
from dplutils.pipeline.task import PipelineTask
56

67

78
@pytest.mark.parametrize("max_batches", (1, 4, 10))
@@ -22,6 +23,29 @@ def test_run_simple_pipeline_iterator(self, dummy_steps, max_batches):
2223
assert isinstance(res.data, pd.DataFrame)
2324
assert set(res.data.columns).issuperset({"id", "step1", "step2"})
2425

26+
def test_pipeline_throws_away_empty_batches(self, blackhole_step, max_batches, tmp_path):
27+
# Flag to ensure that subsequent tasks are not called on empty batches. This is useful
28+
# since the check of yielded results only checks output of pipeline, we want to ensure tasks
29+
# don't needlessly get called as well. Flag implemented with file so it is more portable
30+
# for subprocess calls (e.g. used in ray)
31+
flag_file = tmp_path / "called.flag"
32+
33+
def set_counter(x):
34+
flag_file.write_text("")
35+
return x
36+
37+
# first ensure operation of our counter instrument
38+
pl = self.executor([PipelineTask("nocalls", set_counter)], max_batches=max_batches)
39+
res = list(pl.run())
40+
assert len(res) == max_batches
41+
assert flag_file.exists()
42+
# now ensure we toss empty dataframes
43+
flag_file.unlink()
44+
pl = self.executor([*blackhole_step, PipelineTask("nocalls", set_counter)], max_batches=max_batches)
45+
res = list(pl.run())
46+
assert len(res) == 0
47+
assert not flag_file.exists()
48+
2549
def test_run_dag_pipeline(self, dummy_pipeline_graph, max_batches):
2650
pl = self.executor(dummy_pipeline_graph, max_batches=max_batches)
2751
it = pl.run()

0 commit comments

Comments
 (0)