Skip to content

Commit 733e748

Browse files
committed
Use a queue to read batches to catch data starvation.
Read batches into a queue so that we can use the blocking Queue.get() to wait for a batch with a timeout. Using a timeout is useful for catching cases where a dead iterator will lead to data starvation and a non-terminating process. #140
1 parent 73875cd commit 733e748

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

programl/models/model.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616
"""Base class for implementing classifier models."""
1717
import pickle
18+
from queue import Empty, Queue
19+
from threading import Thread
1820
from typing import Any, Dict, Iterable
1921

2022
from programl.models.batch_data import BatchData
@@ -24,6 +26,22 @@
2426
from programl.util.py.progress import NullContext, ProgressContext
2527

2628

29+
class BatchQueue(Thread):
30+
"""A thread which reads batches onto a queue. That is all it does.
31+
c.f. https://youtu.be/X7HmltUWXgs
32+
"""
33+
34+
def __init__(self, batches: Iterable[BatchData], queue: Queue):
35+
super().__init__()
36+
self.batches = batches
37+
self.queue = queue
38+
39+
def run(self) -> None:
40+
for i, batch_data in enumerate(self.batches, start=1):
41+
self.queue.put((i, batch_data))
42+
self.queue.put((None, None))
43+
44+
2745
class Model(object):
2846
"""Abstract base class for implementing classifiers.
2947
@@ -86,12 +104,34 @@ def RunBatches(
86104
self,
87105
epoch_type: epoch_pb2.EpochType,
88106
batches: Iterable[BatchData],
107+
timeout: float = 60,
89108
**rolling_results_builder_opts,
90109
) -> epoch_pb2.EpochResults:
110+
# Read batches into a queue so that we can use the blocking Queue.get()
111+
# to wait for a batch with a timeout. Using a timeout is useful for
112+
# catching cases where a dead iterator will lead to data starvation and
113+
# a non-terminating process.
114+
# See <https://github.com/ChrisCummins/ProGraML/issues/140>.
115+
queue = Queue(maxsize=128)
116+
batches = BatchQueue(batches, queue)
117+
batches.start()
118+
91119
with RollingResultsBuilder(**rolling_results_builder_opts) as results_builder:
92-
for i, batch_data in enumerate(batches):
120+
while True:
121+
try:
122+
i, batch_data = queue.get(timeout=timeout)
123+
except Empty as e:
124+
raise ValueError(
125+
f"Model received no batches within {timeout:.1f}s timeout, "
126+
"did your batch generator die?"
127+
) from e
128+
# Done.
129+
if not batch_data:
130+
break
93131
batch_results = self.RunBatch(epoch_type, batch_data)
94132
results_builder.AddBatch(batch_data, batch_results, weight=None)
133+
134+
batches.join()
95135
return results_builder.results.ToEpochResults()
96136

97137
def RestoreCheckpoint(self, checkpoint: checkpoint_pb2.Checkpoint):

0 commit comments

Comments
 (0)