Skip to content

Commit 75f5bca

Browse files
authored
Merge pull request #178 from knutfrode/dev
More methods now return Xarray Dataarrays instead of numpy arrays
2 parents d7e5b1a + a2cd113 commit 75f5bca

File tree

2 files changed

+110
-44
lines changed

2 files changed

+110
-44
lines changed

trajan/traj.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,21 @@ def ensure_time_dim(ds, time_dim):
4242
return ds
4343

4444
def grid_area(lons, lats):
45-
"""Calculate the area of each grid cell"""
45+
"""
46+
Calculate the area of each grid cell.
47+
48+
Parameters
49+
----------
50+
lons : array-like
51+
Longitudes of the grid.
52+
lats : array-like
53+
Latitudes of the grid.
54+
55+
Returns
56+
-------
57+
xarray.DataArray
58+
Grid cell areas with dimensions ('lat', 'lon').
59+
"""
4660
from shapely.geometry import Polygon
4761

4862
if lons.ndim == 1:
@@ -59,7 +73,7 @@ def grid_area(lons, lats):
5973
polygon = Polygon([(lon[0], lat[0]), (lon[1], lat[1]), (lon[2], lat[2]), (lon[3], lat[3])])
6074
grid_areas[i, j] = abs(geod.geometry_area_perimeter(polygon)[0])
6175

62-
return grid_areas
76+
return xr.DataArray(grid_areas, name="grid_area", attrs={'units': 'm^2', 'description': 'Area of each grid cell'})
6377

6478

6579
class Traj:
@@ -595,15 +609,17 @@ def assign_cf_attrs(self,
595609
return ds
596610

597611
def index_of_last(self):
598-
"""Find index of last valid position along each trajectory.
612+
"""
613+
Find the index of the last valid position along each trajectory.
599614
600615
Returns
601616
-------
602-
array-like
603-
Array of the index of the last valid position along each trajectory.
617+
xarray.DataArray
618+
Index of the last valid position for each trajectory.
619+
Dimensions: ('trajectory',).
604620
"""
605-
return np.ma.notmasked_edges(np.ma.masked_invalid(self.ds.lon.values),
606-
axis=1)[1][1]
621+
last_indices = np.ma.notmasked_edges(np.ma.masked_invalid(self.ds.lon.values), axis=1)[1][1]
622+
return xr.DataArray(last_indices, dims=[self.trajectory_dim], name="index_of_last")
607623

608624
@abstractmethod
609625
def speed(self) -> xr.DataArray:
@@ -855,35 +871,33 @@ def convex_hull_contains_point(self, lon, lat):
855871
return p.contains_points(point)[0]
856872

857873
def get_area_convex_hull(self):
858-
"""Return the area [m2] of the convex hull spanned by all positions.
874+
"""
875+
Calculate the area [m2] of the convex hull spanned by all positions.
859876
860877
Returns
861878
-------
862-
scalar
863-
Area [m2] of convex hull around all positions.
879+
xarray.DataArray
880+
Area of the convex hull in square meters.
864881
"""
865-
866882
from scipy.spatial import ConvexHull
867883

868884
lon = self.ds.lon.where(self.ds.status == 0)
869885
lat = self.ds.lat.where(self.ds.status == 0)
870886
fin = np.isfinite(lat + lon)
871887
if np.sum(fin) <= 3:
872-
return 0
888+
return xr.DataArray(0, name="convex_hull_area", attrs={"units": "m2"})
873889
if len(np.unique(lat)) == 1 and len(np.unique(lon)) == 1:
874-
return 0
890+
return xr.DataArray(0, name="convex_hull_area", attrs={"units": "m2"})
875891
lat = lat[fin]
876892
lon = lon[fin]
877-
# An equal area projection centered around the particles
878893
aea = pyproj.Proj(
879894
f'+proj=aea +lat_0={lat.mean().values} +lat_1={lat.min().values} +lat_2={lat.max().values} +lon_0={lon.mean().values} +x_0=0 +y_0=0 +datum=NAD83 +units=m +no_defs'
880895
)
881-
882896
x, y = aea(lat, lon, inverse=False)
883897
fin = np.isfinite(x + y)
884898
points = np.vstack((y.T, x.T)).T
885899
hull = ConvexHull(points)
886-
return np.array(hull.volume) # volume=area for 2D as here
900+
return xr.DataArray(hull.volume, name="convex_hull_area", attrs={"units": "m2"})
887901

888902
@abstractmethod
889903
def gridtime(self, times, time_varname=None) -> xr.Dataset:
@@ -1277,7 +1291,7 @@ def make_grid(self, dx, dy=None, z=None,
12771291

12781292
x = np.arange(xmin, xmax + dx*2, dx) # One extra row/column
12791293
y = np.arange(ymin, ymax + dy*2, dy)
1280-
area = grid_area(x, y)
1294+
area = grid_area(x, y).data
12811295

12821296
# Create Xarray Dataset
12831297
data_vars = {}

trajan/traj1d.py

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,18 @@ def __init__(self, ds, trajectory_dim, obs_dim, time_varname):
1919
super().__init__(ds, trajectory_dim, obs_dim, time_varname)
2020

2121
def timestep(self):
22-
"""Time step between observations in seconds."""
23-
return ((self.ds.time[1] - self.ds.time[0]) /
24-
np.timedelta64(1, 's')).values
22+
"""
23+
Calculate the time step between observations in seconds.
24+
25+
Returns
26+
-------
27+
xarray.DataArray
28+
Time step between observations with a single value.
29+
Attributes:
30+
- units: seconds
31+
"""
32+
timestep = ((self.ds.time[1] - self.ds.time[0]) / np.timedelta64(1, 's')).values
33+
return xr.DataArray(timestep, name="timestep", attrs={"units": "seconds"})
2534

2635
def is_1d(self):
2736
return True
@@ -30,54 +39,78 @@ def is_2d(self):
3039
return False
3140

3241
def to_2d(self, obs_dim='obs'):
42+
"""
43+
Convert the dataset to a 2D representation.
44+
45+
Parameters
46+
----------
47+
obs_dim : str, optional
48+
Name of the observation dimension in the 2D representation, by default 'obs'.
49+
50+
Returns
51+
-------
52+
xarray.Dataset
53+
Dataset with a 2D representation of trajectories.
54+
"""
3355
ds = self.ds.copy()
34-
time = ds[self.time_varname].rename({
35-
self.time_varname: obs_dim
36-
}).expand_dims(
37-
dim={
38-
self.trajectory_dim: ds.sizes[self.trajectory_dim]
39-
}).assign_coords({self.trajectory_dim: ds[self.trajectory_dim]})
40-
# TODO should also add cf_role here
56+
time = ds[self.time_varname].rename({self.time_varname: obs_dim}).expand_dims(
57+
dim={self.trajectory_dim: ds.sizes[self.trajectory_dim]}
58+
).assign_coords({self.trajectory_dim: ds[self.trajectory_dim]})
4159
ds = ds.rename({self.time_varname: obs_dim})
4260
ds[self.time_varname] = time
43-
ds[obs_dim] = np.arange(0, ds.sizes[obs_dim])
44-
61+
ds[obs_dim] = xr.DataArray(np.arange(0, ds.sizes[obs_dim]), dims=[obs_dim])
4562
return ds
4663

4764
def to_1d(self):
4865
return self.ds.copy()
4966

5067
def time_to_next(self):
68+
"""
69+
Calculate the time difference to the next observation.
70+
71+
Returns
72+
-------
73+
xarray.DataArray
74+
Time difference to the next observation with the same dimensions as the dataset.
75+
Attributes:
76+
- units: seconds
77+
"""
5178
time_step = self.ds.time[1] - self.ds.time[0]
52-
return time_step
79+
return xr.DataArray(time_step, name="time_to_next", attrs={"units": "seconds"})
5380

5481
def velocity_spectrum(self):
55-
82+
"""
83+
Calculate the velocity spectrum for a single trajectory.
84+
85+
Returns
86+
-------
87+
xarray.DataArray
88+
Velocity spectrum with dimensions ('period').
89+
Attributes:
90+
- units: power
91+
"""
5692
if self.ds.sizes[self.trajectory_dim] > 1:
57-
raise ValueError(
58-
'Spectrum can only be calculated for a single trajectory')
93+
raise ValueError('Spectrum can only be calculated for a single trajectory')
5994

6095
u, v = self.velocity_components()
6196
u = u.squeeze()
6297
v = v.squeeze()
6398
u = u[np.isfinite(u)]
6499
v = v[np.isfinite(v)]
65100

66-
timestep_h = (self.ds.time[1] - self.ds.time[0]) / np.timedelta64(
67-
1, 'h') # hours since start
101+
timestep_h = (self.ds.time[1] - self.ds.time[0]) / np.timedelta64(1, 'h') # hours since start
68102

69103
ps = np.abs(np.fft.rfft(np.abs(u + 1j * v)))
70104
freq = np.fft.rfftfreq(n=u.size, d=timestep_h.values)
71105
freq[0] = np.nan
72106

73107
da = xr.DataArray(
74108
data=ps,
75-
name='velocity spectrum',
109+
name='velocity_spectrum',
76110
dims=['period'],
77-
coords={'period': (['period'], 1 / freq, {
78-
'units': 'hours'
79-
})},
80-
attrs={'units': 'power'})
111+
coords={'period': (['period'], 1 / freq, {'units': 'hours'})},
112+
attrs={'units': 'power'}
113+
)
81114

82115
return da
83116

@@ -146,9 +179,27 @@ def skill_matching(traj, expected):
146179

147180
return s
148181

149-
def skill(self, expected, method='liu-weissberg', **kwargs) -> xr.Dataset:
150-
151-
expected = expected.traj # Normalise
182+
def skill(self, expected, method='liu-weissberg', **kwargs) -> xr.DataArray:
183+
"""
184+
Calculate the skill score for trajectories.
185+
186+
Parameters
187+
----------
188+
expected : Traj1d
189+
Expected trajectory dataset.
190+
method : str, optional
191+
Skill score method, by default 'liu-weissberg'.
192+
**kwargs : dict
193+
Additional arguments for the skill score calculation.
194+
195+
Returns
196+
-------
197+
xarray.DataArray
198+
Skill score with dimensions matching the dataset.
199+
Attributes:
200+
- method: Skill score calculation method.
201+
"""
202+
expected = expected.traj # Normalize
152203
expected_trajdim = expected.trajectory_dim
153204
self_trajdim = self.trajectory_dim
154205

@@ -157,7 +208,8 @@ def skill(self, expected, method='liu-weissberg', **kwargs) -> xr.Dataset:
157208
if numtraj_self > 1 and numtraj_expected > 1 and numtraj_self != numtraj_expected:
158209
raise ValueError(
159210
'Datasets must have the same number of trajectories, or a single trajectory. '
160-
f'This dataset: {numtraj_self}, expected: {numtraj_expected}.')
211+
f'This dataset: {numtraj_self}, expected: {numtraj_expected}.'
212+
)
161213

162214
numobs_self = self.ds.sizes[self.obs_dim]
163215
numobs_expected = expected.ds.sizes[expected.obs_dim]

0 commit comments

Comments
 (0)