22from collections .abc import Iterable , Sequence
33from itertools import chain
44from multiprocessing import cpu_count
5+ from queue import Empty
56from sys import stdin
7+ from time import sleep
68from typing import TYPE_CHECKING , Literal
79
10+ import multiprocess
811from cloudpickle import load , loads
912from fsspec .callbacks import DEFAULT_CALLBACK , Callback
10- from multiprocess import get_context
13+ from multiprocess . queues import Queue as MultiprocessQueue
1114
1215from datachain .catalog import Catalog
1316from datachain .catalog .catalog import clone_catalog_with_cache
2528from datachain .utils import batched , flatten , safe_closing
2629
2730if TYPE_CHECKING :
28- import multiprocess
2931 from sqlalchemy import Select , Table
3032
3133 from datachain .data_storage import AbstractMetastore , AbstractWarehouse
@@ -101,8 +103,8 @@ def udf_worker_entrypoint(fd: int | None = None) -> int:
101103
102104class UDFDispatcher :
103105 _catalog : Catalog | None = None
104- task_queue : "multiprocess.Queue | None" = None
105- done_queue : "multiprocess.Queue | None" = None
106+ task_queue : MultiprocessQueue | None = None
107+ done_queue : MultiprocessQueue | None = None
106108
107109 def __init__ (self , udf_info : UdfInfo , buffer_size : int = DEFAULT_BATCH_SIZE ):
108110 self .udf_data = udf_info ["udf_data" ]
@@ -121,7 +123,7 @@ def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
121123 self .buffer_size = buffer_size
122124 self .task_queue = None
123125 self .done_queue = None
124- self .ctx = get_context ("spawn" )
126+ self .ctx = multiprocess . get_context ("spawn" )
125127
126128 @property
127129 def catalog (self ) -> "Catalog" :
@@ -137,6 +139,8 @@ def _create_worker(self) -> "UDFWorker":
137139 udf : UDFAdapter = loads (self .udf_data )
138140 # Ensure all registered DataModels have rebuilt schemas in worker processes.
139141 ModelStore .rebuild_all ()
142+ assert self .task_queue is not None
143+ assert self .done_queue is not None
140144 return UDFWorker (
141145 self .catalog ,
142146 udf ,
@@ -259,8 +263,6 @@ def run_udf_parallel( # noqa: C901, PLR0912
259263 for p in pool :
260264 p .start ()
261265
262- # Will be set to True if all tasks complete normally
263- normal_completion = False
264266 try :
265267 # Will be set to True when the input is exhausted
266268 input_finished = False
@@ -283,10 +285,20 @@ def run_udf_parallel( # noqa: C901, PLR0912
283285
284286 # Process all tasks
285287 while n_workers > 0 :
286- try :
287- result = get_from_queue (self .done_queue )
288- except KeyboardInterrupt :
289- break
288+ while True :
289+ try :
290+ result = self .done_queue .get_nowait ()
291+ break
292+ except Empty :
293+ for p in pool :
294+ exitcode = p .exitcode
295+ if exitcode not in (None , 0 ):
296+ message = (
297+ f"Worker { p .name } exited unexpectedly with "
298+ f"code { exitcode } "
299+ )
300+ raise RuntimeError (message ) from None
301+ sleep (0.01 )
290302
291303 if bytes_downloaded := result .get ("bytes_downloaded" ):
292304 download_cb .relative_update (bytes_downloaded )
@@ -313,39 +325,23 @@ def run_udf_parallel( # noqa: C901, PLR0912
313325 put_into_queue (self .task_queue , next (input_data ))
314326 except StopIteration :
315327 input_finished = True
316-
317- # Finished with all tasks normally
318- normal_completion = True
319328 finally :
320- if not normal_completion :
321- # Stop all workers if there is an unexpected exception
322- for _ in pool :
323- put_into_queue (self .task_queue , STOP_SIGNAL )
324-
325- # This allows workers (and this process) to exit without
326- # consuming any remaining data in the queues.
327- # (If they exit due to an exception.)
328- self .task_queue .close ()
329- self .task_queue .join_thread ()
330-
331- # Flush all items from the done queue.
332- # This is needed if any workers are still running.
333- while n_workers > 0 :
334- result = get_from_queue (self .done_queue )
335- status = result ["status" ]
336- if status != OK_STATUS :
337- n_workers -= 1
338-
339- self .done_queue .close ()
340- self .done_queue .join_thread ()
341-
342- # Wait for workers to stop
329+ for p in pool :
330+ if p .is_alive ():
331+ p .terminate ()
332+
343333 for p in pool :
344334 p .join ()
335+ self .task_queue .cancel_join_thread ()
336+ self .done_queue .cancel_join_thread ()
337+ self .task_queue .close ()
338+ self .done_queue .close ()
339+ self .task_queue = None
340+ self .done_queue = None
345341
346342
347343class DownloadCallback (Callback ):
348- def __init__ (self , queue : "multiprocess.Queue" ) -> None :
344+ def __init__ (self , queue : MultiprocessQueue ) -> None :
349345 self .queue = queue
350346 super ().__init__ ()
351347
@@ -360,7 +356,7 @@ class ProcessedCallback(Callback):
360356 def __init__ (
361357 self ,
362358 name : Literal ["processed" , "generated" ],
363- queue : "multiprocess.Queue" ,
359+ queue : MultiprocessQueue ,
364360 ) -> None :
365361 self .name = name
366362 self .queue = queue
@@ -375,8 +371,8 @@ def __init__(
375371 self ,
376372 catalog : "Catalog" ,
377373 udf : "UDFAdapter" ,
378- task_queue : "multiprocess.Queue" ,
379- done_queue : "multiprocess.Queue" ,
374+ task_queue : MultiprocessQueue ,
375+ done_queue : MultiprocessQueue ,
380376 query : "Select" ,
381377 table : "Table" ,
382378 cache : bool ,
0 commit comments