Skip to content

Commit afa9a9f

Browse files
committed
Merge branch 'RC_v1.6.0' of https://github.com/tobac-project/tobac into xarray_notebook_update
2 parents f7f9ede + 4de8d3f commit afa9a9f

10 files changed

+266
-45
lines changed

tobac/merge_split.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ def merge_split_MEST(
185185
if PBC_flag in ["hdim_1", "hdim_2", "both"]:
186186
# Note that we multiply by dxy to get the distances in spatial coordinates
187187
dist_func = build_distance_function(
188-
min_h1 * dxy, max_h1 * dxy, min_h2 * dxy, max_h2 * dxy, PBC_flag, is_3D
188+
min_h1 * dxy if min_h1 is not None else None,
189+
max_h1 * dxy if max_h1 is not None else None,
190+
min_h2 * dxy if min_h2 is not None else None,
191+
max_h2 * dxy if max_h2 is not None else None,
192+
PBC_flag,
193+
is_3D,
189194
)
190195
cell_start_tree = BallTree(
191196
cell_start_locations, metric="pyfunc", func=dist_func

tobac/tests/test_analysis_spatial.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from tobac.utils.datetime import to_cftime, to_datetime64
2121

2222

23-
def test_calculate_distance():
23+
def test_calculate_distance_xy():
24+
"""
25+
Test for tobac.analysis.spatial.calculate_distance with cartesian coordinates
26+
"""
2427
test_features = pd.DataFrame(
2528
{
2629
"feature": [1, 2],
@@ -36,6 +39,13 @@ def test_calculate_distance():
3639

3740
assert calculate_distance(test_features.iloc[0], test_features.iloc[1]) == 1000
3841

42+
43+
def test_calculate_distance_latlon():
44+
"""
45+
Test for tobac.analysis.spatial.calculate_distance with latitude/longitude
46+
coordinates
47+
"""
48+
3949
test_features = pd.DataFrame(
4050
{
4151
"feature": [1, 2],
@@ -53,6 +63,27 @@ def test_calculate_distance():
5363
test_features.iloc[0], test_features.iloc[1]
5464
) == pytest.approx(1.11e5, rel=1e4)
5565

66+
67+
def test_calculate_distance_latlon_wrong_order():
68+
"""
69+
Test for tobac.analysis.spatial.calculate_distance with latitude/longitude
70+
coordinates provided in the wrong order. When lat/lon are provided with
71+
standard naming the function should detect this and switch their order to
72+
ensure that haversine distances are calculated correctly.
73+
"""
74+
75+
test_features = pd.DataFrame(
76+
{
77+
"feature": [1, 2],
78+
"frame": [0, 0],
79+
"time": [
80+
datetime(2000, 1, 1),
81+
datetime(2000, 1, 1),
82+
],
83+
"longitude": [0, 1],
84+
"latitude": [0, 0],
85+
}
86+
)
5687
# Test that if latitude and longitude coord names are given in the wrong order, then they are swapped:
5788
# (expectation is hdim1=y=latitude, hdim2=x=longitude, doesn't matter for x/y but does matter for lat/lon)
5889
assert calculate_distance(
@@ -64,14 +95,16 @@ def test_calculate_distance():
6495
) == pytest.approx(1.11e5, rel=1e4)
6596

6697

67-
def test_calculate_distance_errors():
68-
# Test invalid method_distance
98+
def test_calculate_distance_error_invalid_method():
99+
"""Test invalid method_distance"""
69100
with pytest.raises(ValueError, match="method_distance invalid*"):
70101
calculate_distance(
71102
pd.DataFrame(), pd.DataFrame(), method_distance="invalid_method_distance"
72103
)
73104

74-
# Test no horizontal coordinates"
105+
106+
def test_calculate_distance_error_no_coords():
107+
"""Test no horizontal coordinates in input dataframe"""
75108
test_features = pd.DataFrame(
76109
{
77110
"feature": [1, 2],
@@ -86,7 +119,9 @@ def test_calculate_distance_errors():
86119
with pytest.raises(ValueError):
87120
calculate_distance(test_features.iloc[0], test_features.iloc[1])
88121

89-
# Test dataframes with mismatching coordinates:
122+
123+
def test_calculate_distance_error_mismatched_coords():
124+
"""Test dataframes with mismatching coordinates"""
90125
with pytest.raises(ValueError, match="Discovered coordinates*"):
91126
calculate_distance(
92127
pd.DataFrame(
@@ -109,7 +144,9 @@ def test_calculate_distance_errors():
109144
),
110145
)
111146

112-
# Test invalid method:
147+
148+
def test_calculate_distance_error_no_method():
149+
"""Test hdim1_coord/hdim2_coord specified but no method_distance"""
113150
test_features = pd.DataFrame(
114151
{
115152
"feature": [1, 2],
@@ -122,30 +159,31 @@ def test_calculate_distance_errors():
122159
"projection_y_coordinate": [0, 0],
123160
}
124161
)
125-
with pytest.raises(ValueError):
126-
calculate_distance(
127-
test_features.iloc[0],
128-
test_features.iloc[1],
129-
method_distance="invalid_method",
130-
)
131162

132-
# Test hdim1_coord/hdim2_coord specified but no method_distance
133-
with pytest.raises(ValueError):
163+
with pytest.raises(ValueError, match="method_distance parameter must*"):
134164
calculate_distance(
135165
test_features.iloc[0],
136166
test_features.iloc[1],
137167
hdim1_coord="projection_y_coordinate",
138168
)
139169

140-
with pytest.raises(ValueError):
170+
with pytest.raises(ValueError, match="method_distance parameter must*"):
141171
calculate_distance(
142172
test_features.iloc[0],
143173
test_features.iloc[1],
144174
hdim2_coord="projection_x_coordinate",
145175
)
146176

147177

148-
def test_calculate_velocity_individual_xy():
178+
@pytest.mark.parametrize(
179+
"x_coord, y_coord",
180+
[("x", "y"), ("projection_x_coordinate", "projection_y_coordinate")],
181+
)
182+
def test_calculate_velocity_individual_xy(x_coord, y_coord):
183+
"""
184+
Test calculate_velocity_individual gives the correct result for a single
185+
track woth different x/y coordinate names
186+
"""
149187
test_features = pd.DataFrame(
150188
{
151189
"feature": [1, 2],
@@ -154,8 +192,8 @@ def test_calculate_velocity_individual_xy():
154192
datetime(2000, 1, 1, 0, 0),
155193
datetime(2000, 1, 1, 0, 10),
156194
],
157-
"projection_x_coordinate": [0, 6000],
158-
"projection_y_coordinate": [0, 0],
195+
x_coord: [0, 6000],
196+
y_coord: [0, 0],
159197
}
160198
)
161199

@@ -164,29 +202,40 @@ def test_calculate_velocity_individual_xy():
164202
== 10
165203
)
166204

205+
206+
@pytest.mark.parametrize(
207+
"lat_coord, lon_coord", [("lat", "lon"), ("latitude", "longitude")]
208+
)
209+
def test_calculate_velocity_individual_latlon(lat_coord, lon_coord):
210+
"""
211+
Test calculate_velocity_individual gives the correct result for a single
212+
track woth different lat/lon coordinate names
213+
"""
167214
test_features = pd.DataFrame(
168215
{
169216
"feature": [1, 2],
170-
"frame": [0, 1],
217+
"frame": [0, 0],
171218
"time": [
172219
datetime(2000, 1, 1, 0, 0),
173220
datetime(2000, 1, 1, 0, 10),
174221
],
175-
"x": [0, 6000],
176-
"y": [0, 0],
222+
lon_coord: [0, 1],
223+
lat_coord: [0, 0],
177224
}
178225
)
179226

180-
assert (
181-
calculate_velocity_individual(test_features.iloc[0], test_features.iloc[1])
182-
== 10
183-
)
227+
assert calculate_velocity_individual(
228+
test_features.iloc[0], test_features.iloc[1]
229+
) == pytest.approx(1.11e5 / 600, rel=1e2)
184230

185231

186232
@pytest.mark.parametrize(
187233
"time_format", ("datetime", "datetime64", "proleptic_gregorian", "360_day")
188234
)
189235
def test_calculate_velocity(time_format):
236+
"""
237+
Test velocity calculation using different time formats
238+
"""
190239
test_features = pd.DataFrame(
191240
{
192241
"feature": [1, 2, 3, 4],

tobac/tests/test_datetime.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010

1111
def test_to_cftime():
12+
"""Test conversion of datetime types to cftime calendars"""
1213
test_dates = [
1314
"2000-01-01",
1415
"2000-01-01 00:00:00",
@@ -34,8 +35,24 @@ def test_to_cftime():
3435
2000, 1, 1
3536
)
3637

38+
# Test array-like input
39+
for date in test_dates:
40+
assert datetime_utils.to_cftime([date], "standard")[0] == cftime.datetime(
41+
2000, 1, 1
42+
)
43+
assert datetime_utils.to_cftime([date], "gregorian")[
44+
0
45+
] == cftime.DatetimeGregorian(2000, 1, 1)
46+
assert datetime_utils.to_cftime([date], "360_day")[0] == cftime.Datetime360Day(
47+
2000, 1, 1
48+
)
49+
assert datetime_utils.to_cftime([date], "365_day")[0] == cftime.DatetimeNoLeap(
50+
2000, 1, 1
51+
)
52+
3753

3854
def test_to_timestamp():
55+
"""Test conversion of various datetime types to pandas timestamps"""
3956
test_dates = [
4057
"2000-01-01",
4158
"2000-01-01 00:00:00",
@@ -52,8 +69,13 @@ def test_to_timestamp():
5269
for date in test_dates:
5370
assert datetime_utils.to_timestamp(date) == pd.to_datetime("2000-01-01")
5471

72+
# Test array input
73+
for date in test_dates:
74+
assert datetime_utils.to_timestamp([date])[0] == pd.to_datetime("2000-01-01")
75+
5576

5677
def test_to_datetime():
78+
"""Test conversion of various datetime types to python datetime"""
5779
test_dates = [
5880
"2000-01-01",
5981
"2000-01-01 00:00:00",
@@ -70,8 +92,13 @@ def test_to_datetime():
7092
for date in test_dates:
7193
assert datetime_utils.to_datetime(date) == datetime(2000, 1, 1)
7294

95+
# Test array input
96+
for date in test_dates:
97+
assert datetime_utils.to_datetime([date])[0] == datetime(2000, 1, 1)
98+
7399

74100
def test_to_datetime64():
101+
"""Test conversion of various datetime types to numpy datetime64"""
75102
test_dates = [
76103
"2000-01-01",
77104
"2000-01-01 00:00:00",
@@ -90,8 +117,15 @@ def test_to_datetime64():
90117
"2000-01-01 00:00:00.000000000"
91118
)
92119

120+
# Test array input
121+
for date in test_dates:
122+
assert datetime_utils.to_datetime64([date])[0] == np.datetime64(
123+
"2000-01-01 00:00:00.000000000"
124+
)
125+
93126

94127
def test_to_datestr():
128+
"""Test conversion of various datetime types to ISO format datestring"""
95129
test_dates = [
96130
"2000-01-01",
97131
"2000-01-01 00:00:00",
@@ -113,6 +147,9 @@ def test_to_datestr():
113147

114148

115149
def test_to_datestr_array():
150+
"""Test conversion of arrays of various datetime types to ISO format
151+
datestring
152+
"""
116153
test_dates = [
117154
"2000-01-01",
118155
"2000-01-01 00:00:00",
@@ -132,6 +169,7 @@ def test_to_datestr_array():
132169

133170

134171
def test_match_datetime_format():
172+
"""Test match_datetime_format for various datetime-like combinations"""
135173
test_dates = [
136174
"2000-01-01T00:00:00.000000000",
137175
datetime(2000, 1, 1),
@@ -149,6 +187,9 @@ def test_match_datetime_format():
149187

150188

151189
def test_match_datetime_format_array():
190+
"""Test match_datetime_format for various datetime-like combinations with
191+
array input
192+
"""
152193
test_dates = [
153194
"2000-01-01T00:00:00.000000000",
154195
datetime(2000, 1, 1),
@@ -168,6 +209,8 @@ def test_match_datetime_format_array():
168209

169210

170211
def test_match_datetime_format_error():
171-
# Test that if a non datetime-like object is provided as tagert a ValueError is raised:
172-
with pytest.raises(ValueError):
212+
"""Test that if a non datetime-like object is provided as target to
213+
match_datetime_format that a ValueError is raised:
214+
"""
215+
with pytest.raises(ValueError, match="Target is not a valid datetime*"):
173216
datetime_utils.match_datetime_format(datetime(2000, 1, 1), 1.5)

tobac/tests/test_feature_detection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,9 @@ def test_feature_detection_preserve_datetime_3d():
874874

875875

876876
def test_feature_detection_360_day_calendar():
877+
"""Tests that datetime format and feature detection work correctly with
878+
cftime 360-day calendars
879+
"""
877880
test_dset_size = (50, 50)
878881
test_hdim_1_pt = 20.0
879882
test_hdim_2_pt = 20.0

tobac/tests/test_generators.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
def test_field_and_features_over_time():
16+
"""Test iterating over field_and_features_over_time generator"""
1617
test_data = xr.DataArray(
1718
np.zeros([2, 10, 10]),
1819
dims=("time", "y", "x"),
@@ -56,6 +57,9 @@ def test_field_and_features_over_time():
5657

5758

5859
def test_field_and_features_over_time_time_padding():
60+
"""Test the time_padding functionality of field_and_features_over_time
61+
generator
62+
"""
5963
test_data = xr.DataArray(
6064
np.zeros([1, 10, 10]),
6165
dims=("time", "y", "x"),
@@ -104,6 +108,7 @@ def test_field_and_features_over_time_time_padding():
104108

105109

106110
def test_field_and_features_over_time_cftime():
111+
"""Test field_and_features_over_time when given cftime datetime formats"""
107112
test_data = xr.DataArray(
108113
np.zeros([2, 10, 10]),
109114
dims=("time", "y", "x"),
@@ -153,6 +158,9 @@ def test_field_and_features_over_time_cftime():
153158

154159

155160
def test_field_and_features_over_time_time_var_name():
161+
"""Test field_and_features_over_time generator works correctly with a time
162+
coordinate name other than "time"
163+
"""
156164
# Test non-standard time coord name:
157165
test_data = xr.DataArray(
158166
np.zeros([2, 10, 10]),
@@ -180,6 +188,10 @@ def test_field_and_features_over_time_time_var_name():
180188

181189

182190
def test_field_and_features_over_time_time_var_name_error():
191+
"""Test that field_and_features_over_time generator raises the correct
192+
error when the name of the time coordinates do not match between the given
193+
data and dataframe
194+
"""
183195
# Test if time_var_name not in dataarray:
184196
test_data = xr.DataArray(
185197
np.zeros([2, 10, 10]),
@@ -199,7 +211,7 @@ def test_field_and_features_over_time_time_var_name_error():
199211
}
200212
)
201213

202-
with pytest.raises(ValueError):
214+
with pytest.raises(ValueError, match="time not present in input field*"):
203215
next(generators.field_and_features_over_time(test_data, test_features))
204216

205217
# Test if time var name not in dataframe:
@@ -221,5 +233,5 @@ def test_field_and_features_over_time_time_var_name_error():
221233
}
222234
)
223235

224-
with pytest.raises(ValueError):
236+
with pytest.raises(ValueError, match="time not present in input feature*"):
225237
next(generators.field_and_features_over_time(test_data, test_features))

0 commit comments

Comments
 (0)