Skip to content
43 changes: 43 additions & 0 deletions examples/example_find_positions_at_obs_times.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Finding the closeset or interpolated positions at given times (e.g. wave observations).
=======================================================================================
"""

from pathlib import Path
from trajan.readers.omb import read_omb_csv
import xarray as xr
import coloredlogs

coloredlogs.install(level='debug')

#%%
# Read the data
data = Path.cwd().parent / "tests" / "test_data" / "csv" / "omb3.csv"
ds = read_omb_csv(data)
print(ds)

#%%
# The wave data in the variable `pHs0` is given along a different observation dimension.
# Because it is also a observation, i.e. in the style of 2D trajectory
# datasets, we need to iterate over the trajectories:


def gridwaves(tds):
t = tds[['lat', 'lon',
'time']].traj.gridtime(tds['time_waves_imu'].squeeze())
return t.traj.to_2d(obsdim='obs_waves_imu')


dsw = ds.groupby('trajectory').map(gridwaves)

print(dsw)

#%%
# We now have the positions interpolated to the IMU (wave) observations. We
# could also merge these together to one dataset again:

ds = xr.merge((ds, dsw.rename({
'lon': 'lon_waves',
'lat': 'lat_waves'
}).drop('time')))
print(ds)
58 changes: 58 additions & 0 deletions examples/example_plot_spectra_accessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Plotting wave spectra data, using the accessor syntax
================================================
"""

# %%

exit()

# %%

ipython3

# %%

from pathlib import Path
from trajan.readers.omb import read_omb_csv
from trajan.plot.spectra import plot_trajan_spectra
import coloredlogs
import datetime
import matplotlib.pyplot as plt

# adjust the level of information printed
# coloredlogs.install(level='error')
coloredlogs.install(level='debug')

# %%

# load the data from an example file with several buoys and a bit of wave spectra data
path_to_test_data = Path.cwd().parent / "tests" / "test_data" / "csv" / "omb3.csv"
xr_data = read_omb_csv(path_to_test_data)

# %%

# if no axis is provided, an axis will be generated automatically

xr_data.isel(trajectory=0).processed_elevation_energy_spectrum.wave.plot(
xr_data.isel(trajectory=0).time_waves_imu.squeeze(),
)

plt.show()

# %%

# it is also possible to provide an axis on which to plot

# a plot with 3 lines, 2 columns
fig, ax = plt.subplots(3, 2)

ax_out = xr_data.isel(trajectory=0).processed_elevation_energy_spectrum.wave.plot(
xr_data.isel(trajectory=0).time_waves_imu.squeeze(),
# plot on the second line, first column
ax=ax[1, 0]
)

plt.show()

# %%
15 changes: 15 additions & 0 deletions tests/test_convert_datalayout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from numpy.testing import assert_almost_equal
import numpy as np
import trajan as ta
import xarray as xr
import pandas as pd

def test_to2d(barents):
# print(barents)
gr = barents.traj.gridtime('1H')
# print(gr)

assert gr.traj.is_1d()

b2d = gr.traj.to_2d()
assert b2d.traj.is_2d()
17 changes: 17 additions & 0 deletions tests/test_plot_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,20 @@ def test_plot_spectra_withargs(test_data, tmpdir, plot):
plt.show()

plt.close('all')


def test_plot_spectra_accessor(test_data, plot):
csv_in = test_data / 'csv/omb3.csv'
ds = read_omb_csv(csv_in)
print(ds)
print(ds.elevation_energy_spectrum)
print(ds.frequencies_waves_imu)

plt.figure()
ds.isel(trajectory=0).elevation_energy_spectrum.wave.plot(
ds.isel(trajectory=0).time_waves_imu.squeeze())

if plot:
plt.show()

plt.close('all')
2 changes: 2 additions & 0 deletions trajan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from . import skill as _
from . import readers as _

from . import waves as _

logger = logging.getLogger(__name__)

__version__ = importlib.metadata.version("trajan")
Expand Down
6 changes: 3 additions & 3 deletions trajan/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, ds, obsdim, timedim, trajectorycoord, rowsizevar):
self.rowvar = rowsizevar
super().__init__(ds, obsdim, timedim)

def to_2d(self):
def to_2d(self, obsdim='obs'):
"""This actually converts a contiguous ragged xarray Dataset into an xarray Dataset that follows the Traj2d conventions."""
global_attrs = self.ds.attrs

Expand Down Expand Up @@ -70,7 +70,7 @@ def to_2d(self):

# trajectory vars
'time':
xr.DataArray(dims=["trajectory", "obs"],
xr.DataArray(dims=["trajectory", obsdim],
data=array_time,
attrs={
"standard_name": "time",
Expand Down Expand Up @@ -119,7 +119,7 @@ def to_2d(self):
crrt_data_var = "lat"

ds_converted_to_traj2d[crrt_data_var] = \
xr.DataArray(dims=["trajectory", "obs"],
xr.DataArray(dims=["trajectory", obsdim],
data=crrt_var,
attrs=attrs)

Expand Down
2 changes: 1 addition & 1 deletion trajan/traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def condense_obs(self) -> xr.Dataset:
"""

@abstractmethod
def to_2d(self) -> xr.Dataset:
def to_2d(self, obsdim='obs') -> xr.Dataset:
"""
Convert dataset into a 2D dataset from.
"""
11 changes: 11 additions & 0 deletions trajan/traj1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ def is_1d(self):
def is_2d(self):
return False

def to_2d(self, obsdim='obs'):
ds = self.ds.copy()
time = ds[self.timedim].rename({
self.timedim: obsdim
}).expand_dims(dim={'trajectory': ds.sizes['trajectory']})
ds = ds.rename({self.timedim: obsdim})
ds[self.timedim] = time
ds[obsdim] = np.arange(0, ds.sizes[obsdim])

return ds

def time_to_next(self):
time_step = self.ds.time[1] - self.ds.time[0]
return time_step
Expand Down
28 changes: 28 additions & 0 deletions trajan/waves/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import xarray as xr
import logging
import numpy as np

from .plot import Plot

# recommended by cf-xarray
xr.set_options(keep_attrs=True)

logger = logging.getLogger(__name__)


@xr.register_dataarray_accessor('wave')
class Wave:
def __init__(self, ds):
self.ds = ds
self.__plot__ = None

@property
def plot(self) -> Plot:
"""
See :class:`trajan.waves.Plot`.
"""
if self.__plot__ is None:
logger.debug(f'Setting up new plot object.')
self.__plot__ = Plot(self.ds)

return self.__plot__
119 changes: 119 additions & 0 deletions trajan/waves/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import logging
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
import cf_xarray as _

from trajan.accessor import detect_time_dim

logger = logging.getLogger(__name__)
logging.getLogger('matplotlib.font_manager').disabled = True


class Plot:
ds: xr.Dataset

# A lon-lat projection with the currently used globe.
gcrs = None

DEFAULT_LINE_COLOR = 'gray'

def __init__(self, ds):
self.ds = ds

def __call__(self, *args, **kwargs):
if self.ds.attrs['standard_name'] == 'sea_surface_wave_variance_spectral_density':
return self.spectra(*args, **kwargs)
else:
raise ValueError('Unknown wave variable')

def spectra(self, time, *args, **kwargs):
"""
Plot the wave spectra information from a trajan compatible xarray.

Args:

time: DataArray with times.

vrange: can be either:
- None to use the default log range [-3.0, 1.0]
- a tuple of float to set the log range explicitely

`nseconds_gap`: float
Number of seconds between 2 consecutive
spectra for one instrument above which we consider that there is a
data loss that should be filled with NaN. This is to avoid "stretching"
neighboring spectra over long times if an instrument gets offline.

Returns:

ax: plt.Axes
"""
vrange = kwargs.pop('vrange', None)
nseconds_gap = kwargs.pop('nseconds_gap', 6 * 3600)

# TODO: is there a better solution for the following?
# NOTE: we would rather like to do something simpler, like:
# ax = kwargs.pop('ax', plt.axes())
# NOTE: but it seems that calling plt.axes() durinig arg evaluation messes things up, so doing instead:
if 'ax' in kwargs:
ax = kwargs.pop('ax')
else:
ax = plt.axes()

if vrange is None:
vmin_pcolor = -3.0
vmax_pcolor = 1.0
else:
vmin_pcolor = vrange[0]
vmax_pcolor = vrange[1]

spectra_frequencies = self.ds.cf['wave_frequency']

crrt_spectra = self.ds.to_numpy()
# crrt_spectra_times = detect_time_dim(self.ds, 'obs_waves_imu').to_numpy()
crrt_spectra_times = time.to_numpy()

list_datetimes = []
list_spectra = []

# avoid streching at the left
list_datetimes.append(
crrt_spectra_times[0] - np.timedelta64(2, 'm'))
list_spectra.append(np.full(len(spectra_frequencies), np.nan))

for crrt_spectra_ind in range(1, crrt_spectra.shape[0], 1):
if np.isnan(crrt_spectra_times[crrt_spectra_ind]):
continue

# if a gap with more than nseconds_gap seconds, fill with NaNs
# to avoid stretching neighbors over missing data
seconds_after_previous = float(
crrt_spectra_times[crrt_spectra_ind] - crrt_spectra_times[crrt_spectra_ind-1]) / 1e9
if seconds_after_previous > nseconds_gap:
logger.debug(
f"spectrum index {crrt_spectra_ind} is {seconds_after_previous} seconds \
after the previous one; insert nan spectra in between to avoid stretching")
list_datetimes.append(
crrt_spectra_times[crrt_spectra_ind-1] + np.timedelta64(2, 'h'))
list_spectra.append(
np.full(len(spectra_frequencies), np.nan))
list_datetimes.append(
crrt_spectra_times[crrt_spectra_ind] - np.timedelta64(2, 'h'))
list_spectra.append(
np.full(len(spectra_frequencies), np.nan))

list_spectra.append(crrt_spectra[crrt_spectra_ind, :])
list_datetimes.append(crrt_spectra_times[crrt_spectra_ind])

# avoid stretching at the right
last_datetime = list_datetimes[-1]
list_datetimes.append(last_datetime + np.timedelta64(2, 'm'))
list_spectra.append(np.full(len(spectra_frequencies), np.nan))

pclr = ax.pcolor(list_datetimes, spectra_frequencies, np.log10(
np.transpose(np.array(list_spectra))), vmin=vmin_pcolor, vmax=vmax_pcolor)

return ax

Loading