Skip to content

Commit b9b8ae1

Browse files
committed
cleaning up kernelds
1 parent b36074b commit b9b8ae1

File tree

2 files changed

+67
-30
lines changed

2 files changed

+67
-30
lines changed

mtpy/processing/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
from mth5 import RUN_SUMMARY_COLUMNS
1+
from mth5 import RUN_SUMMARY_DTYPE, RUN_SUMMARY_COLUMNS
22

3+
ADDED_KERNEL_DATASET_DTYPE = [
4+
("fc", bool),
5+
("remote", bool),
6+
("run_dataarray", object),
7+
("stft", object),
8+
("mth5_obj", object),
9+
]
310
ADDED_KERNEL_DATASET_COLUMNS = [
4-
"fc",
5-
"remote",
6-
"run_dataarray",
7-
"stft",
8-
"mth5_obj",
11+
entry[0] for entry in ADDED_KERNEL_DATASET_DTYPE
912
]
1013

11-
KERNEL_DATASET_COLUMNS = RUN_SUMMARY_COLUMNS + ADDED_KERNEL_DATASET_COLUMNS
14+
KERNEL_DATASET_DTYPE = RUN_SUMMARY_DTYPE + ADDED_KERNEL_DATASET_DTYPE
15+
KERNEL_DATASET_COLUMNS = [entry[0] for entry in KERNEL_DATASET_DTYPE]
1216

1317
MINI_SUMMARY_COLUMNS = [
1418
"survey",

mtpy/processing/kernel_dataset.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@
7474

7575
from mtpy.processing.run_summary import RunSummary
7676
from mtpy.processing import (
77-
ADDED_KERNEL_DATASET_COLUMNS,
78-
KERNEL_DATASET_COLUMNS,
77+
KERNEL_DATASET_DTYPE,
7978
MINI_SUMMARY_COLUMNS,
8079
)
8180

@@ -183,16 +182,7 @@ def df(self, value):
183182

184183
raise TypeError(msg)
185184

186-
need_columns = []
187-
for col in KERNEL_DATASET_COLUMNS:
188-
if not col in value.columns:
189-
need_columns.append(col)
190-
if need_columns:
191-
msg = f"DataFrame needs columns {', '.join(need_columns)}"
192-
logger.error(msg)
193-
raise ValueError(msg)
194-
195-
self._df = self._set_datetime_columns(value)
185+
self._df = self._set_datetime_columns(self._add_columns(value))
196186

197187
def _set_datetime_columns(self, df):
198188
"""
@@ -216,6 +206,28 @@ def clone_dataframe(self) -> pd.DataFrame:
216206
"""return a deep copy of dataframe"""
217207
return copy.deepcopy(self.df)
218208

209+
def _add_columns(self, df):
210+
"""
211+
add columns with appropriate dtypes
212+
"""
213+
214+
for col, dtype in KERNEL_DATASET_DTYPE:
215+
if not col in df.columns:
216+
if col in ["survey", "station", "run", "start", "end"]:
217+
raise ValueError(
218+
f"{col} must be a filled column in the dataframe"
219+
)
220+
else:
221+
if isinstance(dtype, object):
222+
df[col] = None
223+
else:
224+
df[col] = dtype(0)
225+
logger.warning(
226+
f"KernelDataset DataFrame needs column {col}, adding "
227+
f"and setting dtype to {dtype}."
228+
)
229+
return df
230+
219231
def from_run_summary(
220232
self,
221233
run_summary: RunSummary,
@@ -254,18 +266,14 @@ def from_run_summary(
254266
raise ValueError(msg)
255267

256268
# add columns column
257-
for col in ADDED_KERNEL_DATASET_COLUMNS:
258-
df[col] = None
259-
260-
df["fc"] = False
269+
df = self._add_columns(df)
261270

262271
# set remote reference
263-
df["remote"] = False
264272
if remote_station_id:
265273
cond = df.station == remote_station_id
266274
df.remote = cond
267275

268-
# be sure to set date time columns
276+
# be sure to set date time columns and restrict to simultaneous runs
269277
df = self._set_datetime_columns(df)
270278
if remote_station_id:
271279
df = self.restrict_run_intervals_to_simultaneous(df)
@@ -286,11 +294,6 @@ def mini_summary(self) -> pd.DataFrame:
286294
"""return a dataframe that fits in terminal"""
287295
return self.df[self._mini_summary_columns]
288296

289-
@property
290-
def print_mini_summary(self) -> None:
291-
"""prints a dataframe that (hopefully) fits in terminal"""
292-
logger.info(self.mini_summary)
293-
294297
@property
295298
def local_survey_id(self) -> str:
296299
"""return string label for local survey id"""
@@ -523,6 +526,8 @@ def get_station_metadata(self, local_station_id: str):
523526
assert len(run_ids) == len(sub_df)
524527

525528
# iterate over these runs, packing metadata into
529+
# get run metadata from the group object instead of loading the runTS
530+
# object, should be much faster.
526531
station_metadata = None
527532
for i, row in sub_df.iterrows():
528533
local_run_obj = self.get_run_object(row)
@@ -533,6 +538,34 @@ def get_station_metadata(self, local_station_id: str):
533538
station_metadata.add_run(run_metadata)
534539
return station_metadata
535540

541+
def get_run_object(
542+
self, index_or_row: Union[int, pd.Series]
543+
) -> mt_metadata.timeseries.Run:
544+
"""
545+
Gets the run object associated with a row of the df
546+
547+
Development Notes:
548+
TODO: This appears to be unused except by get_station_metadata.
549+
Delete or integrate if desired.
550+
- This has likely been deprecated by direct calls to
551+
run_obj = row.mth5_obj.from_reference(row.run_reference) in pipelines.
552+
553+
Parameters
554+
----------
555+
index_or_row: integer index of df, or pd.Series object
556+
557+
Returns
558+
-------
559+
run_obj: mt_metadata.timeseries.Run
560+
The run associated with the row of the df.
561+
"""
562+
if isinstance(index_or_row, int):
563+
row = self.df.loc[index_or_row]
564+
else:
565+
row = index_or_row
566+
run_obj = row.mth5_obj.from_reference(row.run_reference)
567+
return run_obj
568+
536569
@property
537570
def num_sample_rates(self) -> int:
538571
"""returns the number of unique sample rates in the dataframe"""

0 commit comments

Comments
 (0)