Skip to content

Commit 915aeae

Browse files
Clean up implementation a lot
1 parent a8f2090 commit 915aeae

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

intake_esm/cat.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,8 @@ def _df_from_file(
321321
cat.catalog_file = csv_path
322322

323323
reader = CatalogFileDataReader(cat.catalog_file, storage_options, **read_kwargs)
324-
read = reader()
325-
self._iterable_dtype_map = reader._dtype_map
326-
return read
324+
self._iterable_dtype_map = reader.dtype_map
325+
return reader.frames
327326

328327
@property
329328
def lf(self) -> pl.LazyFrame:
@@ -610,8 +609,8 @@ def __init__(
610609
f'Expected one of {__filetypes__}'
611610
)
612611

613-
# Set default dtype_map to tuple
614-
self._dtype_map = {key: 'tuple' for key in self.read_kwargs.get('converters', {}).keys()}
612+
self._dtype_map: dict[str, str] = {}
613+
self.frames = self._read()
615614

616615
def _read_csv_pd(self) -> FramesModel:
617616
"""Read a catalog file stored as a csv using pandas"""
@@ -653,17 +652,14 @@ def _read_csv_pl(self) -> FramesModel:
653652
)
654653
.collect()
655654
.to_dicts()
656-
):
655+
): # Returns an empty list if no rows - hence walrus
657656
self._dtype_map = dtype_map[0]
658657

659658
lf = lf.with_columns(
660659
[
661660
pl.col(colname)
662661
.str.replace('^.', '[') # Replace first/last chars with [ or ].
663662
.str.replace('.$', ']') # set/tuple => list
664-
# ^ We also need to cache - probably as an attriubte on this class
665-
# what we found ie. '[' => list, '(' => tuple, etc., so we can write
666-
# the correct type back when we serialise the catalog. # TODO
667663
.str.replace_all("'", '"')
668664
.str.json_decode() # This is to do with the way polars reads json - single versus double quotes
669665
for colname in converters.keys()
@@ -680,7 +676,7 @@ def _read_parquet_pl(self) -> FramesModel:
680676
)
681677
return FramesModel(lf=lf)
682678

683-
def __call__(self):
679+
def _read(self):
684680
if self.driver == 'polars':
685681
if self.filetype == 'csv':
686682
return self._read_csv_pl()
@@ -694,3 +690,8 @@ def __call__(self):
694690
return self._read_csv_pd()
695691
else:
696692
raise ValueError(f'Unsupported file type {self.filetype} for pandas reader')
693+
694+
@property
695+
def dtype_map(self) -> dict[str, str]:
696+
"""Return a map of column names to their dtypes for columns with iterables."""
697+
return self._dtype_map

0 commit comments

Comments
 (0)