Skip to content

Commit 0d1b2c7

Browse files
authored
Merge pull request #180 from knutfrode/dev
Updated traj2d.py to return Xarray Dataarrays instead of numpy arrays
2 parents a3e3bf5 + 6f188e2 commit 0d1b2c7

File tree

2 files changed

+94
-53
lines changed

2 files changed

+94
-53
lines changed

docs/source/index.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ Indices and tables
9191
* :ref:`genindex`
9292

9393

94+
9495
.. |date| date::
96+
.. |time| date:: %H:%M
9597

96-
Last Updated on |date|
98+
Last Updated on |date| at |time|

trajan/traj2d.py

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,40 @@ def __init__(self, ds, trajectory_dim, obs_dim, time_varname):
2929

3030
def timestep(self, average=np.nanmedian):
3131
"""
32-
Return median time step between observations in seconds.
32+
Calculate the median time step between observations in seconds.
33+
34+
Parameters
35+
----------
36+
average : callable, optional
37+
Function to calculate the average time step, by default `np.nanmedian`.
38+
39+
Returns
40+
-------
41+
xarray.DataArray
42+
Median time step between observations.
43+
Attributes:
44+
- units: seconds
3345
"""
3446
td = np.diff(self.ds.time, axis=1) / np.timedelta64(1, 's')
3547
td = average(td)
36-
return td
48+
return xr.DataArray(td, name="timestep", attrs={"units": "seconds"})
3749

3850
def time_to_next(self):
39-
"""Return time from one position to the next.
40-
41-
Returned datatype is np.timedelta64
42-
Last time is repeated for last position (which has no next position)
51+
"""
52+
Calculate the time difference to the next observation.
53+
54+
Returns
55+
-------
56+
xarray.DataArray
57+
Time difference to the next observation with the same dimensions as the dataset.
58+
Attributes:
59+
- units: seconds
4360
"""
4461
time = self.ds.time
4562
lenobs = self.ds.sizes[self.obs_dim]
46-
td = time.isel(obs=slice(1, lenobs)) - time.isel(
47-
obs=slice(0, lenobs - 1))
48-
td = xr.concat((td, td.isel(obs=-1)),
49-
dim=self.obs_dim) # repeating last time step
50-
return td
63+
td = time.isel(obs=slice(1, lenobs)) - time.isel(obs=slice(0, lenobs - 1))
64+
td = xr.concat((td, td.isel(obs=-1)), dim=self.obs_dim) # Repeat last time step
65+
return td.astype("timedelta64[s]").rename("time_to_next").assign_attrs({"units": "seconds"})
5166

5267
def is_1d(self):
5368
return False
@@ -56,8 +71,19 @@ def is_2d(self):
5671
return True
5772

5873
def insert_nan_where(self, condition):
59-
"""Insert NaN-values in trajectories after given positions, shifting rest of trajectory."""
74+
"""
75+
Insert NaN values in trajectories after given positions, shifting the rest of the trajectory.
76+
77+
Parameters
78+
----------
79+
condition : xarray.DataArray
80+
Boolean condition indicating where NaN values should be inserted.
6081
82+
Returns
83+
-------
84+
xarray.Dataset
85+
Dataset with NaN values inserted at specified positions.
86+
"""
6187
index_of_last = self.index_of_last()
6288
num_inserts = condition.sum(dim=self.obs_dim)
6389
max_obs = (index_of_last + num_inserts).max().values
@@ -66,28 +92,26 @@ def insert_nan_where(self, condition):
6692
trajcoord = range(self.ds.sizes[self.trajectory_dim])
6793
nd = xr.Dataset(
6894
coords={
69-
self.trajectory_dim:
70-
([self.trajectory_dim],
71-
range(self.ds.sizes[self.trajectory_dim])),
72-
self.obs_dim:
73-
([self.obs_dim], range(max_obs)) # Longest trajectory
95+
self.trajectory_dim: ([self.trajectory_dim], trajcoord),
96+
self.obs_dim: ([self.obs_dim], range(max_obs)) # Longest trajectory
7497
},
75-
attrs=self.ds.attrs)
98+
attrs=self.ds.attrs
99+
)
76100

77101
# Add extended variables
78102
for varname, var in self.ds.data_vars.items():
79103
if self.obs_dim not in var.dims:
80104
nd[varname] = var
81105
continue
82-
# Create empty dataarray to hold interpolated values for given variable
106+
107+
# Create empty DataArray to hold interpolated values for the variable
83108
da = xr.DataArray(
84-
data=np.zeros(tuple(nd.sizes[di] for di in nd.dims)) * np.nan,
109+
data=np.full((nd.sizes[self.trajectory_dim], nd.sizes[self.obs_dim]), np.nan),
85110
dims=nd.dims,
86111
attrs=var.attrs,
87112
)
88113

89-
for t in range(self.ds.sizes[
90-
self.trajectory_dim]): # loop over trajectories
114+
for t in range(self.ds.sizes[self.trajectory_dim]): # Loop over trajectories
91115
numins = num_inserts[t]
92116
olddata = var.isel(trajectory=t).values
93117
wh = np.argwhere(condition.isel(trajectory=t).values) + 1
@@ -102,14 +126,11 @@ def insert_nan_where(self, condition):
102126
else:
103127
na = np.atleast_1d(np.nan)
104128
newdata = np.concatenate(
105-
[np.concatenate((ss, na)) for ss in s])
129+
[np.concatenate((ss, na)) for ss in s]
130+
)
106131

107-
newdata = newdata[slice(0, max_obs -
108-
1)] # truncating, should be checked
109-
da[{
110-
self.trajectory_dim: t,
111-
self.obs_dim: slice(0, len(newdata))
112-
}] = newdata
132+
newdata = newdata[:max_obs] # Truncate to max_obs
133+
da[{self.trajectory_dim: t, self.obs_dim: slice(0, len(newdata))}] = newdata
113134

114135
nd[varname] = da.astype(var.dtype)
115136

@@ -119,21 +140,29 @@ def insert_nan_where(self, condition):
119140
return nd
120141

121142
def drop_where(self, condition):
122-
"""Remove positions where condition is True, shifting rest of trajectory."""
143+
"""
144+
Remove positions where the condition is True, shifting the rest of the trajectory.
123145
146+
Parameters
147+
----------
148+
condition : xarray.DataArray
149+
Boolean condition indicating positions to drop.
150+
151+
Returns
152+
-------
153+
xarray.Dataset
154+
Dataset with positions removed where the condition is True.
155+
"""
124156
trajs = []
125157
newlen = 0
126158
for i in range(self.ds.sizes[self.trajectory_dim]):
127-
new = self.ds.isel(trajectory=i).drop_sel(obs=np.where(
128-
condition.isel(
129-
trajectory=i))[0]) # Dropping from given trajectory
159+
new = self.ds.isel(trajectory=i).drop_sel(obs=np.where(condition.isel(trajectory=i))[0])
130160
newlen = max(newlen, new.sizes[self.obs_dim])
131161
trajs.append(new)
132162

133-
# Ensure all trajectories have equal length, by padding with NaN at end
163+
# Ensure all trajectories have equal length by padding with NaN at the end
134164
trajs = [
135-
t.pad(
136-
pad_width={self.obs_dim: (0, newlen - t.sizes[self.obs_dim])})
165+
t.pad(pad_width={self.obs_dim: (0, newlen - t.sizes[self.obs_dim])})
137166
for t in trajs
138167
]
139168

@@ -261,11 +290,26 @@ def to_1d(self):
261290

262291
@__require_obs_dim__
263292
def gridtime(self, times, time_varname=None, round=True):
264-
if isinstance(times, str) or isinstance(
265-
times, pd.Timedelta): # Make time series with given interval
266-
if round is True:
267-
start_time = np.nanmin(np.asarray(
268-
self.ds.time.dt.floor(times)))
293+
"""
294+
Interpolate the dataset to a given time grid.
295+
296+
Parameters
297+
----------
298+
times : str, pandas.Timedelta, or numpy.ndarray
299+
Time grid to interpolate to. If a string or Timedelta, it specifies the interval.
300+
time_varname : str, optional
301+
Name of the time variable, by default the dataset's time variable.
302+
round : bool, optional
303+
Whether to round the start and end times to the nearest interval, by default True.
304+
305+
Returns
306+
-------
307+
xarray.Dataset
308+
Dataset interpolated to the specified time grid.
309+
"""
310+
if isinstance(times, (str, pd.Timedelta)): # Create time series with given interval
311+
if round:
312+
start_time = np.nanmin(np.asarray(self.ds.time.dt.floor(times)))
269313
end_time = np.nanmax(np.asarray(self.ds.time.dt.ceil(times)))
270314
else:
271315
start_time = np.nanmin(np.asarray(self.ds.time))
@@ -281,33 +325,28 @@ def gridtime(self, times, time_varname=None, round=True):
281325
time_varname = self.time_varname if time_varname is None else time_varname
282326

283327
d = None
284-
285328
for t in range(self.ds.sizes[self.trajectory_dim]):
286-
dt = self.ds.isel({self.trajectory_dim : t}) \
287-
.dropna(self.obs_dim, how='all')
288-
289-
dt = dt.assign_coords({self.obs_dim : dt[self.time_varname].values }) \
329+
dt = self.ds.isel({self.trajectory_dim: t}).dropna(self.obs_dim, how="all")
330+
dt = dt.assign_coords({self.obs_dim: dt[self.time_varname].values}) \
290331
.drop_vars(self.time_varname) \
291-
.rename({self.obs_dim : time_varname}) \
332+
.rename({self.obs_dim: time_varname}) \
292333
.set_index({time_varname: time_varname})
293334

294335
_, ui = np.unique(dt[time_varname], return_index=True)
295336
dt = dt.isel({time_varname: ui})
296-
dt = dt.isel(
297-
{time_varname: np.where(~pd.isna(dt[time_varname].values))[0]})
337+
dt = dt.isel({time_varname: np.where(~pd.isna(dt[time_varname].values))[0]})
298338

299339
if dt.sizes[time_varname] > 0:
300340
dt = dt.interp({time_varname: times})
301341
else:
302-
logger.warning(f"time dimension ({time_varname}) is zero size")
342+
logger.warning(f"Time dimension ({time_varname}) is zero size")
303343

304344
if d is None:
305345
d = dt.expand_dims(self.trajectory_dim)
306346
else:
307347
d = xr.concat((d, dt), self.trajectory_dim)
308348

309-
d = d.assign_coords(
310-
{self.trajectory_dim: self.ds[self.trajectory_dim]})
349+
d = d.assign_coords({self.trajectory_dim: self.ds[self.trajectory_dim]})
311350

312351
return d
313352

0 commit comments

Comments
 (0)