Skip to content

Commit 05aa2e0

Browse files
authored
Fix bug in source generation for input batch size > 1 (#75)
Resolves #74
1 parent 7ca5ef8 commit 05aa2e0

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

dplutils/pipeline/stream.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def process_source(self, source):
135135
source_batch.append(next(self.source_generator))
136136
except StopIteration:
137137
self.source_exhausted = True
138-
return
138+
if len(source_batch) == 0:
139+
return
140+
break
139141
self.n_sourced += 1
140142
if self.n_sourced == self.max_batches:
141143
self.source_exhausted = True

tests/pipeline/test_stream_executor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,20 @@ def generator():
4040
expected_rows = max_batches if max_batches else 12
4141
assert len(res) == expected_rows
4242
assert pd.concat(res).customgen.to_list() == list(range(expected_rows))
43+
44+
45+
@pytest.mark.parametrize('gen_n', [0, 4])
46+
def test_stream_executor_exhausts_input_when_source_batchsize_larger_than_input(gen_n):
47+
st = PipelineTask('task_name', lambda x: x, batch_size=10)
48+
def generator():
49+
n = gen_n
50+
for i in range(n):
51+
yield pd.DataFrame({'customgen': [i]})
52+
pl = LocalSerialExecutor([st], generator=generator)
53+
res = [i.data for i in pl.run()]
54+
if gen_n > 0:
55+
assert len(res) == 1
56+
assert len(res[0]) == gen_n
57+
else:
58+
# do not submit empty source batches
59+
assert len(res) == 0

0 commit comments

Comments
 (0)