diff --git a/uxarray/remap/apply_func.py b/uxarray/remap/apply_func.py new file mode 100644 index 000000000..e5e914856 --- /dev/null +++ b/uxarray/remap/apply_func.py @@ -0,0 +1,251 @@ +from __future__ import annotations +from collections.abc import Callable +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from uxarray.core.dataset import UxDataset + from uxarray.core.dataarray import UxDataArray + +import numpy as np + +import uxarray.core.dataarray +import uxarray.core.dataset +from uxarray.grid import Grid +import warnings + + +def _apply_func_remap( + source_grid: Grid, + destination_grid: Grid, + source_data: np.ndarray, + remap_to: str = "face centers", + coord_type: str = "spherical", + func: Callable = np.mean, + r: float = 1.0, +) -> np.array: + """Apply neighborhood function Remapping between two grids. + + Parameters: + ----------- + source_grid : Grid + Source grid that data is mapped from. + destination_grid : Grid + Destination grid to remap data to. + source_data : np.ndarray + Data variable to remap. + remap_to : str, default="nodes" + Location of where to map data, either "nodes", "edge centers", or "face centers". + coord_type: str, default="spherical" + Coordinate type to use for nearest neighbor query, either "spherical" or "Cartesian". + r : float, default=1. + radius of neighborhoodFor spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. + + Returns: + -------- + destination_data : np.ndarray + Data mapped to the destination grid. + """ + + source_data = np.asarray(source_data) + n_elements = source_data.shape[-1] + + if n_elements == source_grid.n_node: + source_data_mapping = "nodes" + elif n_elements == source_grid.n_face: + source_data_mapping = "face centers" + elif n_elements == source_grid.n_edge: + source_data_mapping = "edge centers" + else: + raise ValueError( + f"Invalid source_data shape. The final dimension should match the number of corner " + f"nodes ({source_grid.n_node}), edge nodes ({source_grid.n_edge}), or face centers ({source_grid.n_face}) " + f"in the source grid, but received: {source_data.shape}" + ) + + if coord_type == "spherical": + if remap_to == "nodes": + lon, lat = ( + destination_grid.node_lon.values, + destination_grid.node_lat.values, + ) + elif remap_to == "face centers": + lon, lat = ( + destination_grid.face_lon.values, + destination_grid.face_lat.values, + ) + elif remap_to == "edge centers": + lon, lat = ( + destination_grid.edge_lon.values, + destination_grid.edge_lat.values, + ) + else: + raise ValueError( + f"Invalid remap_to. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {remap_to}" + ) + + _source_tree = source_grid.get_ball_tree(coordinates=source_data_mapping) + + dest_coords = np.vstack([lon, lat]).T + + neighbor_indices = _source_tree.query_radius(dest_coords, r=r) + + elif coord_type == "cartesian": + if remap_to == "nodes": + x, y, z = ( + destination_grid.node_x.values, + destination_grid.node_y.values, + destination_grid.node_z.values, + ) + elif remap_to == "face centers": + x, y, z = ( + destination_grid.face_x.values, + destination_grid.face_y.values, + destination_grid.face_z.values, + ) + elif remap_to == "edge centers": + x, y, z = ( + destination_grid.edge_x.values, + destination_grid.edge_y.values, + destination_grid.edge_z.values, + ) + else: + raise ValueError( + f"Invalid remap_to. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {remap_to}" + ) + + _source_tree = source_grid.get_ball_tree( + coordinates=source_data_mapping, + coordinate_system="cartesian", + distance_metric="minkowski", + ) + + dest_coords = np.vstack([x, y, z]).T + + neighbor_indices = _source_tree.query_radius(dest_coords, r=r) + + else: + raise ValueError( + f"Invalid coord_type. Expected either 'spherical' or 'cartesian', but received {coord_type}" + ) + + # make destination_shape a list instead of immutable tuple + destination_shape = list(source_data.shape) + # last dimension has same number of elements as neighbor_indices list + destination_shape[-1] = len(neighbor_indices) + destination_data = np.empty(destination_shape) + # Apply function to indices on last axis. + for i, idx in enumerate(neighbor_indices): + if len(idx): + destination_data[..., i] = func(source_data[..., idx]) + + return destination_data + + +def _apply_func_remap_uxda( + source_uxda: UxDataArray, + destination_grid: Grid, + remap_to: str = "face centers", + coord_type: str = "spherical", + func: Callable = np.mean, + r=1.0, +): + """Neighborhood function Remapping implementation for ``UxDataArray``. + + Parameters + --------- + source_uxda : UxDataArray + Source UxDataArray for remapping + destination_grid : Grid + Destination grid for remapping + remap_to : str, default="nodes" + Location of where to map data, either "nodes", "edge centers", or "face centers" + coord_type : str, default="spherical" + Indicates whether to remap using on Spherical or Cartesian coordinates for the computations when + remapping. + r : float, default=1. + Radius of neighborhood. + """ + + # check dimensions remapped to and from + if ( + (source_uxda._node_centered() and remap_to != "nodes") + or (source_uxda._face_centered() and remap_to != "face centers") + or (source_uxda._edge_centered() and remap_to != "edge centers") + ): + warnings.warn( + f"Your data is stored on {source_uxda.dims[-1]}, but you are remapping to {remap_to}" + ) + + # prepare dimensions + if remap_to == "nodes": + destination_dim = "n_node" + elif remap_to == "face centers": + destination_dim = "n_face" + else: + destination_dim = "n_edge" + + destination_dims = list(source_uxda.dims) + destination_dims[-1] = destination_dim + + # perform remapping + destination_data = _apply_func_remap( + source_uxda.uxgrid, + destination_grid, + source_uxda.data, + remap_to, + coord_type, + func, + r, + ) + # construct data array for remapping variable + uxda_remap = uxarray.core.dataarray.UxDataArray( + data=destination_data, + name=source_uxda.name, + coords=source_uxda.coords, + dims=destination_dims, + uxgrid=destination_grid, + ) + return uxda_remap + + +def _apply_func_remap_uxds( + source_uxds: UxDataset, + destination_grid: Grid, + remap_to: str = "face centers", + coord_type: str = "spherical", + func: Callable = np.mean, + r: float = 1.0, +): + """Neighboohood function implementation for ``UxDataset``. + + Parameters + --------- + source_uxds : UxDataset + Source UxDataset for remapping + destination_grid : Grid + Destination grid for remapping + remap_to : str, default="nodes" + Location of where to map data, either "nodes", "edge centers", or "face centers" + coord_type : str, default="spherical" + Indicates whether to remap using on Spherical or Cartesian coordinates + func : Callable = np.mean + function to apply to neighborhood + r : float, default=1. + Radius of neighborhood in deg + """ + + destination_uxds = uxarray.core.dataset.UxDataset(uxgrid=destination_grid) + for var_name in source_uxds.data_vars: + destination_uxds[var_name] = _apply_func_remap_uxda( + source_uxds[var_name], + destination_uxds, + remap_to, + coord_type, + func, + r, + ) + + return destination_uxds diff --git a/uxarray/remap/dataarray_accessor.py b/uxarray/remap/dataarray_accessor.py index 4a5f21dfe..5bfb6a06d 100644 --- a/uxarray/remap/dataarray_accessor.py +++ b/uxarray/remap/dataarray_accessor.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING, Optional from warnings import warn @@ -6,12 +7,14 @@ from uxarray.remap.inverse_distance_weighted import ( _inverse_distance_weighted_remap_uxda, ) +from uxarray.remap.apply_func import _apply_func_remap_uxda if TYPE_CHECKING: from uxarray.core.dataset import UxDataset from uxarray.core.dataarray import UxDataArray from uxarray.grid import Grid +import numpy as np class UxDataArrayRemapAccessor: @@ -26,6 +29,9 @@ def __repr__(self): " * nearest_neighbor(destination_obj, remap_to, coord_type)\n" ) methods_heading += " * inverse_distance_weighted(destination_obj, remap_to, coord_type, power, k)\n" + methods_heading += ( + " * apply_func(destination_grid, remap_to, coord_type, func, r)\n" + ) return prefix + methods_heading @@ -46,7 +52,7 @@ def nearest_neighbor( destination_obj : Grid, UxDataArray, UxDataset Optional destination for remapping, deprecating remap_to : str, default="nodes" - Location of where to map data, either "nodes" or "face centers" + Location of where to map data, either "nodes", "edge centers", or "face centers" coord_type : str, default="spherical" Indicates whether to remap using on spherical or cartesian coordinates """ @@ -71,6 +77,37 @@ def nearest_neighbor( self.uxda, destination_obj, remap_to, coord_type ) + def apply_func( + self, + destination_grid: Grid = None, + remap_to: str = "face centers", + coord_type: str = "spherical", + func: Callable = np.mean, + r=1, + ): + """Neighborhood function Remapping between a source (``UxDataArray``) + and destination.`. + + Parameters + --------- + destination_grid : Grid + Destination Grid for remapping + remap_to : str, default="nodes" + Location of where to map data, either "nodes", "edge centers", or "face centers" + coord_type : str, default="spherical" + Indicates whether to remap using on spherical or cartesian coordinates + func : Callable, default = np.mean + Function to apply to neighborhood + r : float, default=1 + Radius of neighborhood in deg + """ + if destination_grid is None: + raise ValueError("Destination needed for remap.") + + return _apply_func_remap_uxda( + self.uxda, destination_grid, remap_to, coord_type, func, r + ) + def inverse_distance_weighted( self, destination_grid: Optional[Grid] = None, @@ -90,7 +127,7 @@ def inverse_distance_weighted( destination_obj : Grid, UxDataArray, UxDataset Optional destination for remapping, deprecating remap_to : str, default="nodes" - Location of where to map data, either "nodes" or "face centers" + Location of where to map data, either "nodes", "edge centers", or "face centers" coord_type : str, default="spherical" Indicates whether to remap using on spherical or cartesian coordinates power : int, default=2 diff --git a/uxarray/remap/dataset_accessor.py b/uxarray/remap/dataset_accessor.py index d5a9edf3f..4ab30ebf8 100644 --- a/uxarray/remap/dataset_accessor.py +++ b/uxarray/remap/dataset_accessor.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING, Optional from warnings import warn @@ -6,12 +7,14 @@ from uxarray.remap.inverse_distance_weighted import ( _inverse_distance_weighted_remap_uxds, ) +from uxarray.remap.apply_func import _apply_func_remap_uxds if TYPE_CHECKING: from uxarray.core.dataset import UxDataset from uxarray.core.dataarray import UxDataArray from uxarray.grid import Grid +import numpy as np class UxDatasetRemapAccessor: @@ -26,6 +29,9 @@ def __repr__(self): " * nearest_neighbor(destination_obj, remap_to, coord_type)\n" ) methods_heading += " * inverse_distance_weighted(destination_obj, remap_to, coord_type, power, k)\n" + methods_heading += ( + " * apply_func(destination_grid, remap_to, coord_type, func, r)\n" + ) return prefix + methods_heading @@ -72,6 +78,37 @@ def nearest_neighbor( self.uxds, destination_obj, remap_to, coord_type ) + def apply_func( + self, + destination_grid: Grid = None, + remap_to: str = "face centers", + coord_type: str = "spherical", + func: Callable = np.mean, + r=1, + ): + """Neighborhood function Remapping between a source (``UxDataset``) and + destination.`. + + Parameters + --------- + destination_grid : Grid + Destination Grid for remapping + remap_to : str, default="nodes" + Location of where to map data, either "nodes", "edge centers", or "face centers" + coord_type : str, default="spherical" + Indicates whether to remap using on spherical or cartesian coordinates + func : Callable, default = np.mean + Function to apply to neighborhood + r : float, default=1 + Radius of neighborhood in deg + """ + if destination_grid is None: + raise ValueError("Destination needed for remap.") + + return _apply_func_remap_uxds( + self.uxds, destination_grid, remap_to, coord_type, func, r + ) + def inverse_distance_weighted( self, destination_grid: Optional[Grid] = None,