Skip to content

Commit 2a51fcf

Browse files
committed
fix(dispatcher): simplify and fix termination handling
1 parent a72c865 commit 2a51fcf

File tree

3 files changed

+155
-42
lines changed

3 files changed

+155
-42
lines changed

src/datachain/query/dispatch.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
from collections.abc import Iterable, Sequence
33
from itertools import chain
44
from multiprocessing import cpu_count
5+
from queue import Empty
56
from sys import stdin
7+
from time import sleep
68
from typing import TYPE_CHECKING, Literal
79

10+
import multiprocess
811
from cloudpickle import load, loads
912
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
10-
from multiprocess import get_context
13+
from multiprocess.queues import Queue as MultiprocessQueue
1114

1215
from datachain.catalog import Catalog
1316
from datachain.catalog.catalog import clone_catalog_with_cache
@@ -25,7 +28,6 @@
2528
from datachain.utils import batched, flatten, safe_closing
2629

2730
if TYPE_CHECKING:
28-
import multiprocess
2931
from sqlalchemy import Select, Table
3032

3133
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
@@ -101,8 +103,8 @@ def udf_worker_entrypoint(fd: int | None = None) -> int:
101103

102104
class UDFDispatcher:
103105
_catalog: Catalog | None = None
104-
task_queue: "multiprocess.Queue | None" = None
105-
done_queue: "multiprocess.Queue | None" = None
106+
task_queue: MultiprocessQueue | None = None
107+
done_queue: MultiprocessQueue | None = None
106108

107109
def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
108110
self.udf_data = udf_info["udf_data"]
@@ -121,7 +123,7 @@ def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
121123
self.buffer_size = buffer_size
122124
self.task_queue = None
123125
self.done_queue = None
124-
self.ctx = get_context("spawn")
126+
self.ctx = multiprocess.get_context("spawn")
125127

126128
@property
127129
def catalog(self) -> "Catalog":
@@ -137,6 +139,8 @@ def _create_worker(self) -> "UDFWorker":
137139
udf: UDFAdapter = loads(self.udf_data)
138140
# Ensure all registered DataModels have rebuilt schemas in worker processes.
139141
ModelStore.rebuild_all()
142+
assert self.task_queue is not None
143+
assert self.done_queue is not None
140144
return UDFWorker(
141145
self.catalog,
142146
udf,
@@ -259,8 +263,6 @@ def run_udf_parallel( # noqa: C901, PLR0912
259263
for p in pool:
260264
p.start()
261265

262-
# Will be set to True if all tasks complete normally
263-
normal_completion = False
264266
try:
265267
# Will be set to True when the input is exhausted
266268
input_finished = False
@@ -283,10 +285,20 @@ def run_udf_parallel( # noqa: C901, PLR0912
283285

284286
# Process all tasks
285287
while n_workers > 0:
286-
try:
287-
result = get_from_queue(self.done_queue)
288-
except KeyboardInterrupt:
289-
break
288+
while True:
289+
try:
290+
result = self.done_queue.get_nowait()
291+
break
292+
except Empty:
293+
for p in pool:
294+
exitcode = p.exitcode
295+
if exitcode not in (None, 0):
296+
message = (
297+
f"Worker {p.name} exited unexpectedly with "
298+
f"code {exitcode}"
299+
)
300+
raise RuntimeError(message) from None
301+
sleep(0.01)
290302

291303
if bytes_downloaded := result.get("bytes_downloaded"):
292304
download_cb.relative_update(bytes_downloaded)
@@ -313,39 +325,23 @@ def run_udf_parallel( # noqa: C901, PLR0912
313325
put_into_queue(self.task_queue, next(input_data))
314326
except StopIteration:
315327
input_finished = True
316-
317-
# Finished with all tasks normally
318-
normal_completion = True
319328
finally:
320-
if not normal_completion:
321-
# Stop all workers if there is an unexpected exception
322-
for _ in pool:
323-
put_into_queue(self.task_queue, STOP_SIGNAL)
324-
325-
# This allows workers (and this process) to exit without
326-
# consuming any remaining data in the queues.
327-
# (If they exit due to an exception.)
328-
self.task_queue.close()
329-
self.task_queue.join_thread()
330-
331-
# Flush all items from the done queue.
332-
# This is needed if any workers are still running.
333-
while n_workers > 0:
334-
result = get_from_queue(self.done_queue)
335-
status = result["status"]
336-
if status != OK_STATUS:
337-
n_workers -= 1
338-
339-
self.done_queue.close()
340-
self.done_queue.join_thread()
341-
342-
# Wait for workers to stop
329+
for p in pool:
330+
if p.is_alive():
331+
p.terminate()
332+
343333
for p in pool:
344334
p.join()
335+
self.task_queue.cancel_join_thread()
336+
self.done_queue.cancel_join_thread()
337+
self.task_queue.close()
338+
self.done_queue.close()
339+
self.task_queue = None
340+
self.done_queue = None
345341

346342

347343
class DownloadCallback(Callback):
348-
def __init__(self, queue: "multiprocess.Queue") -> None:
344+
def __init__(self, queue: MultiprocessQueue) -> None:
349345
self.queue = queue
350346
super().__init__()
351347

@@ -360,7 +356,7 @@ class ProcessedCallback(Callback):
360356
def __init__(
361357
self,
362358
name: Literal["processed", "generated"],
363-
queue: "multiprocess.Queue",
359+
queue: MultiprocessQueue,
364360
) -> None:
365361
self.name = name
366362
self.queue = queue
@@ -375,8 +371,8 @@ def __init__(
375371
self,
376372
catalog: "Catalog",
377373
udf: "UDFAdapter",
378-
task_queue: "multiprocess.Queue",
379-
done_queue: "multiprocess.Queue",
374+
task_queue: MultiprocessQueue,
375+
done_queue: MultiprocessQueue,
380376
query: "Select",
381377
table: "Table",
382378
cache: bool,

src/datachain/query/queue.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import datetime
22
from collections.abc import Iterable, Iterator
3-
from queue import Empty, Full, Queue
3+
from queue import Empty, Full
44
from struct import pack, unpack
55
from time import sleep
66
from typing import Any
77

88
import msgpack
9+
from multiprocess.queues import Queue
910

1011
from datachain.query.batch import RowsOutput
1112

tests/func/test_udf.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import os
33
import pickle
44
import posixpath
5+
import sys
6+
import time
57

8+
import multiprocess as mp
69
import pytest
710

811
import datachain as dc
@@ -550,6 +553,119 @@ def name_len_error(_name):
550553
chain.show()
551554

552555

556+
@pytest.mark.parametrize(
557+
"failure_mode,expected_exit_code,error_marker",
558+
[
559+
("exception", 1, "Worker 1 failure!"),
560+
("keyboard_interrupt", -2, "KeyboardInterrupt"),
561+
("sys_exit", 1, None),
562+
("os_exit", 1, None), # os._exit - immediate termination
563+
],
564+
)
565+
def test_udf_parallel_worker_failure_exits_peers(
566+
test_session_tmpfile,
567+
tmp_path,
568+
capfd,
569+
failure_mode,
570+
expected_exit_code,
571+
error_marker,
572+
):
573+
"""
574+
Test that when one worker fails, all other workers exit immediately.
575+
576+
Tests different failure modes:
577+
- exception: Worker raises RuntimeError (normal exception)
578+
- keyboard_interrupt: Worker raises KeyboardInterrupt (simulates Ctrl+C)
579+
- sys_exit: Worker calls sys.exit() (clean Python exit)
580+
- os_exit: Worker calls os._exit() (immediate process termination)
581+
"""
582+
import platform
583+
584+
# Windows uses different exit codes for KeyboardInterrupt
585+
# 3221225786 (0xC000013A) is STATUS_CONTROL_C_EXIT on Windows
586+
# while POSIX systems use -2 (SIGINT)
587+
if platform.system() == "Windows" and failure_mode == "keyboard_interrupt":
588+
expected_exit_code = 3221225786
589+
590+
vals = list(range(100))
591+
592+
barrier_dir = tmp_path / "udf_workers_barrier"
593+
barrier_dir_str = str(barrier_dir)
594+
os.makedirs(barrier_dir_str, exist_ok=True)
595+
expected_workers = 3
596+
597+
def slow_process(val: int) -> int:
598+
proc_name = mp.current_process().name
599+
with open(os.path.join(barrier_dir_str, f"{proc_name}.started"), "w") as f:
600+
f.write(str(time.time()))
601+
602+
# Wait until all expected workers have written their markers
603+
deadline = time.time() + 1.0
604+
while time.time() < deadline:
605+
try:
606+
count = len(
607+
[n for n in os.listdir(barrier_dir_str) if n.endswith(".started")]
608+
)
609+
except FileNotFoundError:
610+
count = 0
611+
if count >= expected_workers:
612+
break
613+
time.sleep(0.01)
614+
615+
if proc_name == "Worker-UDF-1":
616+
if failure_mode == "exception":
617+
raise RuntimeError("Worker 1 failure!")
618+
if failure_mode == "keyboard_interrupt":
619+
raise KeyboardInterrupt("Worker interrupted")
620+
if failure_mode == "sys_exit":
621+
sys.exit(1)
622+
if failure_mode == "os_exit":
623+
os._exit(1)
624+
time.sleep(5)
625+
return val * 2
626+
627+
chain = (
628+
dc.read_values(val=vals, session=test_session_tmpfile)
629+
.settings(parallel=3)
630+
.map(slow_process, output={"result": int})
631+
)
632+
633+
start = time.time()
634+
with pytest.raises(RuntimeError, match="UDF Execution Failed!") as exc_info:
635+
list(chain.to_iter("result"))
636+
elapsed = time.time() - start
637+
638+
# Verify timing: should exit immediately when worker fails
639+
assert elapsed < 10, f"took {elapsed:.1f}s, should exit immediately"
640+
641+
# Verify multiple workers were started via barrier markers
642+
try:
643+
started_files = [
644+
n for n in os.listdir(barrier_dir_str) if n.endswith(".started")
645+
]
646+
except FileNotFoundError:
647+
started_files = []
648+
assert len(started_files) == 3, (
649+
f"Expected all 3 workers to start, but saw markers for: {started_files}"
650+
)
651+
652+
captured = capfd.readouterr()
653+
654+
# Verify the RuntimeError has a meaningful message with exit code
655+
error_message = str(exc_info.value)
656+
assert f"UDF Execution Failed! Exit code: {expected_exit_code}" in error_message, (
657+
f"Expected exit code {expected_exit_code}, got: {error_message}"
658+
)
659+
660+
if error_marker:
661+
assert error_marker in captured.err, (
662+
f"Expected '{error_marker}' in stderr for {failure_mode} mode. "
663+
f"stderr output: {captured.err[:500]}"
664+
)
665+
666+
assert "semaphore" not in captured.err
667+
668+
553669
@pytest.mark.parametrize(
554670
"cloud_type,version_aware",
555671
[("s3", True)],

0 commit comments

Comments
 (0)