Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ Bug Fixes
:py:func:`open_dataset` is called with a non-existent local file path
(:issue:`10896`).
By `Kristian Kollsgård <https://github.com/kkollsga>`_.
- Fix a regression where :py:func:`open_mfdataset` could hang indefinitely with
``engine="h5netcdf"`` and ``parallel=True`` on distributed schedulers when
opening file-like objects from remote filesystems (:issue:`10807`).

Documentation
~~~~~~~~~~~~~
Expand Down
28 changes: 18 additions & 10 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def _multi_file_closer(closers):
closer()


def _preprocess_mfdataset(
ds: Dataset, preprocess: Callable[[Dataset], Dataset]
) -> Dataset:
# Preserve the underlying file closer if preprocess returns a Dataset without one.
# This keeps resource cleanup reliable while allowing arbitrary preprocess functions.
processed_ds = preprocess(ds)
if processed_ds._close is None and ds._close is not None:
processed_ds.set_close(ds._close)
return processed_ds


def load_dataset(filename_or_obj: T_PathFileOrDataStore, **kwargs) -> Dataset:
"""Open, load into memory, and close a Dataset from a file or file-like
object.
Expand Down Expand Up @@ -1617,14 +1628,10 @@ class (a subclass of ``BackendEntrypoint``) can also be used.
if parallel:
import dask

# wrap the open_dataset, getattr, and preprocess with delayed
# wrap open_dataset and preprocess with delayed
open_ = dask.delayed(open_dataset)
getattr_ = dask.delayed(getattr)
if preprocess is not None:
preprocess = dask.delayed(preprocess)
else:
open_ = open_dataset
getattr_ = getattr

if errors not in ("raise", "warn", "ignore"):
raise ValueError(
Expand Down Expand Up @@ -1652,14 +1659,15 @@ class (a subclass of ``BackendEntrypoint``) can also be used.
combined_ids_paths = _infer_concat_order_from_positions(paths)
ids = list(combined_ids_paths.keys())

closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]
datasets = [_preprocess_mfdataset(ds, preprocess) for ds in datasets]

if parallel:
# calling compute here will return the datasets/file_objs lists,
# the underlying datasets will still be stored as dask arrays
datasets, closers = dask.compute(datasets, closers)
# calling compute here will return the list of datasets; the underlying
# arrays remain lazy and dask-backed.
(datasets,) = dask.compute(datasets)

closers = [ds._close for ds in datasets if ds._close is not None]

# Combine all datasets, closing them in case of a ValueError
try:
Expand Down
5 changes: 5 additions & 0 deletions xarray/backends/file_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import atexit
import pickle
import threading
import uuid
import warnings
Expand Down Expand Up @@ -422,6 +423,10 @@ def __del__(self) -> None:
def __getstate__(self):
# file is intentionally omitted: we want to open it again
opener = _get_none if self._closed else self._opener
# Fail fast if opener arguments are not serializable. Without this guard,
# distributed execution can block while attempting to serialize delayed
# tasks that capture unpickleable file-like handles.
pickle.dumps((self._args, self._kwargs))
return (opener, self._args, self._mode, self._lock, self._kwargs)

def __setstate__(self, state) -> None:
Expand Down
48 changes: 41 additions & 7 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import io
import os
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any, Self

import numpy as np
Expand Down Expand Up @@ -232,12 +232,25 @@ def open(
else:
lock = combine_locks([HDF5_LOCK, get_write_lock(filename)])

manager_cls = (
CachingFileManager
if isinstance(filename, str) and not is_remote_uri(filename)
else PickleableFileManager
)
manager = manager_cls(h5netcdf.File, filename, mode=mode, kwargs=kwargs)
manager: FileManager[Any]
if isinstance(filename, str) and not is_remote_uri(filename):
manager = CachingFileManager(
h5netcdf.File, filename, mode=mode, kwargs=kwargs
)
elif mode == "r" and _is_fsspec_file_obj(filename):
# Reopen fsspec-backed files from fs/path instead of serializing a live
# file handle across distributed workers.
manager = PickleableFileManager(
_open_h5netcdf_from_fsspec,
filename.fs,
filename.path,
mode=mode,
kwargs={"h5netcdf_kwargs": kwargs},
)
else:
manager = PickleableFileManager(
h5netcdf.File, filename, mode=mode, kwargs=kwargs
)

return cls(
manager,
Expand Down Expand Up @@ -465,6 +478,27 @@ def _normalize_filename_or_obj(
return _normalize_path(filename_or_obj)


def _is_fsspec_file_obj(obj: Any) -> bool:
fs = getattr(obj, "fs", None)
path = getattr(obj, "path", None)
return fs is not None and path is not None and callable(getattr(fs, "open", None))


def _open_h5netcdf_from_fsspec(
fs: Any,
path: str,
*,
mode: str = "r",
h5netcdf_kwargs: Mapping[str, Any] | None = None,
):
import h5netcdf

file_mode = "rb" if mode == "r" else mode
file_obj = fs.open(path, mode=file_mode)
kwargs = {} if h5netcdf_kwargs is None else dict(h5netcdf_kwargs)
return h5netcdf.File(file_obj, mode=mode, **kwargs)


class H5netcdfBackendEntrypoint(BackendEntrypoint):
"""
Backend for netCDF files based on the h5netcdf package.
Expand Down
61 changes: 61 additions & 0 deletions xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import io
import pickle
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -40,6 +41,7 @@
has_netCDF4,
has_scipy,
requires_cftime,
requires_h5netcdf,
requires_netCDF4,
requires_zarr,
)
Expand Down Expand Up @@ -176,6 +178,65 @@ def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path):
assert_identical(tf["test"], da)


@requires_h5netcdf
def test_open_mfdataset_file_like_parallel_distributed_h5netcdf():
time = np.arange(20)
x = np.arange(4)
data = np.arange(80).reshape(20, 4)
da = xr.DataArray(data, coords={"time": time, "x": x}, dims=("time", "x"), name="v")

file_objs = []
for i in range(0, 20, 10):
ds = da.isel(time=slice(i, i + 10)).to_dataset()
file_content = ds.to_netcdf(engine="h5netcdf")
file_objs.append(io.BytesIO(file_content))

with cluster() as (s, [_a, _b]):
with Client(s["address"]):
with xr.open_mfdataset(
file_objs,
engine="h5netcdf",
parallel=True,
concat_dim="time",
combine="nested",
) as tf:
assert_identical(tf["v"], da)


@requires_h5netcdf
@pytest.mark.timeout(30)
def test_open_mfdataset_parallel_distributed_h5netcdf_fsspec_file_objects(tmp_path):
fsspec = pytest.importorskip("fsspec")

time = np.arange(20)
x = np.arange(4)
data = np.arange(80).reshape(20, 4)
da = xr.DataArray(data, coords={"time": time, "x": x}, dims=("time", "x"), name="v")

paths = []
for i in range(0, 20, 10):
path = tmp_path / f"chunk_{i}.nc"
da.isel(time=slice(i, i + 10)).to_dataset().to_netcdf(path, engine="h5netcdf")
paths.append(path)

fs = fsspec.filesystem("file")
with fs.open(str(paths[0]), "rb") as f0, fs.open(str(paths[1]), "rb") as f1:
# Regression test for GH10807:
# the buggy implementation builds delayed tasks that serialize closer
# callables extracted from worker-side datasets. With h5netcdf and
# file-like inputs this can block indefinitely on distributed.
with cluster() as (s, [_a, _b]):
with Client(s["address"]):
with xr.open_mfdataset(
[f0, f1],
engine="h5netcdf",
parallel=True,
concat_dim="time",
combine="nested",
) as tf:
assert_identical(tf["v"], da)


# TODO: move this to test_backends.py
@requires_cftime
@requires_netCDF4
Expand Down
Loading