3
3
import networkx as nx
4
4
from abc import ABC , abstractmethod
5
5
from collections import deque
6
+ from collections .abc import Generator
6
7
from dataclasses import dataclass , field
7
- from typing import Any
8
+ from typing import Any , Callable
8
9
from dplutils .pipeline import PipelineTask , PipelineExecutor
9
10
from dplutils .pipeline .utils import deque_extract
10
11
@@ -56,6 +57,15 @@ class StreamingGraphExecutor(PipelineExecutor, ABC):
56
57
default, for each run, it generates a indefinite stream of input dataframes
57
58
tagged with a monotonically incrementing batch id.
58
59
60
+ Args:
61
+ max_batches: maximum number of batches from the source generator to feed
62
+ to the input task(s). Default is None, which means either exhaust the
63
+ source generator or run indefinitely.
64
+ generator: A callable that when called returns a generator which yields
65
+ dataframes. The yielded dataframes are assumed to be a single row, in
66
+ which case input task batching will be honored.
67
+
68
+
59
69
Implementations must override abstract methods for (remote) task submission
60
70
and polling. The following must be overriden, see their docs for more:
61
71
@@ -66,17 +76,18 @@ class StreamingGraphExecutor(PipelineExecutor, ABC):
66
76
- :meth:`task_submit`
67
77
- :meth:`task_submittable`
68
78
"""
69
- def __init__ (self , graph , max_batches = None ):
79
+ def __init__ (self , graph , max_batches : int = None , generator : Callable [[], Generator [ pd . DataFrame , None , None ]] = None ):
70
80
super ().__init__ (graph )
71
81
self .max_batches = max_batches
72
82
# make a local copy of the graph with each node wrapped in a tracker
73
83
# object
74
84
self .stream_graph = nx .relabel_nodes (self .graph , StreamTask )
85
+ self .generator_fun = generator or self .source_generator_fun
75
86
76
87
def execute (self ):
77
88
self .n_sourced = 0
78
89
self .source_exhausted = False
79
- self .source_generator = self .source_generator_fun ()
90
+ self .source_generator = self .generator_fun ()
80
91
while True :
81
92
batch = self .execute_until_output ()
82
93
if batch is None :
@@ -120,13 +131,18 @@ def resolve_completed(self):
120
131
def process_source (self , source ):
121
132
source_batch = []
122
133
for _ in range (source .task .batch_size or 1 ):
123
- source_batch .append (next (self .source_generator ))
134
+ try :
135
+ source_batch .append (next (self .source_generator ))
136
+ except StopIteration :
137
+ self .source_exhausted = True
138
+ return
124
139
self .n_sourced += 1
125
140
if self .n_sourced == self .max_batches :
126
141
self .source_exhausted = True
127
142
break
128
143
source .pending .appendleft (self .task_submit (source .task , source_batch ))
129
144
source .counter += 1
145
+ return
130
146
131
147
def enqueue_tasks (self ):
132
148
# Work through the graph in reverse order, submitting any tasks as
0 commit comments