Skip to content

Commit 40b3ffa

Browse files
authored
Merge pull request #835 from DHI/interpolator_dataclass
Interpolation ids and weights always go together
2 parents c4f6672 + 94bbe93 commit 40b3ffa

File tree

11 files changed

+144
-272
lines changed

11 files changed

+144
-272
lines changed

mikeio/_interpolation.py

Lines changed: 65 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,9 @@
11
from __future__ import annotations
2-
from typing import TYPE_CHECKING, overload
2+
from dataclasses import dataclass
33
import numpy as np
44

5-
if TYPE_CHECKING:
6-
from .dataset import Dataset, DataArray
7-
8-
from .spatial import GeometryUndefined
9-
105

116
def get_idw_interpolant(distances: np.ndarray, p: float = 2) -> np.ndarray:
12-
"""IDW interpolant for 2d array of distances.
13-
14-
https://pro.arcgis.com/en/pro-app/help/analysis/geostatistical-analyst/how-inverse-distance-weighted-interpolation-works.htm
15-
16-
Parameters
17-
----------
18-
distances : array-like
19-
distances between interpolation point and grid point
20-
p : float, optional
21-
power of inverse distance weighting, default=2
22-
23-
Returns
24-
-------
25-
np.array
26-
weights
27-
28-
"""
297
is_1d = distances.ndim == 1
308
if is_1d:
319
distances = np.atleast_2d(distances)
@@ -45,123 +23,68 @@ def get_idw_interpolant(distances: np.ndarray, p: float = 2) -> np.ndarray:
4523
return weights
4624

4725

48-
@overload
49-
def interp2d(
50-
data: np.ndarray | DataArray,
51-
elem_ids: np.ndarray,
52-
weights: np.ndarray | None = None,
53-
shape: tuple[int, ...] | None = None,
54-
) -> np.ndarray: ...
55-
56-
57-
@overload
58-
def interp2d(
59-
data: Dataset,
60-
elem_ids: np.ndarray,
61-
weights: np.ndarray | None = None,
62-
shape: tuple[int, ...] | None = None,
63-
) -> Dataset: ...
64-
65-
66-
def interp2d(
67-
data: Dataset | DataArray | np.ndarray,
68-
elem_ids: np.ndarray,
69-
weights: np.ndarray | None = None,
70-
shape: tuple[int, ...] | None = None,
71-
) -> Dataset | np.ndarray:
72-
"""interp spatially in data (2d only).
73-
74-
Parameters
75-
----------
76-
data : mikeio.Dataset, DataArray, or ndarray
77-
dfsu data
78-
elem_ids : ndarray(int)
79-
n sized array of 1 or more element ids used for interpolation
80-
weights : ndarray(float), optional
81-
weights with same size as elem_ids used for interpolation
82-
shape: tuple, optional
83-
reshape output
84-
85-
Returns
86-
-------
87-
ndarray, Dataset, or DataArray
88-
spatially interped data with same type and shape as input
89-
90-
Examples
91-
--------
92-
>>> elem_ids, weights = dfs.get_spatial_interpolant(coords)
93-
>>> dsi = interp2d(ds, elem_ids, weights)
94-
95-
"""
96-
from .dataset import DataArray, Dataset
97-
98-
if isinstance(data, Dataset):
99-
ds = data.copy()
100-
101-
ni = len(elem_ids)
102-
103-
interp_data_vars = {}
104-
105-
for da in ds:
106-
key = da.name
107-
if "time" not in da.dims:
108-
idatitem = _interp_itemstep(da.to_numpy(), elem_ids, weights)
109-
if shape:
110-
idatitem = idatitem.reshape(*shape)
111-
112-
else:
113-
nt, _ = da.shape
114-
# use dtype of da
115-
idatitem = np.empty(shape=(nt, ni), dtype=da.values.dtype)
116-
for step in range(nt):
117-
idatitem[step, :] = _interp_itemstep(
118-
da[step].to_numpy(), elem_ids, weights
119-
)
120-
if shape:
121-
idatitem = idatitem.reshape((nt, *shape))
122-
123-
dims = ("time", "element") # TODO is this the best?
124-
interp_data_vars[key] = DataArray(
125-
data=idatitem,
126-
time=da.time,
127-
dims=dims,
128-
item=da.item,
129-
geometry=GeometryUndefined(),
130-
)
131-
132-
new_ds = Dataset(interp_data_vars, validate=False)
133-
return new_ds
134-
135-
if isinstance(data, DataArray):
136-
# TODO why doesn't this return a DataArray?
137-
data = data.to_numpy()
138-
139-
if isinstance(data, np.ndarray):
26+
@dataclass
27+
class Interpolant:
28+
ids: np.ndarray
29+
weights: np.ndarray
30+
31+
@staticmethod
32+
def from_distances(distances: np.ndarray, p: float = 2) -> np.ndarray:
33+
"""IDW interpolant for 2d array of distances.
34+
35+
https://pro.arcgis.com/en/pro-app/help/analysis/geostatistical-analyst/how-inverse-distance-weighted-interpolation-works.htm
36+
37+
Parameters
38+
----------
39+
distances : array-like
40+
distances between interpolation point and grid point
41+
p : float, optional
42+
power of inverse distance weighting, default=2
43+
44+
Returns
45+
-------
46+
np.array
47+
weights
48+
49+
"""
50+
return get_idw_interpolant(distances, p)
51+
52+
def interp1d(self, data: np.ndarray) -> np.ndarray:
53+
ids = self.ids
54+
weights = self.weights
55+
result = np.dot(data[:, ids], weights)
56+
assert isinstance(result, np.ndarray)
57+
return result
58+
59+
def interp2d(
60+
self,
61+
data: np.ndarray,
62+
) -> np.ndarray:
63+
"""interp spatially in data (2d only).
64+
65+
Parameters
66+
----------
67+
data : ndarray
68+
dfsu data
69+
70+
Returns
71+
-------
72+
ndarray
73+
spatially interpolated data
74+
75+
"""
76+
weights = self.weights
77+
elem_ids = self.ids
78+
14079
if data.ndim == 1:
141-
# data is single item and single time step
142-
idatitem = _interp_itemstep(data, elem_ids, weights)
143-
if shape:
144-
idatitem = idatitem.reshape(*shape)
145-
return idatitem
146-
147-
ni = len(elem_ids)
148-
datitem = data
149-
nt, _ = datitem.shape
150-
idatitem = np.empty(shape=(nt, ni), dtype=datitem.dtype)
151-
for step in range(nt):
152-
idatitem[step, :] = _interp_itemstep(datitem[step], elem_ids, weights)
153-
if shape:
154-
idatitem = idatitem.reshape((nt, *shape))
155-
return idatitem
156-
157-
158-
def _interp_itemstep(
159-
data: np.ndarray,
160-
elem_ids: np.ndarray,
161-
weights: np.ndarray | None = None,
162-
) -> np.ndarray:
163-
if weights is None:
164-
return data[elem_ids]
165-
else:
166-
idat = data[elem_ids] * weights
167-
return np.sum(idat, axis=1) if weights.ndim == 2 else idat
80+
idat = data[elem_ids] * weights.astype(data.dtype)
81+
return np.sum(idat, axis=1) if weights.ndim == 2 else idat
82+
elif data.ndim == 2:
83+
# data shape: (nt, nelem)
84+
85+
# data[:, elem_ids]: (nt, ni)
86+
# weights: (ni,) or (ni, nweights)
87+
idat = data[:, elem_ids] * weights.astype(data.dtype) # broadcasting
88+
return np.sum(idat, axis=-1) if weights.ndim == 2 else idat
89+
else:
90+
raise ValueError("data must be 1D or 2D array")

mikeio/_track.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _extract_track(
8484

8585
# spatial interpolation
8686
n_pts = 1 if method == "nearest" else 5
87-
elem_ids, weights = geometry.get_2d_interpolant(
87+
interpolant = geometry.get_2d_interpolant(
8888
coords[i_start : (i_end + 1)], n_nearest=n_pts
8989
)
9090

@@ -131,7 +131,9 @@ def is_EOF(step: int) -> bool:
131131
continue
132132

133133
w = (t_rel[t] - t1) / timestep # time-weight
134-
eid = elem_ids[i_interp]
134+
eid = interpolant.ids[i_interp]
135+
weights = interpolant.weights
136+
# TODO move to interpolation module?
135137
if np.any(eid > 0):
136138
dati = (1 - w) * np.dot(d1[:, eid], weights[i_interp])
137139
dati = dati + w * np.dot(d2[:, eid], weights[i_interp])

mikeio/dataset/_dataarray.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
import pandas as pd
2626
from mikecore.DfsuFile import DfsuFileType
2727

28+
2829
from ..eum import EUMType, EUMUnit, ItemInfo
2930
from .._time import _get_time_idx_list, _n_selected_timesteps
3031

3132
if TYPE_CHECKING:
3233
from ._dataset import Dataset
3334
import xarray
3435
from numpy.typing import ArrayLike
36+
from mikeio._interpolation import Interpolant
3537

3638

3739
from ..spatial import (
@@ -1014,8 +1016,6 @@ def interp(
10141016
x: float | None = None,
10151017
y: float | None = None,
10161018
z: float | None = None,
1017-
n_nearest: int = 3,
1018-
interpolant: tuple[Any, Any] | None = None,
10191019
**kwargs: Any,
10201020
) -> DataArray:
10211021
"""Interpolate data in time and space.
@@ -1045,7 +1045,7 @@ def interp(
10451045
z-coordinate of point to be interpolated to, by default None
10461046
n_nearest : int, optional
10471047
When using IDW interpolation, how many nearest points should
1048-
be used, by default: 3
1048+
be used.
10491049
interpolant : tuple, optional
10501050
Precomputed interpolant, by default None
10511051
**kwargs: Any
@@ -1075,15 +1075,12 @@ def interp(
10751075
if z is not None:
10761076
raise NotImplementedError()
10771077

1078-
geometry: GeometryPoint2D | GeometryPoint3D | GeometryUndefined = (
1079-
GeometryUndefined()
1080-
)
1078+
geometry: GeometryPoint2D | GeometryUndefined = GeometryUndefined()
1079+
interpolant = kwargs.get("interpolant")
10811080

10821081
# interp in space
1083-
if (x is not None) or (y is not None) or (z is not None):
1084-
coords = [(x, y)]
1085-
1086-
if isinstance(self.geometry, Grid2D): # TODO DIY bilinear interpolation
1082+
if (x is not None) or (y is not None):
1083+
if isinstance(self.geometry, Grid2D):
10871084
if x is None or y is None:
10881085
raise ValueError("both x and y must be specified")
10891086

@@ -1094,31 +1091,27 @@ def interp(
10941091
)
10951092
elif isinstance(self.geometry, Grid1D):
10961093
if interpolant is None:
1097-
interpolant = self.geometry.get_spatial_interpolant(coords) # type: ignore
1098-
dai = self.geometry.interp(self.to_numpy(), *interpolant).flatten()
1094+
assert x is not None
1095+
interpolant = self.geometry.get_spatial_interpolant(x)
1096+
dai = interpolant.interp1d(self.to_numpy()).flatten()
10991097
geometry = GeometryUndefined()
1100-
elif isinstance(self.geometry, GeometryFM3D):
1101-
raise NotImplementedError("Interpolation in 3d is not yet implemented")
11021098
elif isinstance(self.geometry, GeometryFM2D):
11031099
if x is None or y is None:
11041100
raise ValueError("both x and y must be specified")
11051101

11061102
if interpolant is None:
11071103
interpolant = self.geometry.get_2d_interpolant(
1108-
coords, # type: ignore
1109-
n_nearest=n_nearest,
1104+
xy=[(x, y)], # type: ignore
11101105
**kwargs, # type: ignore
11111106
)
1112-
dai = self.geometry.interp2d(self, *interpolant).flatten() # type: ignore
1113-
if z is None:
1114-
geometry = GeometryPoint2D(
1115-
x=x, y=y, projection=self.geometry.projection
1116-
)
1117-
# this is not supported yet (see above)
1118-
# else:
1119-
# geometry = GeometryPoint3D(
1120-
# x=x, y=y, z=z, projection=self.geometry.projection
1121-
# )
1107+
dai = interpolant.interp2d(self.to_numpy()).flatten()
1108+
geometry = GeometryPoint2D(
1109+
x=x, y=y, projection=self.geometry.projection
1110+
)
1111+
else:
1112+
raise NotImplementedError(
1113+
f"Interpolation in {self.geometry} is not yet implemented"
1114+
)
11221115

11231116
da = DataArray(
11241117
data=dai,
@@ -1285,7 +1278,7 @@ def interp_na(self, axis: str = "time", **kwargs: Any) -> DataArray:
12851278
def interp_like(
12861279
self,
12871280
other: DataArray | Grid2D | GeometryFM2D | pd.DatetimeIndex,
1288-
interpolant: tuple[Any, Any] | None = None,
1281+
interpolant: Interpolant | None = None,
12891282
**kwargs: Any,
12901283
) -> DataArray:
12911284
"""Interpolate in space (and in time) to other geometry (and time axis).
@@ -1297,7 +1290,7 @@ def interp_like(
12971290
----------
12981291
other: Dataset, DataArray, Grid2D, GeometryFM, pd.DatetimeIndex
12991292
The target geometry (and time axis) to interpolate to
1300-
interpolant: tuple, optional
1293+
interpolant: Interpolant, optional
13011294
Reuse pre-calculated index and weights
13021295
**kwargs: Any
13031296
additional kwargs are passed to interpolation method
@@ -1337,16 +1330,19 @@ def interp_like(
13371330
raise NotImplementedError()
13381331

13391332
if interpolant is None:
1340-
elem_ids, weights = self.geometry.get_2d_interpolant(xy, **kwargs)
1341-
else:
1342-
elem_ids, weights = interpolant
1333+
interpolant = self.geometry.get_2d_interpolant(xy, **kwargs)
13431334

13441335
if isinstance(geom, (Grid2D, GeometryFM2D)):
1345-
shape = (geom.ny, geom.nx) if isinstance(geom, Grid2D) else None
1336+
ari = interpolant.interp2d(data=self.to_numpy())
1337+
if isinstance(geom, Grid2D):
1338+
shape = (
1339+
(self.n_timesteps, geom.ny, geom.nx)
1340+
if self.dims[0] == "time"
1341+
else (geom.ny, geom.nx)
1342+
)
1343+
ari = ari.reshape(shape)
13461344

1347-
ari = self.geometry.interp2d(
1348-
data=self.to_numpy(), elem_ids=elem_ids, weights=weights, shape=shape
1349-
)
1345+
assert ari.dtype == self.dtype
13501346
else:
13511347
raise NotImplementedError(
13521348
"Interpolation to other geometry not yet supported"

mikeio/dataset/_dataset.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,6 @@ def interp(
811811
x: float | None = None,
812812
y: float | None = None,
813813
z: float | None = None,
814-
n_nearest: int = 3,
815814
**kwargs: Any,
816815
) -> Dataset:
817816
"""Interpolate data in time and space.
@@ -878,7 +877,6 @@ def interp(
878877
): # TODO remove this when all geometries implements the same method
879878
interpolant = self.geometry.get_2d_interpolant(
880879
xy, # type: ignore
881-
n_nearest=n_nearest,
882880
**kwargs, # type: ignore
883881
)
884882
das = [da.interp(x=x, y=y, interpolant=interpolant) for da in self]

0 commit comments

Comments
 (0)