Skip to content

Commit b75beb7

Browse files
authored
Merge pull request #143 from knutfrode/dev
Trajectory dimension is not anymore hardcoded as 'trajectory', but is…
2 parents 391f3e2 + d290cde commit b75beb7

File tree

9 files changed

+129
-88
lines changed

9 files changed

+129
-88
lines changed

examples/example_parcels.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,18 @@
88
import trajan as ta
99

1010
ds = xr.open_dataset('../tests/test_data/parcels.zarr', engine='zarr')
11+
#%%
12+
# Print Xarray dataset
1113
print(ds)
12-
ds.traj.plot(land='mask', margin=2)
14+
15+
#%%
16+
# Print trajectory specific information about dataset
17+
print(ds.traj)
18+
19+
#%%
20+
# Basic plot
21+
ds.traj.plot(land='mask', margin=1)
22+
# TODO: we must allow no time dimension for the below to work
1323
#ds.mean('trajectory', skipna=True).traj.plot(color='r', label='Mean trajectory')
1424

1525
plt.show()

tests/test_convert_datalayout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def test_to2d(barents):
88
# print(barents)
9-
gr = barents.traj.gridtime('1H')
9+
gr = barents.traj.gridtime('1h')
1010
# print(gr)
1111

1212
assert gr.traj.is_1d()

tests/test_repr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def test_repr_1d(opendrift_sim):
55
repr = str(opendrift_sim.traj)
66
assert '2015-11-16T00:00' in repr
77
assert 'Timestep: 1:00:00' in repr
8-
assert "67 timesteps time['time'] (1D)" in repr
8+
assert "67 timesteps [obs_dim: time]" in repr
99

1010
def test_repr_2d(test_data):
1111
ds = xr.open_dataset(test_data / 'bug32.nc')

tests/test_skill_score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def test_barents_align(barents):
3838
b0 = barents.isel(trajectory=0)
3939

4040
# assert b0.sizes['trajectory'] == 1
41-
assert barents.sizes['trajectory'] == 2
41+
assert barents.sizes[barents.traj.trajectory_dim] == 2
4242

4343
(b01, _) = xr.broadcast(b0, barents)
44-
b01 = b01.transpose('trajectory', ...)
44+
b01 = b01.transpose(b01.traj.trajectory_dim, ...)
4545

4646
np.testing.assert_allclose(b01.isel(trajectory=0).lon, barents.isel(trajectory=0).lon)
4747
np.testing.assert_allclose(b01.isel(trajectory=1).lon, barents.isel(trajectory=0).lon)

trajan/accessor.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,54 @@ def detect_time_variable(ds, obs_dim):
2222

2323
raise ValueError("No time variable detected")
2424

25+
def detect_trajectory_dim(ds):
26+
logger.debug('Detecting trajectory dimension')
27+
if 'trajectory_id' in ds.cf.cf_roles: # This is the proper CF way
28+
trajectory_var = ds.cf['trajectory_id']
29+
if trajectory_var.name in trajectory_var.sizes: # Check that this is a dimension
30+
return trajectory_var.name
31+
else:
32+
if len(trajectory_var.dims) > 1:
33+
logger.warning(f'trajectory_id is {trajectory_var.name}, but dimensions '
34+
'have other names: {str(list(trajectory_var.sizes))}')
35+
elif len(trajectory_var.dims) == 1: # Using the single dimension name
36+
return list(trajectory_var.sizes)[0]
37+
else:
38+
logger.warning('Single trajectory, a trajectory dimension will be added')
39+
return None
40+
41+
logger.warning('No trajectory_id attribute/variable found, trying to identify by name.')
42+
tx = detect_tx_variable(ds)
43+
for tdn in ['trajectory', 'traj']: # Common names of trajectory dimension
44+
if tdn in tx.dims:
45+
return tdn
46+
47+
return None # Did not succeed in detecting trajectory dimension
48+
2549

2650
@xr.register_dataset_accessor("traj")
2751
class TrajA(Traj):
2852
def __new__(cls, ds):
29-
if 'traj' in ds.dims:
30-
logger.info(
31-
'Normalizing dimension name from "traj" to "trajectory".')
32-
ds = ds.rename({'traj': 'trajectory'})
3353

34-
if 'trajectory' not in ds.dims: # Add empty trajectory dimension, if single trajectory
35-
ds = ds.expand_dims({'trajectory': 1})
36-
ds['trajectory'].attrs['cf_role'] = 'trajectory_id'
54+
trajectory_dim = detect_trajectory_dim(ds)
55+
56+
if trajectory_dim is None:
57+
if 'trajectory_id' in ds.cf.cf_roles:
58+
trajectory_id = ds.cf.cf_roles['trajectory_id']
59+
if len(trajectory_id) > 1:
60+
raise ValueError(f'Dataset has several trajectory_id variables: {trajectory_id}')
61+
else:
62+
trajectory_id = trajectory_id[0]
63+
logger.warning(f'Using trajectory_id variable name ({trajectory_id}) '
64+
'as trajectory dimension name')
65+
trajectory_dim = trajectory_id
66+
ds = ds.set_coords(trajectory_dim)
67+
ds = ds.expand_dims(trajectory_dim, create_index_for_new_dim=False)
68+
else:
69+
logger.debug('Creating new trajectory dimension "trajectory"')
70+
trajectory_dim = 'trajectory'
71+
ds = ds.expand_dims({trajectory_dim: 1})
72+
ds[trajectory_dim].attrs['cf_role'] = 'trajectory_id'
3773

3874
obs_dim = None
3975
time_varname = None
@@ -65,8 +101,7 @@ def __new__(cls, ds):
65101
else:
66102
raise ValueError(f"cannot deduce the timecoord; we have the following candidates: {with_standard_name_time_and_dim_index = }")
67103

68-
# discover the trajectorycoord variable name #################
69-
trajectorycoord = ds.cf["trajectory_id"].name
104+
# KFD TODO: the below detection should be generalized to dynamic dimension names
70105

71106
# discover the "rowsize" variable name #######################
72107
# NOTE: this is probably not standard; something to point to the CF conventions? should we need a standard_name for this, instead of the following heuristics?
@@ -82,10 +117,10 @@ def __new__(cls, ds):
82117
raise ValueError("mismatch between the index length and the sum of the deduced trajectory lengths")
83118

84119
logger.debug(
85-
f"1D storage dataset; detected: {obs_dim = }, {timecoord = }, {trajectorycoord = }, {rowsizevar}"
120+
f"1D storage dataset; detected: {obs_dim = }, {timecoord = }, {trajectory_dim = }, {rowsizevar}"
86121
)
87122

88-
return ocls(ds, obs_dim, timecoord, trajectorycoord, rowsizevar)
123+
return ocls(ds, trajectory_dim, obs_dim, timecoord, rowsizevar)
89124

90125
else:
91126
logging.warning(f"{ds} has {tx.dims = } which is of dimension 1 but is not index; this is a bit unusual; try to parse with Traj1d or Traj2d")
@@ -137,4 +172,5 @@ def __new__(cls, ds):
137172
f'Time variable has more than two dimensions: {ds[time_varname].shape}'
138173
)
139174

140-
return ocls(ds, obs_dim, time_varname)
175+
# TODO: The provided attributes could perhaps be added here before returning
176+
return ocls(ds, trajectory_dim, obs_dim, time_varname)

trajan/ragged.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,17 @@
1111
class ContiguousRagged(Traj):
1212
"""An unstructured dataset, where each trajectory may have observations at different times, and all the data for the different trajectories are stored in single arrays with one dimension, contiguously, one trajectory after the other. Typically from a collection of drifters. This class convert continous ragged datasets into 2d datasets, so that the Traj2d methods can be leveraged."""
1313

14-
trajdim: str
15-
rowvar: str
14+
rowvar: str # TODO: Should have a more precise name than rowvar
1615

17-
def __init__(self, ds, obs_dim, time_varname, trajectorycoord, rowsizevar):
18-
self.trajdim = trajectorycoord
16+
def __init__(self, ds, trajectory_dim, obs_dim, time_varname, rowsizevar):
1917
self.rowvar = rowsizevar
20-
super().__init__(ds, obs_dim, time_varname)
18+
super().__init__(ds, trajectory_dim, obs_dim, time_varname)
2119

2220
def to_2d(self, obs_dim='obs'):
2321
"""This actually converts a contiguous ragged xarray Dataset into an xarray Dataset that follows the Traj2d conventions."""
2422
global_attrs = self.ds.attrs
2523

26-
nbr_trajectories = len(self.ds[self.trajdim])
24+
nbr_trajectories = len(self.ds[self.trajectory_dim])
2725

2826
# find the longest trajectory
2927
longest_trajectory = np.max(self.ds[self.rowvar].to_numpy())
@@ -32,7 +30,7 @@ def to_2d(self, obs_dim='obs'):
3230

3331
# the trajectory dimension special case (as it is a different kind, and has a different dim than other variables)
3432

35-
array_instruments = self.ds[self.trajdim].to_numpy()
33+
array_instruments = self.ds[self.trajectory_dim].to_numpy()
3634

3735
# the time var (special case as it is of a different type)
3836

trajan/traj.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,37 +41,40 @@ class Traj:
4141

4242
__gcrs__: pyproj.CRS
4343

44-
def __init__(self, ds, obs_dim, time_varname):
44+
def __init__(self, ds, trajectory_dim, obs_dim, time_varname):
4545
self.ds = ds
4646
self.__plot__ = None
4747
self.__animate__ = None
4848
self.__gcrs__ = pyproj.CRS.from_epsg(4326)
49+
self.trajectory_dim = trajectory_dim # name of trajectory dimension
4950
self.obs_dim = obs_dim # dimension along which time increases
5051
self.time_varname = time_varname
5152

5253
def __repr__(self):
5354
output = '=======================\n'
5455
output += 'TrajAn info:\n'
5556
output += '------------\n'
56-
output += f'{self.ds.sizes["trajectory"]} trajectories\n'
57-
if 'time' in self.ds.variables:
58-
if self.time_varname in self.ds.sizes:
59-
output += f'{self.ds.sizes[self.time_varname]} timesteps'
60-
timevar = self.ds[self.time_varname]
61-
output += f' {timevar.name}{list(timevar.sizes)} ({len(timevar.sizes)}D)\n'
57+
if self.trajectory_dim is None:
58+
output += 'Single trajectory (no trajectory dimension)\n'
59+
else:
60+
output += f'{self.ds.sizes[self.trajectory_dim]} trajectories [trajectory_dim: {self.trajectory_dim}]\n'
61+
if self.time_varname is not None:
62+
output += f'{self.ds.sizes[self.obs_dim]} timesteps [obs_dim: {self.obs_dim}]\n'
63+
timevar = self.ds[self.time_varname]
64+
output += f'Time variable: {timevar.name}{list(timevar.sizes)} ({len(timevar.sizes)}D)\n'
6265
try:
6366
timestep = self.timestep()
6467
timestep = timedelta(seconds=int(timestep))
6568
except:
6669
timestep = '[self.timestep returns error]' # TODO
6770
output += f'Timestep: {timestep}\n'
68-
start_time = self.ds.time.min().data
69-
end_time = self.ds.time.max().data
71+
start_time = self.ds.time.min(skipna=True).data
72+
end_time = self.ds.time.max(skipna=True).data
7073
output += f'Time coverage: {start_time} - {end_time}\n'
7174
else:
7275
output += f'Dataset has no time variable'
73-
output += f'Longitude span: {self.tx.min().data} to {self.tx.max().data}\n'
74-
output += f'Latitude span: {self.ty.min().data} to {self.ty.max().data}\n'
76+
output += f'Longitude span: {self.tx.min(skipna=True).data} to {self.tx.max(skipna=True).data}\n'
77+
output += f'Latitude span: {self.ty.min(skipna=True).data} to {self.ty.max(skipna=True).data}\n'
7578
output += 'Variables:\n'
7679
for var in self.ds.variables:
7780
if var not in ['trajectory', self.obs_dim]:
@@ -352,31 +355,23 @@ def assign_cf_attrs(self,
352355
"""
353356
ds = self.ds.copy(deep=True)
354357

355-
ds['trajectory'] = ds['trajectory'].astype(str)
356-
ds['trajectory'].attrs = {
358+
ds[self.trajectory_dim] = ds[self.trajectory_dim].astype(str)
359+
ds[self.trajectory_dim].attrs = {
357360
'cf_role': 'trajectory_id',
358361
'long_name': 'trajectory name'
359362
}
360363

361364
ds = ds.assign_attrs({
362-
'Conventions':
363-
'CF-1.10',
364-
'featureType':
365-
'trajectory',
366-
'geospatial_lat_min':
367-
np.nanmin(self.tlat),
368-
'geospatial_lat_max':
369-
np.nanmax(self.tlat),
370-
'geospatial_lon_min':
371-
np.nanmin(self.tlon),
372-
'geospatial_lon_max':
373-
np.nanmax(self.tlon),
374-
'time_coverage_start':
375-
pd.to_datetime(
365+
'Conventions': 'CF-1.10',
366+
'featureType': 'trajectory',
367+
'geospatial_lat_min': np.nanmin(self.tlat),
368+
'geospatial_lat_max': np.nanmax(self.tlat),
369+
'geospatial_lon_min': np.nanmin(self.tlon),
370+
'geospatial_lon_max': np.nanmax(self.tlon),
371+
'time_coverage_start': pd.to_datetime(
376372
np.nanmin(ds['time'].values[ds['time'].values != np.datetime64(
377373
'NaT')])).isoformat(),
378-
'time_coverage_end':
379-
pd.to_datetime(
374+
'time_coverage_end': pd.to_datetime(
380375
np.nanmax(ds['time'].values[ds['time'].values != np.datetime64(
381376
'NaT')])).isoformat(),
382377
})
@@ -479,24 +474,25 @@ def distance_to(self, other) -> xr.Dataset:
479474
"""
480475

481476
other = other.broadcast_like(self.ds)
477+
482478
geod = pyproj.Geod(ellps='WGS84')
483-
az_fwd, a2, distance = geod.inv(self.ds.traj.tlon, self.ds.traj.tlat,
479+
az_fwd, a2, distance = geod.inv(self.tlon, self.tlat,
484480
other.traj.tlon, other.traj.tlat)
485481

486482
ds = xr.Dataset()
487483
ds['distance'] = xr.DataArray(distance,
488484
name='distance',
489-
coords=self.ds.traj.tlon.coords,
485+
coords=self.tlon.coords,
490486
attrs={'units': 'm'})
491487

492488
ds['az_fwd'] = xr.DataArray(az_fwd,
493489
name='forward azimuth',
494-
coords=self.ds.traj.tlon.coords,
490+
coords=self.tlon.coords,
495491
attrs={'units': 'degrees'})
496492

497493
ds['az_bwd'] = xr.DataArray(a2,
498494
name='back azimuth',
499-
coords=self.ds.traj.tlon.coords,
495+
coords=self.tlon.coords,
500496
attrs={'units': 'degrees'})
501497

502498
return ds

trajan/traj1d.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ class Traj1d(Traj):
1414
A structured dataset, where each trajectory is always given at the same times. Typically the output from a model or from a gridded dataset.
1515
"""
1616

17-
def __init__(self, ds, obs_dim, time_varname):
18-
super().__init__(ds, obs_dim, time_varname)
17+
def __init__(self, ds, trajectory_dim, obs_dim, time_varname):
18+
super().__init__(ds, trajectory_dim, obs_dim, time_varname)
1919

2020
def timestep(self):
2121
"""Time step between observations in seconds."""
@@ -32,7 +32,8 @@ def to_2d(self, obs_dim='obs'):
3232
ds = self.ds.copy()
3333
time = ds[self.time_varname].rename({
3434
self.time_varname: obs_dim
35-
}).expand_dims(dim={'trajectory': ds.sizes['trajectory']})
35+
}).expand_dims(dim={self.trajectory_dim: ds.sizes[self.trajectory_dim]})
36+
# TODO should also add cf_role here
3637
ds = ds.rename({self.time_varname: obs_dim})
3738
ds[self.time_varname] = time
3839
ds[obs_dim] = np.arange(0, ds.sizes[obs_dim])
@@ -45,7 +46,7 @@ def time_to_next(self):
4546

4647
def velocity_spectrum(self):
4748

48-
if self.ds.sizes['trajectory'] > 1:
49+
if self.ds.sizes[self.trajectory_dim] > 1:
4950
raise ValueError(
5051
'Spectrum can only be calculated for a single trajectory')
5152

@@ -77,7 +78,7 @@ def rotary_spectrum(self):
7778
### TODO unfinished method
7879

7980
from .tools import rotary_spectra
80-
if self.ds.sizes['trajectory'] > 1:
81+
if self.ds.sizes[self.trajectory_dim] > 1:
8182
raise ValueError(
8283
'Spectrum can only be calculated for a single trajectory')
8384

@@ -95,9 +96,9 @@ def rotary_spectrum(self):
9596
plt.show()
9697

9798
def skill(self, other, method='liu-weissberg', **kwargs):
98-
if self.ds.sizes['trajectory'] != other.sizes['trajectory']:
99+
if self.ds.sizes[self.trajectory_dim] != other.sizes[other.traj.trajectory_dim]:
99100
raise ValueError(
100-
f"There must be the same number of trajectories in the two datasets that are compared. This dataset: {self.ds.sizes['trajectory']}, other: {other.sizes['trajectory']}."
101+
f"There must be the same number of trajectories in the two datasets that are compared. This dataset: {self.ds.sizes[self.trajectory_dim]}, other: {other.sizes[other.traj.trajectory_dim]}."
101102
)
102103

103104
diff = np.max(
@@ -110,31 +111,31 @@ def skill(self, other, method='liu-weissberg', **kwargs):
110111
f"The two datasets must have approximately equal time coordinates, maximum difference: {diff} seconds. Consider using `gridtime` to interpolate one of the datasets."
111112
)
112113

113-
s = np.zeros((self.ds.sizes['trajectory']), dtype=np.float32)
114+
s = np.zeros((self.ds.sizes[self.trajectory_dim]), dtype=np.float32)
114115

115116
# ds = self.ds.dropna(dim=self.obs_dim)
116117
# other = other.dropna(dim=other.traj.obs_dim)
117118

118-
ds = self.ds.transpose('trajectory', self.obs_dim, ...)
119-
other = other.transpose('trajectory', other.traj.obs_dim, ...)
119+
ds = self.ds.transpose(self.trajectory_dim, self.obs_dim, ...)
120+
other = other.transpose(other.traj.trajectory_dim, other.traj.obs_dim, ...)
120121

121-
lon0 = ds.traj.tlon
122+
lon0 = ds.traj.tlon # TODO should be self.tlon ?
122123
lat0 = ds.traj.tlat
123124
lon1 = other.traj.tlon
124125
lat1 = other.traj.tlat
125126

126127
for ti in range(0, len(s)):
127128
if method == 'liu-weissberg':
128-
s[ti] = skill.liu_weissberg(lon0.isel(trajectory=ti),
129-
lat0.isel(trajectory=ti),
130-
lon1.isel(trajectory=ti),
131-
lat1.isel(trajectory=ti), **kwargs)
129+
s[ti] = skill.liu_weissberg(lon0.isel({self.trajectory_dim: ti}),
130+
lat0.isel({self.trajectory_dim: ti}),
131+
lon1.isel({self.trajectory_dim: ti}),
132+
lat1.isel({self.trajectory_dim: ti}), **kwargs)
132133
else:
133134
raise ValueError(f"Unknown skill-score method: {method}.")
134135

135136
return xr.DataArray(s,
136137
name='Skillscore',
137-
coords={'trajectory': self.ds.trajectory},
138+
coords={self.trajectory_dim: self.ds.trajectory},
138139
attrs={'method': method})
139140

140141
def seltime(self, t0=None, t1=None):

0 commit comments

Comments
 (0)