From 2b2292636f5e62e1fc883e656ddedb62666eb786 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 4 Sep 2025 10:21:07 -0500 Subject: [PATCH 1/6] Fix writing sorted data --- src/opencosmo/dataset/dataset.py | 9 +++------ src/opencosmo/dataset/state.py | 32 ++++++++++++++++++++++++++++---- test/test_dataset.py | 11 +++++++++++ 3 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/opencosmo/dataset/dataset.py b/src/opencosmo/dataset/dataset.py index 467f91e3..dbd02d3e 100644 --- a/src/opencosmo/dataset/dataset.py +++ b/src/opencosmo/dataset/dataset.py @@ -15,7 +15,7 @@ import numpy as np from astropy import units # type: ignore from astropy.cosmology import Cosmology # type: ignore -from astropy.table import QTable # type: ignore +from astropy.table import Column, QTable # type: ignore from opencosmo.dataset.column import ColumnMask, DerivedColumn from opencosmo.dataset.state import DatasetState @@ -246,14 +246,11 @@ def get_data(self, output="astropy", unpack=True) -> OpenCosmoData: data = data[cn] if output == "numpy": - if isinstance(data, u.Quantity): + if isinstance(data, (u.Quantity, Column)): data = data.value elif isinstance(data, (QTable, dict)): data = dict(data) - is_quantity = filter( - lambda v: isinstance(data[v], u.Quantity), data.keys() - ) - for colname in is_quantity: + for colname in data: data[colname] = data[colname].value if isinstance(data, dict) and len(data) == 1: diff --git a/src/opencosmo/dataset/state.py b/src/opencosmo/dataset/state.py index 61bb3e55..3e005bc9 100644 --- a/src/opencosmo/dataset/state.py +++ b/src/opencosmo/dataset/state.py @@ -19,6 +19,17 @@ from opencosmo.dataset.handler import DatasetHandler +def make_sorted_index( + state: "DatasetState", handler: "DatasetHandler", column: str, invert: bool +): + sort_by_column = state.select(column).get_data(handler, ignore_sort=True) + idx = np.argsort(sort_by_column) + if invert: + idx = idx[::-1] + index = SimpleIndex(idx) + return index + + class DatasetState: """ Holds mutable state required by the dataset. Cleans up the dataset to mostly focus @@ -89,13 +100,13 @@ def get_data( data = self.__build_derived_columns(data) data_columns = set(data.columns) + if not ignore_sort and self.__sort_by is not None: + data.sort(self.__sort_by[0], reverse=self.__sort_by[1]) if ( self.__hidden and not self.__hidden.intersection(data_columns) == data_columns ): data.remove_columns(self.__hidden) - if not ignore_sort and self.__sort_by is not None: - data.sort(self.__sort_by[0], reverse=self.__sort_by[1]) return data def with_index(self, index: DataIndex): @@ -126,7 +137,13 @@ def make_schema( self, handler: "DatasetHandler", header: Optional[OpenCosmoHeader] = None ): builder_names = set(self.__builder.columns) - schema = handler.prep_write(self.__index, builder_names - self.__hidden, header) + if self.__sort_by is not None: + index = make_sorted_index(self, handler, *self.__sort_by) + else: + index = self.__index + + schema = handler.prep_write(index, builder_names - self.__hidden, header) + derived_names = set(self.__derived.keys()) - self.__hidden derived_data = ( self.select(derived_names) @@ -254,11 +271,17 @@ def select(self, columns: str | Iterable[str]): if isinstance(columns, str): columns = [columns] + hide_sort = False + columns = set(columns) known_builders = set(self.__builder.columns) known_derived = set(self.__derived.keys()) known_im = set(self.__im_handler.keys()) + if self.__sort_by is not None: + hide_sort = self.__sort_by[0] not in columns + columns.add(self.__sort_by[0]) + unknown_columns = columns - known_builders - known_derived - known_im if unknown_columns: raise ValueError( @@ -295,7 +318,8 @@ def select(self, columns: str | Iterable[str]): new_im_handler = self.__im_handler.with_columns(required_im) new_hidden = all_required - columns - if self.__sort_by is not None and self.__sort_by[0] not in columns: + if hide_sort: + assert self.__sort_by is not None new_hidden.add(self.__sort_by[0]) return DatasetState( diff --git a/test/test_dataset.py b/test/test_dataset.py index 2b5c4c34..86a4a4ad 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -110,6 +110,17 @@ def test_take_sorted_inverted(input_path): assert fof_masses.max() == toolkit_sorted_fof_masses[0] +def test_write_after_sorted(input_path, tmp_path): + dataset = oc.open(input_path) + dataset = dataset.sort_by("fof_halo_mass", invert=True) + halo_tags = dataset.select("fof_halo_tag").get_data("numpy") + oc.write(tmp_path / "test.hdf5", dataset) + new_dataset = oc.open(tmp_path / "test.hdf5") + to_check = new_dataset.select(("fof_halo_mass", "fof_halo_tag")).get_data("numpy") + assert np.all(to_check["fof_halo_mass"][:-1] >= to_check["fof_halo_mass"][1:]) + assert np.all(to_check["fof_halo_tag"] == halo_tags) + + def test_sort_by_derived(input_path): n = 150 ds = oc.open(input_path) From 407186f2577aa93b2b41dbe69c3ecf8ce34321d1 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 4 Sep 2025 10:21:54 -0500 Subject: [PATCH 2/6] Add changelog --- changes/+f495fef9.bugfix.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/+f495fef9.bugfix.rst diff --git a/changes/+f495fef9.bugfix.rst b/changes/+f495fef9.bugfix.rst new file mode 100644 index 00000000..1fc7c5b4 --- /dev/null +++ b/changes/+f495fef9.bugfix.rst @@ -0,0 +1 @@ +Fix a bug that could cause some data to be returned as a astropy.table.Column when data was requested as numpy From a458f1066f2eccbde278102f99f4d99486527c93 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 4 Sep 2025 11:40:55 -0500 Subject: [PATCH 3/6] Add tests, renable eager sorting --- .../collection/structure/collection.py | 3 +- src/opencosmo/dataset/dataset.py | 6 ++- src/opencosmo/dataset/state.py | 53 ++++--------------- src/opencosmo/index/chunked.py | 4 ++ src/opencosmo/index/protocols.py | 1 + src/opencosmo/index/simple.py | 19 ++++--- test/test_collection.py | 17 ++++++ test/test_dataset.py | 11 +++- 8 files changed, 58 insertions(+), 56 deletions(-) diff --git a/src/opencosmo/collection/structure/collection.py b/src/opencosmo/collection/structure/collection.py index ef5179de..ba012d60 100644 --- a/src/opencosmo/collection/structure/collection.py +++ b/src/opencosmo/collection/structure/collection.py @@ -177,7 +177,8 @@ def __getitem__(self, key: str) -> oc.Dataset | oc.StructureCollection: return self.__source index = self.__links[key].make_index(self.__index) - return self.__datasets[key].with_index(index) + dataset = self.__datasets[key].with_index(index) + return dataset def __enter__(self): return self diff --git a/src/opencosmo/dataset/dataset.py b/src/opencosmo/dataset/dataset.py index dbd02d3e..cd9efdfe 100644 --- a/src/opencosmo/dataset/dataset.py +++ b/src/opencosmo/dataset/dataset.py @@ -251,7 +251,8 @@ def get_data(self, output="astropy", unpack=True) -> OpenCosmoData: elif isinstance(data, (QTable, dict)): data = dict(data) for colname in data: - data[colname] = data[colname].value + if isinstance(data[colname], (u.Quantity, Column)): + data[colname] = data[colname].value if isinstance(data, dict) and len(data) == 1: return next(iter(data.values())) @@ -558,6 +559,9 @@ def sort_by(self, column: str, invert: bool = False) -> Dataset: .sort_by("fof_halo_mass") .take(100, at="start") + Note that sorting does not persist when a dataset is written, because the order + of data is used for the spatial index. + Parameters ---------- column : str diff --git a/src/opencosmo/dataset/state.py b/src/opencosmo/dataset/state.py index 3e005bc9..26a0f5d4 100644 --- a/src/opencosmo/dataset/state.py +++ b/src/opencosmo/dataset/state.py @@ -22,11 +22,12 @@ def make_sorted_index( state: "DatasetState", handler: "DatasetHandler", column: str, invert: bool ): - sort_by_column = state.select(column).get_data(handler, ignore_sort=True) + sort_by_column = state.select(column).get_data(handler) idx = np.argsort(sort_by_column) if invert: idx = idx[::-1] - index = SimpleIndex(idx) + + index = SimpleIndex(state.index.into_array()[idx]) return index @@ -45,7 +46,6 @@ def __init__( region: Region, header: OpenCosmoHeader, im_handler: InMemoryColumnHandler, - sort_by: Optional[tuple[str, bool]] = None, hidden: set[str] = set(), derived: dict[str, DerivedColumn] = {}, ): @@ -57,7 +57,6 @@ def __init__( self.__header = header self.__hidden = hidden self.__index = index - self.__sort_by = sort_by self.__region = region @property @@ -89,9 +88,7 @@ def columns(self) -> list[str]: ) return list(columns - self.__hidden) - def get_data( - self, handler: "DatasetHandler", ignore_sort: bool = False - ) -> table.QTable: + def get_data(self, handler: "DatasetHandler") -> table.QTable: """ Get the data for a given handler. """ @@ -100,8 +97,6 @@ def get_data( data = self.__build_derived_columns(data) data_columns = set(data.columns) - if not ignore_sort and self.__sort_by is not None: - data.sort(self.__sort_by[0], reverse=self.__sort_by[1]) if ( self.__hidden and not self.__hidden.intersection(data_columns) == data_columns @@ -124,7 +119,6 @@ def with_index(self, index: DataIndex): self.__region, self.__header, new_cache, - self.__sort_by, self.__hidden, self.__derived, ) @@ -137,12 +131,9 @@ def make_schema( self, handler: "DatasetHandler", header: Optional[OpenCosmoHeader] = None ): builder_names = set(self.__builder.columns) - if self.__sort_by is not None: - index = make_sorted_index(self, handler, *self.__sort_by) - else: - index = self.__index - - schema = handler.prep_write(index, builder_names - self.__hidden, header) + schema = handler.prep_write( + self.__index.sorted(), builder_names - self.__hidden, header + ) derived_names = set(self.__derived.keys()) - self.__hidden derived_data = ( @@ -222,7 +213,6 @@ def with_new_columns( self.__region, self.__header, new_im_handler, - self.__sort_by, self.__hidden, new_derived, ) @@ -253,7 +243,6 @@ def with_region(self, region: Region): region, self.__header, self.__im_handler, - self.__sort_by, self.__hidden, self.__derived, ) @@ -271,16 +260,11 @@ def select(self, columns: str | Iterable[str]): if isinstance(columns, str): columns = [columns] - hide_sort = False - columns = set(columns) known_builders = set(self.__builder.columns) known_derived = set(self.__derived.keys()) known_im = set(self.__im_handler.keys()) - if self.__sort_by is not None: - hide_sort = self.__sort_by[0] not in columns - columns.add(self.__sort_by[0]) unknown_columns = columns - known_builders - known_derived - known_im if unknown_columns: @@ -318,9 +302,6 @@ def select(self, columns: str | Iterable[str]): new_im_handler = self.__im_handler.with_columns(required_im) new_hidden = all_required - columns - if hide_sort: - assert self.__sort_by is not None - new_hidden.add(self.__sort_by[0]) return DatasetState( self.__base_unit_transformations, @@ -330,21 +311,20 @@ def select(self, columns: str | Iterable[str]): self.__region, self.__header, new_im_handler, - self.__sort_by, new_hidden, new_derived, ) def sort_by(self, column_name: str, handler: "DatasetHandler", invert: bool): + index = make_sorted_index(self, handler, column_name, invert) return DatasetState( self.__base_unit_transformations, self.__builder, - self.__index, + index, self.__convention, self.__region, self.__header, self.__im_handler, - (column_name, invert), self.__hidden, self.__derived, ) @@ -353,19 +333,7 @@ def take(self, n: int, at: str, handler): """ Take rows from the dataset. """ - if self.__sort_by is not None: - column = self.select(self.__sort_by[0]).get_data(handler, ignore_sort=True)[ - self.__sort_by[0] - ] - sorted = np.argsort(column) - if self.__sort_by[1]: - sorted = sorted[::-1] - - index: DataIndex = SimpleIndex(sorted) - else: - index = self.__index - - new_index = index.take(n, at) + new_index = self.__index.take(n, at) return self.with_index(new_index) def take_range(self, start: int, end: int): @@ -408,7 +376,6 @@ def with_units( self.__region, self.__header, self.__im_handler, - self.__sort_by, self.__hidden, self.__derived, ) diff --git a/src/opencosmo/index/chunked.py b/src/opencosmo/index/chunked.py index c0244842..2ef1fd16 100644 --- a/src/opencosmo/index/chunked.py +++ b/src/opencosmo/index/chunked.py @@ -45,6 +45,10 @@ def range(self) -> tuple[int, int]: """ return self.__starts[0], self.__starts[-1] + self.__sizes[-1] - 1 + def sorted(self): + sorted_indices = np.argsort(self.__starts) + return ChunkedIndex(self.__starts[sorted_indices], self.__sizes[sorted_indices]) + def into_array(self) -> NDArray[np.int_]: """ Convert the ChunkedIndex to a SimpleIndex. diff --git a/src/opencosmo/index/protocols.py b/src/opencosmo/index/protocols.py index 75ec53ba..d5e3bf62 100644 --- a/src/opencosmo/index/protocols.py +++ b/src/opencosmo/index/protocols.py @@ -20,6 +20,7 @@ def n_in_range( ) -> NDArray[np.int_]: ... def take(self, n: int, at: str = "random") -> "DataIndex": ... def take_range(self, start: int, end: int) -> "DataIndex": ... + def sorted(self) -> "DataIndex": ... def intersection(self, other: "DataIndex") -> "DataIndex": ... def projection(self, other: "DataIndex") -> "DataIndex": ... def into_mask(self) -> NDArray[np.bool_]: ... diff --git a/src/opencosmo/index/simple.py b/src/opencosmo/index/simple.py index 0892f16d..a9390abe 100644 --- a/src/opencosmo/index/simple.py +++ b/src/opencosmo/index/simple.py @@ -24,6 +24,9 @@ def from_size(cls, size: int) -> "SimpleIndex": def empty(cls): return SimpleIndex(np.array([], dtype=int)) + def sorted(self): + return SimpleIndex(np.sort(self.__index)) + def __len__(self) -> int: return len(self.__index) @@ -33,12 +36,9 @@ def into_array(self, copy: bool = False) -> NDArray[np.int_]: return self.__index def range(self) -> tuple[int, int]: - """ - Guranteed to be sorted - """ if len(self) == 0: return 0, 0 - return self.__index[0], self.__index[-1] + return (int(np.min(self.__index)), int(np.max(self.__index))) def into_mask(self): mask = np.zeros(self.__index[-1] + 1, dtype=bool) @@ -64,8 +64,9 @@ def n_in_range( return np.zeros_like(start) ends = start + size - start_idxs = np.searchsorted(self.__index, start, "left") - end_idxs = np.searchsorted(self.__index, ends, "left") + self_sorted = np.sort(self.__index) + start_idxs = np.searchsorted(self_sorted, start, "left") + end_idxs = np.searchsorted(self_sorted, ends, "left") return end_idxs - start_idxs def set_data(self, data: np.ndarray, value: bool) -> np.ndarray: @@ -127,10 +128,8 @@ def projection(self, other: DataIndex) -> DataIndex: where the second index is true. """ other_idxs = other.into_array() - idxs = np.searchsorted(self.__index, other_idxs) - idxs = idxs[idxs != len(self.__index)] - idxs = idxs[other_idxs == self.__index[idxs]] - return SimpleIndex(idxs) + is_in = np.isin(other_idxs, self.__index) + return SimpleIndex(self.__index[is_in]) def mask(self, mask: np.ndarray) -> DataIndex: if mask.shape != self.__index.shape: diff --git a/test/test_collection.py b/test/test_collection.py index 727b6bae..cb18e51b 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -416,6 +416,23 @@ def test_data_link_sort(halo_paths): assert np.all(fof_halo_tags == halo["halo_properties"]["fof_halo_tag"]) +def test_data_link_sort_write(halo_paths, tmp_path): + collection = oc.open(halo_paths) + collection = collection.filter(oc.col("sod_halo_mass") > 10**14).sort_by( + "fof_halo_mass" + ) + oc.write(tmp_path / "temp.hdf5", collection) + new_collection = oc.open(tmp_path / "temp.hdf5").take(10) + assert np.all( + collection["halo_properties"].select("sod_halo_mass").get_data("numpy") > 10**14 + ) + for halo in new_collection.objects(("halo_profiles",)): + assert np.all( + halo["halo_properties"]["fof_halo_tag"] + == halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy")[0] + ) + + def test_data_link_selection(halo_paths): collection = oc.open(*halo_paths) collection = collection.filter(oc.col("sod_halo_mass") > 10**13).take( diff --git a/test/test_dataset.py b/test/test_dataset.py index 86a4a4ad..cff56665 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -80,6 +80,15 @@ def test_take_oom(input_path): assert len(data) == 10 +def test_sort_after_filter(input_path): + dataset = oc.open(input_path) + dataset = dataset.filter(oc.col("fof_halo_mass") > 1e13) + dataset = dataset.sort_by("sod_halo_mass") + data = dataset.select(("fof_halo_mass", "sod_halo_mass")).get_data("numpy") + assert np.all(data["fof_halo_mass"] > 1e13) + assert np.all(data["sod_halo_mass"][:-1] <= data["sod_halo_mass"][1:]) + + def test_take_sorted(input_path): n = 150 dataset = oc.open(input_path) @@ -115,7 +124,7 @@ def test_write_after_sorted(input_path, tmp_path): dataset = dataset.sort_by("fof_halo_mass", invert=True) halo_tags = dataset.select("fof_halo_tag").get_data("numpy") oc.write(tmp_path / "test.hdf5", dataset) - new_dataset = oc.open(tmp_path / "test.hdf5") + new_dataset = oc.open(tmp_path / "test.hdf5").sort_by("fof_halo_mass", invert=True) to_check = new_dataset.select(("fof_halo_mass", "fof_halo_tag")).get_data("numpy") assert np.all(to_check["fof_halo_mass"][:-1] >= to_check["fof_halo_mass"][1:]) assert np.all(to_check["fof_halo_tag"] == halo_tags) From d74f3789235039ab6baf1455f16444dc27c673c6 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 4 Sep 2025 11:58:04 -0500 Subject: [PATCH 4/6] Fix writing linked datasets --- src/opencosmo/collection/structure/collection.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/opencosmo/collection/structure/collection.py b/src/opencosmo/collection/structure/collection.py index ba012d60..b871e76c 100644 --- a/src/opencosmo/collection/structure/collection.py +++ b/src/opencosmo/collection/structure/collection.py @@ -801,10 +801,12 @@ def galaxies(self, *args, **kwargs): raise AttributeError("This collection does not contain galaxies!") def make_schema(self) -> StructCollectionSchema: + sorted_index = self.__source.index.sorted() + to_write = self.with_index(sorted_index) schema = StructCollectionSchema(self.__header) source_name = self.__source.dtype - for name, dataset in self.items(): + for name, dataset in to_write.items(): ds_schema = dataset.make_schema() if name == "galaxies": name = "galaxy_properties" @@ -813,7 +815,7 @@ def make_schema(self) -> StructCollectionSchema: for name, handler in self.__links.items(): if name == "galaxies": name = "galaxy_properties" - link_schema = handler.make_schema(name, self.__index) + link_schema = handler.make_schema(name, sorted_index) schema.insert(link_schema, f"{source_name}.{name}") return schema From 0d4c3c7a8ee58689ed2e49807d1a5388935b06f9 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 4 Sep 2025 13:08:34 -0500 Subject: [PATCH 5/6] Several bug fixes --- src/opencosmo/dataset/state.py | 2 +- src/opencosmo/index/chunked.py | 1 + src/opencosmo/index/simple.py | 8 ++++++-- test/test_im_col.py | 1 + 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/opencosmo/dataset/state.py b/src/opencosmo/dataset/state.py index 26a0f5d4..33cf3a87 100644 --- a/src/opencosmo/dataset/state.py +++ b/src/opencosmo/dataset/state.py @@ -324,7 +324,7 @@ def sort_by(self, column_name: str, handler: "DatasetHandler", invert: bool): self.__convention, self.__region, self.__header, - self.__im_handler, + self.__im_handler.with_index(index), self.__hidden, self.__derived, ) diff --git a/src/opencosmo/index/chunked.py b/src/opencosmo/index/chunked.py index 2ef1fd16..df6d8929 100644 --- a/src/opencosmo/index/chunked.py +++ b/src/opencosmo/index/chunked.py @@ -169,6 +169,7 @@ def take(self, n: int, at: str = "random") -> DataIndex: ] ) idxs = np.random.choice(idxs, n, replace=False) + idxs.sort() return simple.SimpleIndex(idxs) elif at == "start": diff --git a/src/opencosmo/index/simple.py b/src/opencosmo/index/simple.py index a9390abe..157004d1 100644 --- a/src/opencosmo/index/simple.py +++ b/src/opencosmo/index/simple.py @@ -89,7 +89,11 @@ def take(self, n: int, at: str = "random") -> DataIndex: return SimpleIndex.empty() if at == "random": - return SimpleIndex(np.random.choice(self.__index, n, replace=False)) + indices = np.arange(len(self.__index)) + indices = np.random.choice(indices, n, replace=False) + indices.sort() + + return SimpleIndex(self.__index[indices]) elif at == "start": return SimpleIndex(self.__index[:n]) elif at == "end": @@ -148,7 +152,7 @@ def mask(self, mask: np.ndarray) -> DataIndex: return SimpleIndex(self.__index[mask]) - def get_data(self, data: h5py.Dataset) -> np.ndarray: + def get_data(self, data: h5py.Dataset, debug: bool = False) -> np.ndarray: """ Get the data from the dataset using the index. """ diff --git a/test/test_im_col.py b/test/test_im_col.py index 9586c609..8e5ae37c 100644 --- a/test/test_im_col.py +++ b/test/test_im_col.py @@ -138,4 +138,5 @@ def test_add_order(properties_path): ds = ds.sort_by("test_random") data = ds.get_data("numpy") test_random = data["test_random"] + print(test_random) assert np.all(test_random[:-1] <= test_random[1:]) From 62015a80632ab5b428affc73bcf2495e88d65794 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 4 Sep 2025 14:16:36 -0500 Subject: [PATCH 6/6] Sorted unit handling update --- src/opencosmo/dataset/state.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/opencosmo/dataset/state.py b/src/opencosmo/dataset/state.py index 33cf3a87..69fd08b2 100644 --- a/src/opencosmo/dataset/state.py +++ b/src/opencosmo/dataset/state.py @@ -131,9 +131,12 @@ def make_schema( self, handler: "DatasetHandler", header: Optional[OpenCosmoHeader] = None ): builder_names = set(self.__builder.columns) - schema = handler.prep_write( - self.__index.sorted(), builder_names - self.__hidden, header - ) + if isinstance(self.__index, SimpleIndex): + index = self.__index.sorted() + else: + index = self.__index + + schema = handler.prep_write(index, builder_names - self.__hidden, header) derived_names = set(self.__derived.keys()) - self.__hidden derived_data = (