@@ -65,8 +65,10 @@ class StreamingGraphExecutor(PipelineExecutor, ABC):
65
65
to the input task(s). Default is None, which means either exhaust the
66
66
source generator or run indefinitely.
67
67
generator: A callable that when called returns a generator which yields
68
- dataframes. The yielded dataframes are assumed to be a single row, in
69
- which case input task batching will be honored.
68
+ dataframes. The driver will call ``len()`` on the yielded dataframes to
69
+ obtain the number of rows and will split and batch according to the input
70
+ task settings. Each generated dataframe, regardless of size, counts as
71
+ a single source batch with respect to ``max_batches``.
70
72
71
73
72
74
Implementations must override abstract methods for (remote) task submission
@@ -136,39 +138,41 @@ def resolve_completed(self):
136
138
task .data_in .extendleft (self .task_resolve_output (ready ))
137
139
return None
138
140
139
- def process_source (self , source ):
140
- source_batch = []
141
- for _ in range (source .task .batch_size or 1 ):
141
+ def _feed_source (self , source ):
142
+ if self .source_exhausted :
143
+ return
144
+ total_length = sum (i .length for i in source .data_in )
145
+ while total_length < (source .task .batch_size or 1 ):
142
146
try :
143
- source_batch . append ( next (self .source_generator ) )
147
+ next_df = next (self .source_generator )
144
148
except StopIteration :
145
149
self .source_exhausted = True
146
- if len (source_batch ) == 0 :
147
- return
148
150
break
151
+ # We feed any generated source to all source tasks similar the way
152
+ # upstream forked outputs broadcast. We add to data_in so that any
153
+ # necessary batching and splitting can be handled by normal procedure.
154
+ for task in self .stream_graph .source_tasks :
155
+ task .data_in .append (StreamBatch (data = next_df , length = len (next_df )))
149
156
self .n_sourced += 1
150
157
if self .n_sourced == self .max_batches :
151
158
self .source_exhausted = True
152
159
break
153
- source .pending .appendleft (self .task_submit (source .task , source_batch ))
154
- source .counter += 1
155
- return
160
+ total_length += len (next_df )
156
161
157
162
def enqueue_tasks (self ):
158
163
# Work through the graph in reverse order, submitting any tasks as
159
164
# needed. Reverse order ensures we prefer to send tasks that are closer
160
165
# to the end of the pipeline and only feed as necessary.
161
- rank = 0
162
- for task in self . stream_graph . walk_back ( sort_key = lambda x : x . counter ):
163
- if task in self . stream_graph . source_tasks or len (task .data_in ) == 0 :
164
- continue
166
+ def _handle_one_task ( task , rank ):
167
+ eligible = submitted = False
168
+ if len (task .data_in ) == 0 :
169
+ return ( eligible , submitted )
165
170
166
171
batch_size = task .task .batch_size
167
172
if batch_size is not None :
168
173
for batch in deque_extract (task .data_in , lambda b : b .length > batch_size ):
169
174
task .split_pending .appendleft (self .split_batch_submit (batch , batch_size ))
170
175
171
- eligible = False
172
176
while len (task .data_in ) > 0 :
173
177
num_to_merge = deque_num_merge (task .data_in , batch_size )
174
178
if num_to_merge == 0 :
@@ -184,18 +188,29 @@ def enqueue_tasks(self):
184
188
merged = [task .data_in .pop ().data for i in range (num_to_merge )]
185
189
task .pending .appendleft (self .task_submit (task .task , merged ))
186
190
task .counter += 1
191
+ submitted = True
192
+ return (eligible , submitted )
187
193
194
+ # proceed through all non-source tasks, which will be handled separately
195
+ # below due to the need to feed from generator.
196
+ rank = 0
197
+ for task in self .stream_graph .walk_back (sort_key = lambda x : x .counter ):
198
+ if task in self .stream_graph .source_tasks :
199
+ continue
200
+ eligible , _ = _handle_one_task (task , rank )
188
201
if eligible : # update rank of this task if it _could_ be done, whether or not it was
189
202
rank += 1
190
203
191
- # in least-run order try to enqueue as many of the source tasks as can fit
204
+ # Source as many inputs as can fit on source tasks. We prioritize flushing the
205
+ # input queue and secondarily on number of invocations in case batch sizes differ.
192
206
while True :
193
207
task_scheduled = False
194
- for source in sorted (self .stream_graph .source_tasks , key = lambda x : x .counter ):
195
- if self .source_exhausted or not self .task_submittable (source .task , rank ):
196
- continue
197
- self .process_source (source )
198
- task_scheduled = True
208
+ for task in sorted (self .stream_graph .source_tasks , key = lambda x : (- len (x .data_in ), x .counter )):
209
+ if self .task_submittable (task .task , rank ):
210
+ self ._feed_source (task )
211
+ _ , task_scheduled = _handle_one_task (task , rank )
212
+ if task_scheduled : # we want to re-evalute the sort order
213
+ break
199
214
if not task_scheduled :
200
215
break
201
216
0 commit comments