Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 64 additions & 42 deletions src/datachain/query/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from collections.abc import Iterable, Sequence
from itertools import chain
from multiprocessing import cpu_count
from queue import Empty
from sys import stdin
from time import monotonic, sleep
from typing import TYPE_CHECKING, Literal

import multiprocess
from cloudpickle import load, loads
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from multiprocess import get_context
from multiprocess.context import Process
from multiprocess.queues import Queue as MultiprocessQueue

from datachain.catalog import Catalog
from datachain.catalog.catalog import clone_catalog_with_cache
Expand All @@ -25,7 +29,6 @@
from datachain.utils import batched, flatten, safe_closing

if TYPE_CHECKING:
import multiprocess
from sqlalchemy import Select, Table

from datachain.data_storage import AbstractMetastore, AbstractWarehouse
Expand Down Expand Up @@ -101,8 +104,8 @@ def udf_worker_entrypoint(fd: int | None = None) -> int:

class UDFDispatcher:
_catalog: Catalog | None = None
task_queue: "multiprocess.Queue | None" = None
done_queue: "multiprocess.Queue | None" = None
task_queue: MultiprocessQueue | None = None
done_queue: MultiprocessQueue | None = None
Comment on lines +107 to +108
Copy link

Copilot AI Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding assertions in the run_udf_parallel method to verify these queues are initialized before creating workers, as suggested in the PR description.

Copilot uses AI. Check for mistakes.


def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
self.udf_data = udf_info["udf_data"]
Expand All @@ -121,7 +124,7 @@ def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
self.buffer_size = buffer_size
self.task_queue = None
self.done_queue = None
self.ctx = get_context("spawn")
self.ctx = multiprocess.get_context("spawn")

@property
def catalog(self) -> "Catalog":
Expand Down Expand Up @@ -259,8 +262,6 @@ def run_udf_parallel( # noqa: C901, PLR0912
for p in pool:
p.start()

# Will be set to True if all tasks complete normally
normal_completion = False
try:
# Will be set to True when the input is exhausted
input_finished = False
Expand All @@ -283,10 +284,20 @@ def run_udf_parallel( # noqa: C901, PLR0912

# Process all tasks
while n_workers > 0:
try:
result = get_from_queue(self.done_queue)
except KeyboardInterrupt:
break
while True:
try:
result = self.done_queue.get_nowait()
break
except Empty:
for p in pool:
exitcode = p.exitcode
if exitcode not in (None, 0):
message = (
f"Worker {p.name} exited unexpectedly with "
f"code {exitcode}"
)
raise RuntimeError(message) from None
sleep(0.01)

if bytes_downloaded := result.get("bytes_downloaded"):
download_cb.relative_update(bytes_downloaded)
Expand All @@ -313,39 +324,50 @@ def run_udf_parallel( # noqa: C901, PLR0912
put_into_queue(self.task_queue, next(input_data))
except StopIteration:
input_finished = True

# Finished with all tasks normally
normal_completion = True
finally:
if not normal_completion:
# Stop all workers if there is an unexpected exception
for _ in pool:
put_into_queue(self.task_queue, STOP_SIGNAL)

# This allows workers (and this process) to exit without
# consuming any remaining data in the queues.
# (If they exit due to an exception.)
self.task_queue.close()
self.task_queue.join_thread()

# Flush all items from the done queue.
# This is needed if any workers are still running.
while n_workers > 0:
result = get_from_queue(self.done_queue)
status = result["status"]
if status != OK_STATUS:
n_workers -= 1

self.done_queue.close()
self.done_queue.join_thread()
self._shutdown_workers(pool)

def _shutdown_workers(self, pool: list[Process]) -> None:
self._terminate_pool(pool)
self._drain_queue(self.done_queue)
self._drain_queue(self.task_queue)
self._close_queue(self.done_queue)
self._close_queue(self.task_queue)

def _terminate_pool(self, pool: list[Process]) -> None:
for proc in pool:
if proc.is_alive():
proc.terminate()

deadline = monotonic() + 1.0
for proc in pool:
if not proc.is_alive():
continue
remaining = deadline - monotonic()
if remaining > 0:
proc.join(remaining)
if proc.is_alive():
proc.kill()
proc.join(timeout=0.2)

def _drain_queue(self, queue: MultiprocessQueue) -> None:
while True:
try:
queue.get_nowait()
except Empty:
return
except (OSError, ValueError):
return

# Wait for workers to stop
for p in pool:
p.join()
def _close_queue(self, queue: MultiprocessQueue) -> None:
with contextlib.suppress(OSError, ValueError):
queue.close()
with contextlib.suppress(RuntimeError, AssertionError, ValueError):
queue.join_thread()


class DownloadCallback(Callback):
def __init__(self, queue: "multiprocess.Queue") -> None:
def __init__(self, queue: MultiprocessQueue) -> None:
self.queue = queue
super().__init__()

Expand All @@ -360,7 +382,7 @@ class ProcessedCallback(Callback):
def __init__(
self,
name: Literal["processed", "generated"],
queue: "multiprocess.Queue",
queue: MultiprocessQueue,
) -> None:
self.name = name
self.queue = queue
Expand All @@ -375,8 +397,8 @@ def __init__(
self,
catalog: "Catalog",
udf: "UDFAdapter",
task_queue: "multiprocess.Queue",
done_queue: "multiprocess.Queue",
task_queue: MultiprocessQueue,
done_queue: MultiprocessQueue,
query: "Select",
table: "Table",
cache: bool,
Expand Down
3 changes: 2 additions & 1 deletion src/datachain/query/queue.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import datetime
from collections.abc import Iterable, Iterator
from queue import Empty, Full, Queue
from queue import Empty, Full
from struct import pack, unpack
from time import sleep
from typing import Any

import msgpack
from multiprocess.queues import Queue

from datachain.query.batch import RowsOutput

Expand Down
116 changes: 116 additions & 0 deletions tests/func/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import os
import pickle
import posixpath
import sys
import time

import multiprocess as mp
import pytest

import datachain as dc
Expand Down Expand Up @@ -560,6 +563,119 @@ def name_len_error(_name):
chain.show()


@pytest.mark.parametrize(
"failure_mode,expected_exit_code,error_marker",
[
("exception", 1, "Worker 1 failure!"),
("keyboard_interrupt", -2, "KeyboardInterrupt"),
("sys_exit", 1, None),
("os_exit", 1, None), # os._exit - immediate termination
],
)
def test_udf_parallel_worker_failure_exits_peers(
test_session_tmpfile,
tmp_path,
capfd,
failure_mode,
expected_exit_code,
error_marker,
):
"""
Test that when one worker fails, all other workers exit immediately.

Tests different failure modes:
- exception: Worker raises RuntimeError (normal exception)
- keyboard_interrupt: Worker raises KeyboardInterrupt (simulates Ctrl+C)
- sys_exit: Worker calls sys.exit() (clean Python exit)
- os_exit: Worker calls os._exit() (immediate process termination)
"""
import platform

# Windows uses different exit codes for KeyboardInterrupt
# 3221225786 (0xC000013A) is STATUS_CONTROL_C_EXIT on Windows
# while POSIX systems use -2 (SIGINT)
if platform.system() == "Windows" and failure_mode == "keyboard_interrupt":
expected_exit_code = 3221225786

vals = list(range(100))

barrier_dir = tmp_path / "udf_workers_barrier"
barrier_dir_str = str(barrier_dir)
os.makedirs(barrier_dir_str, exist_ok=True)
expected_workers = 3

def slow_process(val: int) -> int:
proc_name = mp.current_process().name
with open(os.path.join(barrier_dir_str, f"{proc_name}.started"), "w") as f:
f.write(str(time.time()))

# Wait until all expected workers have written their markers
deadline = time.time() + 1.0
while time.time() < deadline:
try:
count = len(
[n for n in os.listdir(barrier_dir_str) if n.endswith(".started")]
)
except FileNotFoundError:
count = 0
if count >= expected_workers:
break
time.sleep(0.01)

if proc_name == "Worker-UDF-1":
if failure_mode == "exception":
raise RuntimeError("Worker 1 failure!")
if failure_mode == "keyboard_interrupt":
raise KeyboardInterrupt("Worker interrupted")
if failure_mode == "sys_exit":
sys.exit(1)
if failure_mode == "os_exit":
os._exit(1)
time.sleep(5)
return val * 2

chain = (
dc.read_values(val=vals, session=test_session_tmpfile)
.settings(parallel=3)
.map(slow_process, output={"result": int})
)

start = time.time()
with pytest.raises(RuntimeError, match="UDF Execution Failed!") as exc_info:
list(chain.to_iter("result"))
elapsed = time.time() - start

# Verify timing: should exit immediately when worker fails
assert elapsed < 10, f"took {elapsed:.1f}s, should exit immediately"

# Verify multiple workers were started via barrier markers
try:
started_files = [
n for n in os.listdir(barrier_dir_str) if n.endswith(".started")
]
except FileNotFoundError:
started_files = []
assert len(started_files) == 3, (
f"Expected all 3 workers to start, but saw markers for: {started_files}"
)

captured = capfd.readouterr()

# Verify the RuntimeError has a meaningful message with exit code
error_message = str(exc_info.value)
assert f"UDF Execution Failed! Exit code: {expected_exit_code}" in error_message, (
f"Expected exit code {expected_exit_code}, got: {error_message}"
)

if error_marker:
assert error_marker in captured.err, (
f"Expected '{error_marker}' in stderr for {failure_mode} mode. "
f"stderr output: {captured.err[:500]}"
)

assert "semaphore" not in captured.err


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down