Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/+f495fef9.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug that could cause some data to be returned as a astropy.table.Column when data was requested as numpy
9 changes: 6 additions & 3 deletions src/opencosmo/collection/structure/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def __getitem__(self, key: str) -> oc.Dataset | oc.StructureCollection:
return self.__source

index = self.__links[key].make_index(self.__index)
return self.__datasets[key].with_index(index)
dataset = self.__datasets[key].with_index(index)
return dataset

def __enter__(self):
return self
Expand Down Expand Up @@ -800,10 +801,12 @@ def galaxies(self, *args, **kwargs):
raise AttributeError("This collection does not contain galaxies!")

def make_schema(self) -> StructCollectionSchema:
sorted_index = self.__source.index.sorted()
to_write = self.with_index(sorted_index)
schema = StructCollectionSchema(self.__header)
source_name = self.__source.dtype

for name, dataset in self.items():
for name, dataset in to_write.items():
ds_schema = dataset.make_schema()
if name == "galaxies":
name = "galaxy_properties"
Expand All @@ -812,7 +815,7 @@ def make_schema(self) -> StructCollectionSchema:
for name, handler in self.__links.items():
if name == "galaxies":
name = "galaxy_properties"
link_schema = handler.make_schema(name, self.__index)
link_schema = handler.make_schema(name, sorted_index)
schema.insert(link_schema, f"{source_name}.{name}")

return schema
15 changes: 8 additions & 7 deletions src/opencosmo/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
from astropy import units # type: ignore
from astropy.cosmology import Cosmology # type: ignore
from astropy.table import QTable # type: ignore
from astropy.table import Column, QTable # type: ignore

from opencosmo.dataset.column import ColumnMask, DerivedColumn
from opencosmo.dataset.state import DatasetState
Expand Down Expand Up @@ -246,15 +246,13 @@ def get_data(self, output="astropy", unpack=True) -> OpenCosmoData:
data = data[cn]

if output == "numpy":
if isinstance(data, u.Quantity):
if isinstance(data, (u.Quantity, Column)):
data = data.value
elif isinstance(data, (QTable, dict)):
data = dict(data)
is_quantity = filter(
lambda v: isinstance(data[v], u.Quantity), data.keys()
)
for colname in is_quantity:
data[colname] = data[colname].value
for colname in data:
if isinstance(data[colname], (u.Quantity, Column)):
data[colname] = data[colname].value

if isinstance(data, dict) and len(data) == 1:
return next(iter(data.values()))
Expand Down Expand Up @@ -561,6 +559,9 @@ def sort_by(self, column: str, invert: bool = False) -> Dataset:
.sort_by("fof_halo_mass")
.take(100, at="start")

Note that sorting does not persist when a dataset is written, because the order
of data is used for the spatial index.

Parameters
----------
column : str
Expand Down
56 changes: 25 additions & 31 deletions src/opencosmo/dataset/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
from opencosmo.dataset.handler import DatasetHandler


def make_sorted_index(
state: "DatasetState", handler: "DatasetHandler", column: str, invert: bool
):
sort_by_column = state.select(column).get_data(handler)
idx = np.argsort(sort_by_column)
if invert:
idx = idx[::-1]

index = SimpleIndex(state.index.into_array()[idx])
return index


class DatasetState:
"""
Holds mutable state required by the dataset. Cleans up the dataset to mostly focus
Expand All @@ -34,7 +46,6 @@ def __init__(
region: Region,
header: OpenCosmoHeader,
im_handler: InMemoryColumnHandler,
sort_by: Optional[tuple[str, bool]] = None,
hidden: set[str] = set(),
derived: dict[str, DerivedColumn] = {},
):
Expand All @@ -46,7 +57,6 @@ def __init__(
self.__header = header
self.__hidden = hidden
self.__index = index
self.__sort_by = sort_by
self.__region = region

@property
Expand Down Expand Up @@ -78,9 +88,7 @@ def columns(self) -> list[str]:
)
return list(columns - self.__hidden)

def get_data(
self, handler: "DatasetHandler", ignore_sort: bool = False
) -> table.QTable:
def get_data(self, handler: "DatasetHandler") -> table.QTable:
"""
Get the data for a given handler.
"""
Expand All @@ -94,8 +102,6 @@ def get_data(
and not self.__hidden.intersection(data_columns) == data_columns
):
data.remove_columns(self.__hidden)
if not ignore_sort and self.__sort_by is not None:
data.sort(self.__sort_by[0], reverse=self.__sort_by[1])
return data

def with_index(self, index: DataIndex):
Expand All @@ -113,7 +119,6 @@ def with_index(self, index: DataIndex):
self.__region,
self.__header,
new_cache,
self.__sort_by,
self.__hidden,
self.__derived,
)
Expand All @@ -126,7 +131,13 @@ def make_schema(
self, handler: "DatasetHandler", header: Optional[OpenCosmoHeader] = None
):
builder_names = set(self.__builder.columns)
schema = handler.prep_write(self.__index, builder_names - self.__hidden, header)
if isinstance(self.__index, SimpleIndex):
index = self.__index.sorted()
else:
index = self.__index

schema = handler.prep_write(index, builder_names - self.__hidden, header)

derived_names = set(self.__derived.keys()) - self.__hidden
derived_data = (
self.select(derived_names)
Expand Down Expand Up @@ -205,7 +216,6 @@ def with_new_columns(
self.__region,
self.__header,
new_im_handler,
self.__sort_by,
self.__hidden,
new_derived,
)
Expand Down Expand Up @@ -236,7 +246,6 @@ def with_region(self, region: Region):
region,
self.__header,
self.__im_handler,
self.__sort_by,
self.__hidden,
self.__derived,
)
Expand All @@ -259,6 +268,7 @@ def select(self, columns: str | Iterable[str]):
known_builders = set(self.__builder.columns)
known_derived = set(self.__derived.keys())
known_im = set(self.__im_handler.keys())

unknown_columns = columns - known_builders - known_derived - known_im
if unknown_columns:
raise ValueError(
Expand Down Expand Up @@ -295,8 +305,6 @@ def select(self, columns: str | Iterable[str]):
new_im_handler = self.__im_handler.with_columns(required_im)

new_hidden = all_required - columns
if self.__sort_by is not None and self.__sort_by[0] not in columns:
new_hidden.add(self.__sort_by[0])

return DatasetState(
self.__base_unit_transformations,
Expand All @@ -306,21 +314,20 @@ def select(self, columns: str | Iterable[str]):
self.__region,
self.__header,
new_im_handler,
self.__sort_by,
new_hidden,
new_derived,
)

def sort_by(self, column_name: str, handler: "DatasetHandler", invert: bool):
index = make_sorted_index(self, handler, column_name, invert)
return DatasetState(
self.__base_unit_transformations,
self.__builder,
self.__index,
index,
self.__convention,
self.__region,
self.__header,
self.__im_handler,
(column_name, invert),
self.__im_handler.with_index(index),
self.__hidden,
self.__derived,
)
Expand All @@ -329,19 +336,7 @@ def take(self, n: int, at: str, handler):
"""
Take rows from the dataset.
"""
if self.__sort_by is not None:
column = self.select(self.__sort_by[0]).get_data(handler, ignore_sort=True)[
self.__sort_by[0]
]
sorted = np.argsort(column)
if self.__sort_by[1]:
sorted = sorted[::-1]

index: DataIndex = SimpleIndex(sorted)
else:
index = self.__index

new_index = index.take(n, at)
new_index = self.__index.take(n, at)
return self.with_index(new_index)

def take_range(self, start: int, end: int):
Expand Down Expand Up @@ -384,7 +379,6 @@ def with_units(
self.__region,
self.__header,
self.__im_handler,
self.__sort_by,
self.__hidden,
self.__derived,
)
5 changes: 5 additions & 0 deletions src/opencosmo/index/chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def range(self) -> tuple[int, int]:
"""
return self.__starts[0], self.__starts[-1] + self.__sizes[-1] - 1

def sorted(self):
sorted_indices = np.argsort(self.__starts)
return ChunkedIndex(self.__starts[sorted_indices], self.__sizes[sorted_indices])

def into_array(self) -> NDArray[np.int_]:
"""
Convert the ChunkedIndex to a SimpleIndex.
Expand Down Expand Up @@ -165,6 +169,7 @@ def take(self, n: int, at: str = "random") -> DataIndex:
]
)
idxs = np.random.choice(idxs, n, replace=False)
idxs.sort()
return simple.SimpleIndex(idxs)

elif at == "start":
Expand Down
1 change: 1 addition & 0 deletions src/opencosmo/index/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def n_in_range(
) -> NDArray[np.int_]: ...
def take(self, n: int, at: str = "random") -> "DataIndex": ...
def take_range(self, start: int, end: int) -> "DataIndex": ...
def sorted(self) -> "DataIndex": ...
def intersection(self, other: "DataIndex") -> "DataIndex": ...
def projection(self, other: "DataIndex") -> "DataIndex": ...
def into_mask(self) -> NDArray[np.bool_]: ...
Expand Down
27 changes: 15 additions & 12 deletions src/opencosmo/index/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def from_size(cls, size: int) -> "SimpleIndex":
def empty(cls):
return SimpleIndex(np.array([], dtype=int))

def sorted(self):
return SimpleIndex(np.sort(self.__index))

def __len__(self) -> int:
return len(self.__index)

Expand All @@ -33,12 +36,9 @@ def into_array(self, copy: bool = False) -> NDArray[np.int_]:
return self.__index

def range(self) -> tuple[int, int]:
"""
Guranteed to be sorted
"""
if len(self) == 0:
return 0, 0
return self.__index[0], self.__index[-1]
return (int(np.min(self.__index)), int(np.max(self.__index)))

def into_mask(self):
mask = np.zeros(self.__index[-1] + 1, dtype=bool)
Expand All @@ -64,8 +64,9 @@ def n_in_range(
return np.zeros_like(start)

ends = start + size
start_idxs = np.searchsorted(self.__index, start, "left")
end_idxs = np.searchsorted(self.__index, ends, "left")
self_sorted = np.sort(self.__index)
start_idxs = np.searchsorted(self_sorted, start, "left")
end_idxs = np.searchsorted(self_sorted, ends, "left")
return end_idxs - start_idxs

def set_data(self, data: np.ndarray, value: bool) -> np.ndarray:
Expand All @@ -88,7 +89,11 @@ def take(self, n: int, at: str = "random") -> DataIndex:
return SimpleIndex.empty()

if at == "random":
return SimpleIndex(np.random.choice(self.__index, n, replace=False))
indices = np.arange(len(self.__index))
indices = np.random.choice(indices, n, replace=False)
indices.sort()

return SimpleIndex(self.__index[indices])
elif at == "start":
return SimpleIndex(self.__index[:n])
elif at == "end":
Expand Down Expand Up @@ -127,10 +132,8 @@ def projection(self, other: DataIndex) -> DataIndex:
where the second index is true.
"""
other_idxs = other.into_array()
idxs = np.searchsorted(self.__index, other_idxs)
idxs = idxs[idxs != len(self.__index)]
idxs = idxs[other_idxs == self.__index[idxs]]
return SimpleIndex(idxs)
is_in = np.isin(other_idxs, self.__index)
return SimpleIndex(self.__index[is_in])

def mask(self, mask: np.ndarray) -> DataIndex:
if mask.shape != self.__index.shape:
Expand All @@ -149,7 +152,7 @@ def mask(self, mask: np.ndarray) -> DataIndex:

return SimpleIndex(self.__index[mask])

def get_data(self, data: h5py.Dataset) -> np.ndarray:
def get_data(self, data: h5py.Dataset, debug: bool = False) -> np.ndarray:
"""
Get the data from the dataset using the index.
"""
Expand Down
17 changes: 17 additions & 0 deletions test/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,23 @@ def test_data_link_sort(halo_paths):
assert np.all(fof_halo_tags == halo["halo_properties"]["fof_halo_tag"])


def test_data_link_sort_write(halo_paths, tmp_path):
collection = oc.open(halo_paths)
collection = collection.filter(oc.col("sod_halo_mass") > 10**14).sort_by(
"fof_halo_mass"
)
oc.write(tmp_path / "temp.hdf5", collection)
new_collection = oc.open(tmp_path / "temp.hdf5").take(10)
assert np.all(
collection["halo_properties"].select("sod_halo_mass").get_data("numpy") > 10**14
)
for halo in new_collection.objects(("halo_profiles",)):
assert np.all(
halo["halo_properties"]["fof_halo_tag"]
== halo["halo_profiles"].select("fof_halo_bin_tag").get_data("numpy")[0]
)


def test_data_link_selection(halo_paths):
collection = oc.open(*halo_paths)
collection = collection.filter(oc.col("sod_halo_mass") > 10**13).take(
Expand Down
20 changes: 20 additions & 0 deletions test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def test_take_oom(input_path):
assert len(data) == 10


def test_sort_after_filter(input_path):
dataset = oc.open(input_path)
dataset = dataset.filter(oc.col("fof_halo_mass") > 1e13)
dataset = dataset.sort_by("sod_halo_mass")
data = dataset.select(("fof_halo_mass", "sod_halo_mass")).get_data("numpy")
assert np.all(data["fof_halo_mass"] > 1e13)
assert np.all(data["sod_halo_mass"][:-1] <= data["sod_halo_mass"][1:])


def test_take_sorted(input_path):
n = 150
dataset = oc.open(input_path)
Expand Down Expand Up @@ -110,6 +119,17 @@ def test_take_sorted_inverted(input_path):
assert fof_masses.max() == toolkit_sorted_fof_masses[0]


def test_write_after_sorted(input_path, tmp_path):
dataset = oc.open(input_path)
dataset = dataset.sort_by("fof_halo_mass", invert=True)
halo_tags = dataset.select("fof_halo_tag").get_data("numpy")
oc.write(tmp_path / "test.hdf5", dataset)
new_dataset = oc.open(tmp_path / "test.hdf5").sort_by("fof_halo_mass", invert=True)
to_check = new_dataset.select(("fof_halo_mass", "fof_halo_tag")).get_data("numpy")
assert np.all(to_check["fof_halo_mass"][:-1] >= to_check["fof_halo_mass"][1:])
assert np.all(to_check["fof_halo_tag"] == halo_tags)


def test_sort_by_derived(input_path):
n = 150
ds = oc.open(input_path)
Expand Down
Loading
Loading