Skip to content

Commit c34c5b8

Browse files
committed
Added mock-based test, made by Copilot
1 parent 6f188e2 commit c34c5b8

File tree

2 files changed

+122
-5
lines changed

2 files changed

+122
-5
lines changed

tests/test_traj.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import pytest
2+
import numpy as np
3+
import xarray as xr
4+
import pyproj
5+
from datetime import datetime, timedelta
6+
from trajan.traj import Traj, grid_area
7+
8+
9+
@pytest.fixture
10+
def mock_dataset():
11+
# Making a 1D dataset for basic testing
12+
times = np.array([datetime(2023, 1, 1) + timedelta(hours=i) for i in range(5)])
13+
#times = np.repeat(times[np.newaxis, :], 2, axis=0)
14+
data = {
15+
"lon": (["trajectory", "obs"], np.random.rand(2, 5)),
16+
"lat": (["trajectory", "obs"], np.random.rand(2, 5)),
17+
"time": (["obs"], times)
18+
#"time": (["trajectory", "obs"], times)
19+
}
20+
coords = {"trajectory": [0, 1], "obs": range(5)}
21+
return xr.Dataset(data, coords=coords)
22+
23+
24+
@pytest.fixture
25+
def mock_traj(mock_dataset):
26+
return Traj(mock_dataset, trajectory_dim="trajectory", obs_dim="obs", time_varname="time")
27+
28+
29+
def test_grid_area():
30+
lons = np.linspace(0, 10, 5)
31+
lats = np.linspace(0, 10, 5)
32+
area = grid_area(lons, lats)
33+
assert isinstance(area, xr.DataArray)
34+
assert area.name == "grid_area"
35+
#assert "lat" in area.dims
36+
#assert "lon" in area.dims
37+
38+
39+
def test_index_of_last(mock_traj):
40+
result = mock_traj.index_of_last()
41+
assert isinstance(result, xr.DataArray)
42+
assert result.name == "index_of_last"
43+
assert "trajectory" in result.dims
44+
45+
46+
#def test_speed(mock_traj):
47+
# result = mock_traj.speed()
48+
# assert isinstance(result, xr.DataArray)
49+
# assert result.name == "speed"
50+
# assert "obs" in result.dims
51+
52+
53+
#def test_time_to_next(mock_traj):
54+
# result = mock_traj.time_to_next()
55+
# assert isinstance(result, xr.DataArray)
56+
# assert result.name == "time_to_next"
57+
# assert "obs" in result.dims
58+
59+
60+
def test_distance_to_next(mock_traj):
61+
result = mock_traj.distance_to_next()
62+
assert isinstance(result, xr.DataArray)
63+
assert result.name is None
64+
assert "obs" in result.dims
65+
66+
67+
def test_azimuth_to_next(mock_traj):
68+
result = mock_traj.azimuth_to_next()
69+
assert isinstance(result, xr.DataArray)
70+
assert result.name is None
71+
assert "obs" in result.dims
72+
73+
74+
#def test_velocity_components(mock_traj):
75+
# u, v = mock_traj.velocity_components()
76+
# assert isinstance(u, xr.DataArray)
77+
# assert isinstance(v, xr.DataArray)
78+
# assert u.name == "u_velocity"
79+
# assert v.name == "v_velocity"
80+
# assert "obs" in u.dims
81+
# assert "obs" in v.dims
82+
83+
84+
def test_get_area_convex_hull(mock_traj):
85+
result = mock_traj.get_area_convex_hull()
86+
assert isinstance(result, xr.DataArray)
87+
assert result.name == "convex_hull_area"
88+
assert result.attrs["units"] == "m2"
89+
90+
91+
def test_make_grid(mock_traj):
92+
result = mock_traj.make_grid(dx=10000, dy=None)
93+
assert isinstance(result, xr.Dataset)
94+
assert "cell_area" in result.data_vars
95+
96+
97+
def test_crop(mock_traj):
98+
result = mock_traj.crop(lonmin=0, lonmax=5, latmin=0, latmax=5)
99+
assert isinstance(result, xr.Dataset)
100+
101+
102+
def test_contained_in(mock_traj):
103+
result = mock_traj.contained_in(lonmin=0, lonmax=5, latmin=0, latmax=5)
104+
assert isinstance(result, xr.Dataset)
105+
106+
107+
def test_assign_cf_attrs(mock_traj):
108+
result = mock_traj.assign_cf_attrs(creator_name="Test", creator_email="test@example.com")
109+
assert isinstance(result, xr.Dataset)
110+
assert result.attrs["creator_name"] == "Test"
111+
assert result.attrs["creator_email"] == "test@example.com"

trajan/traj.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,8 @@ def speed(self) -> xr.DataArray:
641641
distance = self.distance_to_next()
642642
timedelta_seconds = self.time_to_next() / np.timedelta64(1, 's')
643643

644-
return distance / timedelta_seconds
644+
speed = distance / timedelta_seconds
645+
return speed
645646

646647
@abstractmethod
647648
def time_to_next(self) -> pd.Timedelta:
@@ -881,15 +882,20 @@ def get_area_convex_hull(self):
881882
"""
882883
from scipy.spatial import ConvexHull
883884

884-
lon = self.ds.lon.where(self.ds.status == 0)
885-
lat = self.ds.lat.where(self.ds.status == 0)
885+
if 'status' in self.ds.variables:
886+
lon = self.ds.lon.where(self.ds.status == 0) # OpenDrift specific
887+
lat = self.ds.lat.where(self.ds.status == 0)
888+
else:
889+
lon = self.ds.lon.where(np.isfinite(self.ds.lon) is True)
890+
lat = self.ds.lat.where(np.isfinite(self.ds.lat) is True)
891+
886892
fin = np.isfinite(lat + lon)
887893
if np.sum(fin) <= 3:
888894
return xr.DataArray(0, name="convex_hull_area", attrs={"units": "m2"})
889895
if len(np.unique(lat)) == 1 and len(np.unique(lon)) == 1:
890896
return xr.DataArray(0, name="convex_hull_area", attrs={"units": "m2"})
891-
lat = lat[fin]
892-
lon = lon[fin]
897+
lat = lat.where(fin)
898+
lon = lon.where(fin)
893899
aea = pyproj.Proj(
894900
f'+proj=aea +lat_0={lat.mean().values} +lat_1={lat.min().values} +lat_2={lat.max().values} +lon_0={lon.mean().values} +x_0=0 +y_0=0 +datum=NAD83 +units=m +no_defs'
895901
)

0 commit comments

Comments
 (0)