diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index e392b27c8..42a585851 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -2,7 +2,7 @@ import warnings from html import escape -from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional +from typing import TYPE_CHECKING, Any, Callable, Hashable, Literal, Mapping, Optional from warnings import warn import numpy as np @@ -19,6 +19,7 @@ _calculate_edge_node_difference, _calculate_grad_on_edge_from_faces, ) +from uxarray.constants import GRID_DIMS from uxarray.core.utils import _map_dims_to_ugrid from uxarray.core.zonal import _compute_non_conservative_zonal_mean from uxarray.cross_sections import UxDataArrayCrossSectionAccessor @@ -1229,7 +1230,6 @@ def isel( ValueError If more than one grid dimension is selected and `ignore_grid=False`. """ - from uxarray.constants import GRID_DIMS from uxarray.core.dataarray import UxDataArray # merge dict‐style + kw‐style indexers @@ -1424,6 +1424,121 @@ def get_dual(self): return uxda + + def neighborhood_filter( + self, + func: Callable = np.mean, + r: float = 1.0, + ) -> UxDataArray: + """Apply neighborhood filter + Parameters: + ----------- + func: Callable, default=np.mean + Apply this function to neighborhood + r : float, default=1. + Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. + Returns: + -------- + destination_data : np.ndarray + Filtered data. + """ + + if self._face_centered(): + data_mapping = "face centers" + elif self._node_centered(): + data_mapping = "nodes" + elif self._edge_centered(): + data_mapping = "edge centers" + else: + raise ValueError( + "Data_mapping is not face, node, or edge. Could not define data_mapping." + ) + + # reconstruct because the cached tree could be built from + # face centers, edge centers or nodes. + tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True) + + coordinate_system = tree.coordinate_system + + if coordinate_system == "spherical": + if data_mapping == "nodes": + lon, lat = ( + self.uxgrid.node_lon.values, + self.uxgrid.node_lat.values, + ) + elif data_mapping == "face centers": + lon, lat = ( + self.uxgrid.face_lon.values, + self.uxgrid.face_lat.values, + ) + elif data_mapping == "edge centers": + lon, lat = ( + self.uxgrid.edge_lon.values, + self.uxgrid.edge_lat.values, + ) + else: + raise ValueError( + f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {data_mapping}" + ) + + dest_coords = np.vstack((lon, lat)).T + + elif coordinate_system == "cartesian": + if data_mapping == "nodes": + x, y, z = ( + self.uxgrid.node_x.values, + self.uxgrid.node_y.values, + self.uxgrid.node_z.values, + ) + elif data_mapping == "face centers": + x, y, z = ( + self.uxgrid.face_x.values, + self.uxgrid.face_y.values, + self.uxgrid.face_z.values, + ) + elif data_mapping == "edge centers": + x, y, z = ( + self.uxgrid.edge_x.values, + self.uxgrid.edge_y.values, + self.uxgrid.edge_z.values, + ) + else: + raise ValueError( + f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {data_mapping}" + ) + + dest_coords = np.vstack((x, y, z)).T + + else: + raise ValueError( + f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}" + ) + + neighbor_indices = tree.query_radius(dest_coords, r=r) + + # Construct numpy array for filtered variable. + destination_data = np.empty(self.data.shape) + + # Assert last dimension is a GRID dimension. + assert self.dims[-1] in GRID_DIMS, ( + f"expected last dimension of uxDataArray {self.data.dims[-1]} " + f"to be one of {GRID_DIMS}" + ) + # Apply function to indices on last axis. + for i, idx in enumerate(neighbor_indices): + if len(idx): + destination_data[..., i] = func(self.data[..., idx]) + + # Construct UxDataArray for filtered variable. + uxda_filter = self._copy() + + uxda_filter.data = destination_data + + return uxda_filter + def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False): return UxDataArray(super().where(cond, other, drop), uxgrid=self.uxgrid) diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index b887cbe06..a95ad0c30 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -3,7 +3,7 @@ import os import sys from html import escape -from typing import IO, Any, Optional, Union +from typing import IO, Any, Callable, Optional, Union from warnings import warn import numpy as np @@ -13,6 +13,7 @@ from xarray.core.utils import UncachedAccessor import uxarray +from uxarray.constants import GRID_DIMS from uxarray.core.dataarray import UxDataArray from uxarray.core.utils import _map_dims_to_ugrid from uxarray.formatting_html import dataset_repr @@ -443,6 +444,42 @@ def to_array(self) -> UxDataArray: xarr = super().to_array() return UxDataArray(xarr, uxgrid=self.uxgrid) + + def neighborhood_filter( + self, + func: Callable = np.mean, + r: float = 1.0, + ): + """Neighborhood function implementation for ``UxDataset``. + Parameters + --------- + func : Callable = np.mean + Apply this function to neighborhood + r : float, default=1. + Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. + """ + + destination_uxds = self._copy() + # Loop through uxDataArrays in uxDataset + for var_name in self.data_vars: + uxda = self[var_name] + + # Skip if uxDataArray has no GRID dimension. + grid_dims = [dim for dim in uxda.dims if dim in GRID_DIMS] + if len(grid_dims) == 0: + continue + + # Put GRID dimension last for UxDataArray.neighborhood_filter. + remember_dim_order = uxda.dims + uxda = uxda.transpose(..., grid_dims[0]) + # Filter uxDataArray. + uxda = uxda.neighborhood_filter(func, r) + # Restore old dimension order. + destination_uxds[var_name] = uxda.transpose(*remember_dim_order) + + return destination_uxds + def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset: """ Converts a ``ux.UXDataset`` to a ``xr.Dataset``. @@ -464,6 +501,7 @@ def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset: return xr.Dataset(self) + def get_dual(self): """Compute the dual mesh for a dataset, returns a new dataset object.