|
15 | 15 | # limitations under the License.
|
16 | 16 | """Base class for implementing classifier models."""
|
17 | 17 | import pickle
|
| 18 | +from queue import Empty, Queue |
| 19 | +from threading import Thread |
18 | 20 | from typing import Any, Dict, Iterable
|
19 | 21 |
|
20 | 22 | from programl.models.batch_data import BatchData
|
|
24 | 26 | from programl.util.py.progress import NullContext, ProgressContext
|
25 | 27 |
|
26 | 28 |
|
| 29 | +class BatchQueue(Thread): |
| 30 | + """A thread which reads batches onto a queue. That is all it does. |
| 31 | + c.f. https://youtu.be/X7HmltUWXgs |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__(self, batches: Iterable[BatchData], queue: Queue): |
| 35 | + super().__init__() |
| 36 | + self.batches = batches |
| 37 | + self.queue = queue |
| 38 | + |
| 39 | + def run(self) -> None: |
| 40 | + for i, batch_data in enumerate(self.batches, start=1): |
| 41 | + self.queue.put((i, batch_data)) |
| 42 | + self.queue.put((None, None)) |
| 43 | + |
| 44 | + |
27 | 45 | class Model(object):
|
28 | 46 | """Abstract base class for implementing classifiers.
|
29 | 47 |
|
@@ -86,12 +104,34 @@ def RunBatches(
|
86 | 104 | self,
|
87 | 105 | epoch_type: epoch_pb2.EpochType,
|
88 | 106 | batches: Iterable[BatchData],
|
| 107 | + timeout: float = 60, |
89 | 108 | **rolling_results_builder_opts,
|
90 | 109 | ) -> epoch_pb2.EpochResults:
|
| 110 | + # Read batches into a queue so that we can use the blocking Queue.get() |
| 111 | + # to wait for a batch with a timeout. Using a timeout is useful for |
| 112 | + # catching cases where a dead iterator will lead to data starvation and |
| 113 | + # a non-terminating process. |
| 114 | + # See <https://github.com/ChrisCummins/ProGraML/issues/140>. |
| 115 | + queue = Queue(maxsize=128) |
| 116 | + batches = BatchQueue(batches, queue) |
| 117 | + batches.start() |
| 118 | + |
91 | 119 | with RollingResultsBuilder(**rolling_results_builder_opts) as results_builder:
|
92 |
| - for i, batch_data in enumerate(batches): |
| 120 | + while True: |
| 121 | + try: |
| 122 | + i, batch_data = queue.get(timeout=timeout) |
| 123 | + except Empty as e: |
| 124 | + raise ValueError( |
| 125 | + f"Model received no batches within {timeout:.1f}s timeout, " |
| 126 | + "did your batch generator die?" |
| 127 | + ) from e |
| 128 | + # Done. |
| 129 | + if not batch_data: |
| 130 | + break |
93 | 131 | batch_results = self.RunBatch(epoch_type, batch_data)
|
94 | 132 | results_builder.AddBatch(batch_data, batch_results, weight=None)
|
| 133 | + |
| 134 | + batches.join() |
95 | 135 | return results_builder.results.ToEpochResults()
|
96 | 136 |
|
97 | 137 | def RestoreCheckpoint(self, checkpoint: checkpoint_pb2.Checkpoint):
|
|
0 commit comments