Skip to content

Commit aae14ed

Browse files
authored
Merge pull request #517 from tobac-project/main
Merge changes from `main` into `RC_v1.6.x`
2 parents 1433722 + 38b2ca0 commit aae14ed

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

tobac/tests/test_xarray_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import Union
66

7+
import pandas as pd
78
import pytest
89
import numpy as np
910
import xarray as xr
@@ -156,6 +157,31 @@ def test_find_axis_from_dim_coord(
156157
(1, 1),
157158
{"test_coord1": (1, 1, 1), "test_coord_time": (5, 6, 7)},
158159
),
160+
(
161+
["time", "x", "y"],
162+
{
163+
"test_coord_datetime": (
164+
"time",
165+
pd.date_range(
166+
datetime.datetime(2000, 1, 1),
167+
datetime.datetime(2000, 1, 1, 6),
168+
freq="1h",
169+
inclusive="left",
170+
),
171+
),
172+
"test_coord_time": ("time", [5, 6, 7, 8, 9, 10]),
173+
},
174+
(1, 1),
175+
{
176+
"test_coord_datetime": pd.date_range(
177+
datetime.datetime(2000, 1, 1),
178+
datetime.datetime(2000, 1, 1, 3),
179+
freq="1h",
180+
inclusive="left",
181+
),
182+
"test_coord_time": (5, 6, 7),
183+
},
184+
),
159185
],
160186
)
161187
def test_add_coordinates_to_features_interpolate_along_other_dims(

tobac/utils/internal/xarray_utils.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,20 @@ def add_coordinates_to_features(
420420
except KeyError:
421421
pass
422422

423-
return_feat_df[interp_coord_name] = renamed_dim_da[interp_coord].interp(
424-
coords={
425-
dim: dim_interp_coords[dim] for dim in renamed_dim_da[interp_coord].dims
426-
}
427-
)
423+
if renamed_dim_da[interp_coord].dtype.kind in "uifc":
424+
# Interpolate over the coordinate
425+
return_feat_df[interp_coord_name] = renamed_dim_da[interp_coord].interp(
426+
coords={
427+
dim: dim_interp_coords[dim]
428+
for dim in renamed_dim_da[interp_coord].dims
429+
}
430+
)
431+
else:
432+
# If non-numeric, we should instead just index the nearest values:
433+
return_feat_df[interp_coord_name] = renamed_dim_da[interp_coord].isel(
434+
**{
435+
dim: np.round(dim_interp_coords[dim]).astype(int)
436+
for dim in renamed_dim_da[interp_coord].dims
437+
}
438+
)
428439
return return_feat_df

0 commit comments

Comments
 (0)