-
Notifications
You must be signed in to change notification settings - Fork 309
Make NetCDF file cache handling compatible with dask distributed #2822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7f6a8d4
6d31c20
1e26d1a
be40c5b
cbd00f0
af4ee66
dad3b14
fc58ca4
09c821a
4f9c5ed
ec76fa6
06d8811
aaf91b9
a2ad42f
9126bbe
5e576f9
63e7507
ea04595
523671a
fde3896
5b137e8
c2b1533
9fce5a7
4993b65
7c173e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,16 +17,17 @@ | |
# satpy. If not, see <http://www.gnu.org/licenses/>. | ||
"""Helpers for reading netcdf-based files.""" | ||
|
||
import functools | ||
import logging | ||
import warnings | ||
|
||
import dask.array as da | ||
import netCDF4 | ||
import numpy as np | ||
import xarray as xr | ||
|
||
from satpy.readers import open_file_or_filename | ||
from satpy.readers.file_handlers import BaseFileHandler | ||
from satpy.readers.utils import np2str | ||
from satpy.readers.utils import get_serializable_dask_array, np2str | ||
from satpy.utils import get_legacy_chunk_size | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
@@ -85,10 +86,12 @@ class NetCDF4FileHandler(BaseFileHandler): | |
xarray_kwargs (dict): Addition arguments to `xarray.open_dataset` | ||
cache_var_size (int): Cache variables smaller than this size. | ||
cache_handle (bool): Keep files open for lifetime of filehandler. | ||
Uses xarray.backends.CachingFileManager, which uses a least | ||
recently used cache. | ||
|
||
""" | ||
|
||
file_handle = None | ||
manager = None | ||
|
||
def __init__(self, filename, filename_info, filetype_info, | ||
auto_maskandscale=False, xarray_kwargs=None, | ||
|
@@ -99,14 +102,22 @@ def __init__(self, filename, filename_info, filetype_info, | |
self.file_content = {} | ||
self.cached_file_content = {} | ||
self._use_h5netcdf = False | ||
try: | ||
file_handle = self._get_file_handle() | ||
except IOError: | ||
LOG.exception( | ||
"Failed reading file %s. Possibly corrupted file", self.filename) | ||
raise | ||
self._auto_maskandscale = auto_maskandscale | ||
if cache_handle: | ||
self.manager = xr.backends.CachingFileManager( | ||
functools.partial(_nc_dataset_wrapper, | ||
auto_maskandscale=auto_maskandscale), | ||
self.filename, mode="r") | ||
file_handle = self.manager.acquire() | ||
else: | ||
try: | ||
file_handle = self._get_file_handle() | ||
except IOError: | ||
LOG.exception( | ||
"Failed reading file %s. Possibly corrupted file", self.filename) | ||
raise | ||
|
||
self._set_file_handle_auto_maskandscale(file_handle, auto_maskandscale) | ||
self._set_file_handle_auto_maskandscale(file_handle, auto_maskandscale) | ||
self._set_xarray_kwargs(xarray_kwargs, auto_maskandscale) | ||
|
||
listed_variables = filetype_info.get("required_netcdf_variables") | ||
|
@@ -117,14 +128,22 @@ def __init__(self, filename, filename_info, filetype_info, | |
self.collect_dimensions("", file_handle) | ||
self.collect_cache_vars(cache_var_size) | ||
|
||
if cache_handle: | ||
self.file_handle = file_handle | ||
else: | ||
if not cache_handle: | ||
file_handle.close() | ||
|
||
def _get_file_handle(self): | ||
return netCDF4.Dataset(self.filename, "r") | ||
|
||
@property | ||
def file_handle(self): | ||
"""Backward-compatible way for file handle caching.""" | ||
warnings.warn( | ||
"attribute .file_handle is deprecated, use .manager instead", | ||
DeprecationWarning) | ||
if self.manager is None: | ||
return None | ||
return self.manager.acquire() | ||
|
||
@staticmethod | ||
def _set_file_handle_auto_maskandscale(file_handle, auto_maskandscale): | ||
if hasattr(file_handle, "set_auto_maskandscale"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not that this has to be handled in your PR, but if I remember correctly this |
||
|
@@ -196,11 +215,8 @@ def _get_required_variable_names(listed_variables, variable_name_replacements): | |
|
||
def __del__(self): | ||
"""Delete the file handler.""" | ||
if self.file_handle is not None: | ||
try: | ||
self.file_handle.close() | ||
except RuntimeError: # presumably closed already | ||
pass | ||
if self.manager is not None: | ||
self.manager.close() | ||
djhoese marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _collect_global_attrs(self, obj): | ||
"""Collect all the global attributes for the provided file object.""" | ||
|
@@ -289,8 +305,8 @@ def _get_variable(self, key, val): | |
group, key = parts | ||
else: | ||
group = None | ||
if self.file_handle is not None: | ||
val = self._get_var_from_filehandle(group, key) | ||
if self.manager is not None: | ||
val = self._get_var_from_manager(group, key) | ||
else: | ||
val = self._get_var_from_xr(group, key) | ||
return val | ||
|
@@ -319,18 +335,27 @@ def _get_var_from_xr(self, group, key): | |
val.load() | ||
return val | ||
|
||
def _get_var_from_filehandle(self, group, key): | ||
def _get_var_from_manager(self, group, key): | ||
# Not getting coordinates as this is more work, therefore more | ||
# overhead, and those are not used downstream. | ||
|
||
with self.manager.acquire_context() as ds: | ||
if group is not None: | ||
v = ds[group][key] | ||
else: | ||
v = ds[key] | ||
if group is None: | ||
g = self.file_handle | ||
dv = get_serializable_dask_array( | ||
self.manager, key, | ||
chunks=v.shape, dtype=v.dtype) | ||
else: | ||
g = self.file_handle[group] | ||
v = g[key] | ||
dv = get_serializable_dask_array( | ||
self.manager, "/".join([group, key]), | ||
chunks=v.shape, dtype=v.dtype) | ||
attrs = self._get_object_attrs(v) | ||
x = xr.DataArray( | ||
da.from_array(v), dims=v.dimensions, attrs=attrs, | ||
name=v.name) | ||
dv, | ||
dims=v.dimensions, attrs=attrs, name=v.name) | ||
return x | ||
|
||
def __contains__(self, item): | ||
|
@@ -443,3 +468,15 @@ def _get_attr(self, obj, key): | |
if self._use_h5netcdf: | ||
return obj.attrs[key] | ||
return super()._get_attr(obj, key) | ||
|
||
def _nc_dataset_wrapper(*args, auto_maskandscale, **kwargs): | ||
"""Wrap netcdf4.Dataset setting auto_maskandscale globally. | ||
|
||
Helper function that wraps netcdf4.Dataset while setting extra parameters. | ||
By encapsulating this in a helper function, we can | ||
pass it to CachingFileManager directly. Currently sets | ||
auto_maskandscale globally (for all variables). | ||
""" | ||
nc = netCDF4.Dataset(*args, **kwargs) | ||
nc.set_auto_maskandscale(auto_maskandscale) | ||
return nc |
Uh oh!
There was an error while loading. Please reload this page.