Skip to content

Commit 55f52d4

Browse files
committed
updating kernel dataset methods to return objects
1 parent b9b8ae1 commit 55f52d4

File tree

2 files changed

+81
-60
lines changed

2 files changed

+81
-60
lines changed

mtpy/processing/kernel_dataset.py

Lines changed: 58 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,17 @@ def __str__(self):
157157
def __repr__(self):
158158
return self.__str__()
159159

160+
# def __iter__(self):
161+
# """
162+
# Iterate over rows in the dataframe
163+
164+
# :return: DESCRIPTION
165+
# :rtype: TYPE
166+
167+
# """
168+
169+
# return self.df.iterrows()[0]
170+
160171
@property
161172
def df(self):
162173
return self._df
@@ -182,7 +193,9 @@ def df(self, value):
182193

183194
raise TypeError(msg)
184195

185-
self._df = self._set_datetime_columns(self._add_columns(value))
196+
self._df = self._add_duration_column(
197+
self._set_datetime_columns(self._add_columns(value)), inplace=False
198+
)
186199

187200
def _set_datetime_columns(self, df):
188201
"""
@@ -217,15 +230,15 @@ def _add_columns(self, df):
217230
raise ValueError(
218231
f"{col} must be a filled column in the dataframe"
219232
)
233+
234+
if isinstance(dtype, object):
235+
df[col] = None
220236
else:
221-
if isinstance(dtype, object):
222-
df[col] = None
223-
else:
224-
df[col] = dtype(0)
225-
logger.warning(
226-
f"KernelDataset DataFrame needs column {col}, adding "
227-
f"and setting dtype to {dtype}."
228-
)
237+
df[col] = dtype(0)
238+
logger.warning(
239+
f"KernelDataset DataFrame needs column {col}, adding "
240+
f"and setting dtype to {dtype}."
241+
)
229242
return df
230243

231244
def from_run_summary(
@@ -313,19 +326,32 @@ def local_survey_metadata(self) -> mt_metadata.timeseries.Survey:
313326
logger.warning(msg)
314327
return self.survey_metadata["0"]
315328

316-
def _add_duration_column(self) -> None:
329+
def _add_duration_column(self, df, inplace=True) -> None:
317330
"""adds a column to self.df with times end-start (in seconds)"""
318-
timedeltas = self.df.end - self.df.start
331+
332+
timedeltas = df.end - df.start
319333
durations = [x.total_seconds() for x in timedeltas]
320-
self.df["duration"] = durations
321-
return
334+
if inplace:
335+
df["duration"] = durations
336+
return df
337+
else:
338+
new_df = df.copy()
339+
new_df["duration"] = durations
340+
return new_df
322341

323-
def _update_duration_column(self) -> None:
342+
def _update_duration_column(self, inplace=True) -> None:
324343
"""calls add_duration_column (after possible manual manipulation of start/end"""
325-
self._add_duration_column()
344+
345+
if inplace:
346+
self._df = self._add_duration_column(self._df, inplace)
347+
else:
348+
return self._add_duration_column(self._df, inplace)
326349

327350
def drop_runs_shorter_than(
328-
self, minimum_duration: float, units="s"
351+
self,
352+
minimum_duration: float,
353+
units="s",
354+
inplace=True,
329355
) -> None:
330356
"""
331357
Drop runs from df that are inconsequentially short
@@ -344,12 +370,18 @@ def drop_runs_shorter_than(
344370
if units != "s":
345371
msg = "Expected units are seconds : units='s'"
346372
raise NotImplementedError(msg)
347-
if "duration" not in self.df.columns:
348-
self._add_duration_column()
373+
349374
drop_cond = self.df.duration < minimum_duration
350-
self.df.drop(self.df[drop_cond].index, inplace=True)
351-
self.df.reset_index(drop=True, inplace=True)
352-
return
375+
if inplace:
376+
self._update_duration_column(inplace)
377+
self.df.drop(self.df[drop_cond].index, inplace=inplace)
378+
self.df.reset_index(drop=True, inplace=True)
379+
return
380+
else:
381+
new_df = self._update_duration_column(inplace)
382+
new_df = self.df.drop(self.df[drop_cond].index)
383+
new_df.reset_index(drop=True, inplace=True)
384+
return new_df
353385

354386
def select_station_runs(
355387
self,
@@ -525,17 +557,13 @@ def get_station_metadata(self, local_station_id: str):
525557
run_ids = sub_df.run.unique()
526558
assert len(run_ids) == len(sub_df)
527559

528-
# iterate over these runs, packing metadata into
529-
# get run metadata from the group object instead of loading the runTS
530-
# object, should be much faster.
531-
station_metadata = None
560+
station_metadata = sub_df.mth5_obj[0].from_reference(
561+
sub_df.station_hdf5_reference[0]
562+
)
563+
station_metadata.runs = ListDict()
532564
for i, row in sub_df.iterrows():
533565
local_run_obj = self.get_run_object(row)
534-
if station_metadata is None:
535-
station_metadata = local_run_obj.station_metadata
536-
station_metadata.runs = ListDict()
537-
run_metadata = local_run_obj.metadata
538-
station_metadata.add_run(run_metadata)
566+
station_metadata.add_run(local_run_obj.metadata)
539567
return station_metadata
540568

541569
def get_run_object(
@@ -697,36 +725,6 @@ def add_columns_for_processing(self, mth5_objs) -> None:
697725
for i, station_id in enumerate(self.df["station"]):
698726
mth5_obj_column[i] = mth5_objs[station_id]
699727
self.df["mth5_obj"] = mth5_obj_column
700-
# for column_name in columns_to_add:
701-
# self.df[column_name] = None
702-
703-
# def get_run_object(
704-
# self, index_or_row: Union[int, pd.Series]
705-
# ) -> mt_metadata.timeseries.Run:
706-
# """
707-
# Gets the run object associated with a row of the df
708-
709-
# Development Notes:
710-
# TODO: This appears to be unused except by get_station_metadata.
711-
# Delete or integrate if desired.
712-
# - This has likely been deprecated by direct calls to
713-
# run_obj = row.mth5_obj.from_reference(row.run_reference) in pipelines.
714-
715-
# Parameters
716-
# ----------
717-
# index_or_row: integer index of df, or pd.Series object
718-
719-
# Returns
720-
# -------
721-
# run_obj: mt_metadata.timeseries.Run
722-
# The run associated with the row of the df.
723-
# """
724-
# if isinstance(index_or_row, int):
725-
# row = self.df.loc[index_or_row]
726-
# else:
727-
# row = index_or_row
728-
# run_obj = row.mth5_obj.from_reference(row.run_reference)
729-
# return run_obj
730728

731729
def close_mth5s(self) -> None:
732730
"""

tests/processing/test_kernel_dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def test_mini_summary(self):
8383
sorted(self.kd._mini_summary_columns), sorted(mini_df.columns)
8484
)
8585

86+
def test_str(self):
87+
mini_df = self.kd.mini_summary
88+
self.assertEqual(mini_df.head(), str(self.kd))
89+
8690
# @classmethod
8791
# def tearDownClass(self):
8892
# self.mth5_path.unlink()
@@ -193,13 +197,32 @@ def test_from_run_summary(self):
193197
self.assertIn("remote", self.kd.df.columns)
194198
with self.subTest("has fc column"):
195199
self.assertIn("fc", self.kd.df.columns)
200+
with self.subTest("has_duration"):
201+
self.assertFalse((self.kd.df.duration == 0).all())
202+
203+
with self.subTest("has_all_columns"):
204+
self.assertListEqual(
205+
sorted(self.kd.df.columns), sorted(KERNEL_DATASET_COLUMNS)
206+
)
196207

197208
def test_num_sample_rates(self):
198209
self.assertEqual(self.kd.num_sample_rates, 1)
199210

200211
def test_sample_rate(self):
201212
self.assertEqual(self.kd.sample_rate, 1)
202213

214+
def test_drop_runs_shorter_than(self):
215+
self.kd.drop_runs_shorter_than(8000)
216+
self.assertEqual((2, 20), self.kd.df.shape)
217+
218+
def test_survey_id(self):
219+
self.assertEqual(self.kd.local_survey_id, "test")
220+
221+
def test_update_duration_column_not_inplace(self):
222+
new_df = self.kd._update_duration_column(inplace=False)
223+
224+
self.assertTrue((new_df.duration == self.kd.df.duration).all())
225+
203226

204227
class TestKernelDatasetMethodsFail(unittest.TestCase):
205228
def setUp(self):

0 commit comments

Comments
 (0)