diff --git a/src/datachain/query/dispatch.py b/src/datachain/query/dispatch.py index 939d5c146..eea0fd11f 100644 --- a/src/datachain/query/dispatch.py +++ b/src/datachain/query/dispatch.py @@ -2,12 +2,16 @@ from collections.abc import Iterable, Sequence from itertools import chain from multiprocessing import cpu_count +from queue import Empty from sys import stdin +from time import monotonic, sleep from typing import TYPE_CHECKING, Literal +import multiprocess from cloudpickle import load, loads from fsspec.callbacks import DEFAULT_CALLBACK, Callback -from multiprocess import get_context +from multiprocess.context import Process +from multiprocess.queues import Queue as MultiprocessQueue from datachain.catalog import Catalog from datachain.catalog.catalog import clone_catalog_with_cache @@ -25,7 +29,6 @@ from datachain.utils import batched, flatten, safe_closing if TYPE_CHECKING: - import multiprocess from sqlalchemy import Select, Table from datachain.data_storage import AbstractMetastore, AbstractWarehouse @@ -101,8 +104,8 @@ def udf_worker_entrypoint(fd: int | None = None) -> int: class UDFDispatcher: _catalog: Catalog | None = None - task_queue: "multiprocess.Queue | None" = None - done_queue: "multiprocess.Queue | None" = None + task_queue: MultiprocessQueue | None = None + done_queue: MultiprocessQueue | None = None def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE): self.udf_data = udf_info["udf_data"] @@ -121,7 +124,7 @@ def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE): self.buffer_size = buffer_size self.task_queue = None self.done_queue = None - self.ctx = get_context("spawn") + self.ctx = multiprocess.get_context("spawn") @property def catalog(self) -> "Catalog": @@ -259,8 +262,6 @@ def run_udf_parallel( # noqa: C901, PLR0912 for p in pool: p.start() - # Will be set to True if all tasks complete normally - normal_completion = False try: # Will be set to True when the input is exhausted input_finished = False @@ -283,10 +284,20 @@ def run_udf_parallel( # noqa: C901, PLR0912 # Process all tasks while n_workers > 0: - try: - result = get_from_queue(self.done_queue) - except KeyboardInterrupt: - break + while True: + try: + result = self.done_queue.get_nowait() + break + except Empty: + for p in pool: + exitcode = p.exitcode + if exitcode not in (None, 0): + message = ( + f"Worker {p.name} exited unexpectedly with " + f"code {exitcode}" + ) + raise RuntimeError(message) from None + sleep(0.01) if bytes_downloaded := result.get("bytes_downloaded"): download_cb.relative_update(bytes_downloaded) @@ -313,39 +324,50 @@ def run_udf_parallel( # noqa: C901, PLR0912 put_into_queue(self.task_queue, next(input_data)) except StopIteration: input_finished = True - - # Finished with all tasks normally - normal_completion = True finally: - if not normal_completion: - # Stop all workers if there is an unexpected exception - for _ in pool: - put_into_queue(self.task_queue, STOP_SIGNAL) - - # This allows workers (and this process) to exit without - # consuming any remaining data in the queues. - # (If they exit due to an exception.) - self.task_queue.close() - self.task_queue.join_thread() - - # Flush all items from the done queue. - # This is needed if any workers are still running. - while n_workers > 0: - result = get_from_queue(self.done_queue) - status = result["status"] - if status != OK_STATUS: - n_workers -= 1 - - self.done_queue.close() - self.done_queue.join_thread() + self._shutdown_workers(pool) + + def _shutdown_workers(self, pool: list[Process]) -> None: + self._terminate_pool(pool) + self._drain_queue(self.done_queue) + self._drain_queue(self.task_queue) + self._close_queue(self.done_queue) + self._close_queue(self.task_queue) + + def _terminate_pool(self, pool: list[Process]) -> None: + for proc in pool: + if proc.is_alive(): + proc.terminate() + + deadline = monotonic() + 1.0 + for proc in pool: + if not proc.is_alive(): + continue + remaining = deadline - monotonic() + if remaining > 0: + proc.join(remaining) + if proc.is_alive(): + proc.kill() + proc.join(timeout=0.2) + + def _drain_queue(self, queue: MultiprocessQueue) -> None: + while True: + try: + queue.get_nowait() + except Empty: + return + except (OSError, ValueError): + return - # Wait for workers to stop - for p in pool: - p.join() + def _close_queue(self, queue: MultiprocessQueue) -> None: + with contextlib.suppress(OSError, ValueError): + queue.close() + with contextlib.suppress(RuntimeError, AssertionError, ValueError): + queue.join_thread() class DownloadCallback(Callback): - def __init__(self, queue: "multiprocess.Queue") -> None: + def __init__(self, queue: MultiprocessQueue) -> None: self.queue = queue super().__init__() @@ -360,7 +382,7 @@ class ProcessedCallback(Callback): def __init__( self, name: Literal["processed", "generated"], - queue: "multiprocess.Queue", + queue: MultiprocessQueue, ) -> None: self.name = name self.queue = queue @@ -375,8 +397,8 @@ def __init__( self, catalog: "Catalog", udf: "UDFAdapter", - task_queue: "multiprocess.Queue", - done_queue: "multiprocess.Queue", + task_queue: MultiprocessQueue, + done_queue: MultiprocessQueue, query: "Select", table: "Table", cache: bool, diff --git a/src/datachain/query/queue.py b/src/datachain/query/queue.py index e5b047e4e..3ac572923 100644 --- a/src/datachain/query/queue.py +++ b/src/datachain/query/queue.py @@ -1,11 +1,12 @@ import datetime from collections.abc import Iterable, Iterator -from queue import Empty, Full, Queue +from queue import Empty, Full from struct import pack, unpack from time import sleep from typing import Any import msgpack +from multiprocess.queues import Queue from datachain.query.batch import RowsOutput diff --git a/tests/func/test_udf.py b/tests/func/test_udf.py index 15b714c04..c5cec3e9b 100644 --- a/tests/func/test_udf.py +++ b/tests/func/test_udf.py @@ -2,7 +2,10 @@ import os import pickle import posixpath +import sys +import time +import multiprocess as mp import pytest import datachain as dc @@ -560,6 +563,119 @@ def name_len_error(_name): chain.show() +@pytest.mark.parametrize( + "failure_mode,expected_exit_code,error_marker", + [ + ("exception", 1, "Worker 1 failure!"), + ("keyboard_interrupt", -2, "KeyboardInterrupt"), + ("sys_exit", 1, None), + ("os_exit", 1, None), # os._exit - immediate termination + ], +) +def test_udf_parallel_worker_failure_exits_peers( + test_session_tmpfile, + tmp_path, + capfd, + failure_mode, + expected_exit_code, + error_marker, +): + """ + Test that when one worker fails, all other workers exit immediately. + + Tests different failure modes: + - exception: Worker raises RuntimeError (normal exception) + - keyboard_interrupt: Worker raises KeyboardInterrupt (simulates Ctrl+C) + - sys_exit: Worker calls sys.exit() (clean Python exit) + - os_exit: Worker calls os._exit() (immediate process termination) + """ + import platform + + # Windows uses different exit codes for KeyboardInterrupt + # 3221225786 (0xC000013A) is STATUS_CONTROL_C_EXIT on Windows + # while POSIX systems use -2 (SIGINT) + if platform.system() == "Windows" and failure_mode == "keyboard_interrupt": + expected_exit_code = 3221225786 + + vals = list(range(100)) + + barrier_dir = tmp_path / "udf_workers_barrier" + barrier_dir_str = str(barrier_dir) + os.makedirs(barrier_dir_str, exist_ok=True) + expected_workers = 3 + + def slow_process(val: int) -> int: + proc_name = mp.current_process().name + with open(os.path.join(barrier_dir_str, f"{proc_name}.started"), "w") as f: + f.write(str(time.time())) + + # Wait until all expected workers have written their markers + deadline = time.time() + 1.0 + while time.time() < deadline: + try: + count = len( + [n for n in os.listdir(barrier_dir_str) if n.endswith(".started")] + ) + except FileNotFoundError: + count = 0 + if count >= expected_workers: + break + time.sleep(0.01) + + if proc_name == "Worker-UDF-1": + if failure_mode == "exception": + raise RuntimeError("Worker 1 failure!") + if failure_mode == "keyboard_interrupt": + raise KeyboardInterrupt("Worker interrupted") + if failure_mode == "sys_exit": + sys.exit(1) + if failure_mode == "os_exit": + os._exit(1) + time.sleep(5) + return val * 2 + + chain = ( + dc.read_values(val=vals, session=test_session_tmpfile) + .settings(parallel=3) + .map(slow_process, output={"result": int}) + ) + + start = time.time() + with pytest.raises(RuntimeError, match="UDF Execution Failed!") as exc_info: + list(chain.to_iter("result")) + elapsed = time.time() - start + + # Verify timing: should exit immediately when worker fails + assert elapsed < 10, f"took {elapsed:.1f}s, should exit immediately" + + # Verify multiple workers were started via barrier markers + try: + started_files = [ + n for n in os.listdir(barrier_dir_str) if n.endswith(".started") + ] + except FileNotFoundError: + started_files = [] + assert len(started_files) == 3, ( + f"Expected all 3 workers to start, but saw markers for: {started_files}" + ) + + captured = capfd.readouterr() + + # Verify the RuntimeError has a meaningful message with exit code + error_message = str(exc_info.value) + assert f"UDF Execution Failed! Exit code: {expected_exit_code}" in error_message, ( + f"Expected exit code {expected_exit_code}, got: {error_message}" + ) + + if error_marker: + assert error_marker in captured.err, ( + f"Expected '{error_marker}' in stderr for {failure_mode} mode. " + f"stderr output: {captured.err[:500]}" + ) + + assert "semaphore" not in captured.err + + @pytest.mark.parametrize( "cloud_type,version_aware", [("s3", True)],