From a28d114b2857ec23eb62bf51eb6b9da71f589c7f Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 10 Apr 2025 07:41:41 -0700 Subject: [PATCH 01/11] Add filtering --- opencosmo/link/collection.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/opencosmo/link/collection.py b/opencosmo/link/collection.py index ecd8f69e..a8a45555 100644 --- a/opencosmo/link/collection.py +++ b/opencosmo/link/collection.py @@ -9,6 +9,21 @@ from opencosmo import link as l +def filter_properties_by_dataset( + dataset: oc.Dataset, + properties: oc.Dataset, + *masks +) -> oc.Dataset: + masked_dataset = dataset.filter(*masks) + if properties.header.file.data_type == "halo_properties": + linked_column = "fof_halo_tag" + elif properties.header.file.data_type == "galaxy_properties": + linked_column = "gal_tag" + + tags = masked_dataset.select(linked_column).data + new_properties = properties.filter(oc.col(linked_column).isin(tags)) + return new_properties + class StructureCollection: """ A collection of datasets that contain both high-level properties @@ -24,6 +39,7 @@ def __init__( self, properties: oc.Dataset, handlers: dict[str, l.LinkHandler], + filters: Optional[dict[str, Any]] = {}, *args, **kwargs, ): @@ -34,6 +50,7 @@ def __init__( self.__properties = properties self.__handlers = handlers self.__idxs = self.__properties.indices + self.__filters = filters def __repr__(self): structure_type = self.__properties.header.file.data_type.split("_")[0] + "s" @@ -87,7 +104,8 @@ def __getitem__(self, key: str) -> oc.Dataset: return self.__properties elif key not in self.__handlers: raise KeyError(f"Dataset {key} not found in collection.") - return self.__handlers[key].get_all_data() + indices = self.__properties.indices + return self.__handlers[key].get_data(indices) def __enter__(self): return self @@ -119,13 +137,20 @@ def select(self, dataset: str, columns: str | list[str]) -> StructureCollection: self.__properties, {**self.__handlers, dataset: new_handler} ) - def filter(self, *masks): + def filter(self, *masks, dataset: Optional[str] = None) -> StructureCollection: """ Apply a filter to the properties dataset and propagate it to the linked datasets """ if not masks: return self - filtered = self.__properties.filter(*masks) + if dataset is None: + filtered = self.__properties.filter(*masks) + elif dataset not in self.__handlers: + raise ValueError(f"Dataset {dataset} not found in collection.") + else: + filtered = filter_properties_by_dataset( + self[dataset], self.__properties, *masks + ) return StructureCollection( filtered, self.__handlers, From 7869dd1ab33ed67176e890f679f0155341b8ace7 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 11 Apr 2025 11:03:15 -0500 Subject: [PATCH 02/11] Small bugfix --- opencosmo/handler/mpi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/opencosmo/handler/mpi.py b/opencosmo/handler/mpi.py index 4e6dd9fd..d3ce640d 100644 --- a/opencosmo/handler/mpi.py +++ b/opencosmo/handler/mpi.py @@ -121,13 +121,13 @@ def write( selected: Optional[np.ndarray] = None, ) -> None: columns = list(columns) - input = verify_input( + input_ = verify_input( comm=self.__comm, columns=columns, dataset_name=dataset_name, require=["dataset_name"], ) - columns = input["columns"] + columns = input_["columns"] rank_range = self.elem_range() # indices = redistribute_indices(indices, rank_range) From f244bf32b29ab1f330e83513eb308d332976f85c Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 11 Apr 2025 11:06:21 -0500 Subject: [PATCH 03/11] Add kwargs --- opencosmo/collection/collection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/opencosmo/collection/collection.py b/opencosmo/collection/collection.py index 50844c7a..3ed9d26f 100644 --- a/opencosmo/collection/collection.py +++ b/opencosmo/collection/collection.py @@ -178,7 +178,7 @@ def __map(self, method, *args, **kwargs): output = {k: getattr(v, method)(*args, **kwargs) for k, v in self.items()} return SimulationCollection(output) - def filter(self, *masks: Mask) -> SimulationCollection: + def filter(self, *masks: Mask, **kwargs) -> SimulationCollection: """ Filter the datasets in the collection. This method behaves exactly like :meth:`opencosmo.Dataset.filter`, except that @@ -196,7 +196,7 @@ def filter(self, *masks: Mask) -> SimulationCollection: A new collection with the same datasets, but only the particles that pass the filter. """ - return self.__map("filter", *masks) + return self.__map("filter", *masks, **kwargs) def select(self, *args, **kwargs) -> SimulationCollection: """ From 8dd34a599790b2a0cbc839db591c04f5c9b43bfc Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Sun, 13 Apr 2025 11:11:30 -0500 Subject: [PATCH 04/11] New indexing tools --- opencosmo/dataset/index.py | 157 +++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 opencosmo/dataset/index.py diff --git a/opencosmo/dataset/index.py b/opencosmo/dataset/index.py new file mode 100644 index 00000000..1db151c2 --- /dev/null +++ b/opencosmo/dataset/index.py @@ -0,0 +1,157 @@ +from __future__ import annotations +from typing import Protocol + +import numpy as np +import h5py + + +class DataIndex(Protocol): + + @classmethod + def from_size(cls, size: int) -> DataIndex: ... + def get_data(self, data: h5py.Dataset) -> np.ndarray: ... + def take(self, n: int, at: str = "random") -> DataIndex: ... + def mask(self, mask: np.ndarray) -> DataIndex: ... + def __len__(self) -> int: ... + + +class SimpleIndex: + """ + An index of integers. + """ + + def __init__(self, index: np.ndarray) -> None: + self.__index = np.sort(index) + + @classmethod + def from_size(cls, size: int) -> SimpleIndex: + return SimpleIndex(np.arange(size)) + + def __len__(self) -> int: + return len(self.__index) + + def take(self, n: int, at: str = "random") -> SimpleIndex: + """ + Take n elements from the index. + """ + if n > len(self): + raise ValueError(f"Cannot take {n} elements from index of size {len(self)}") + if at == "random": + return SimpleIndex(np.random.choice(self.__index, n, replace=False)) + elif at == "start": + return SimpleIndex(self.__index[:n]) + elif at == "end": + return SimpleIndex(self.__index[-n:]) + else: + raise ValueError(f"Unknown value for 'at': {at}") + + def mask(self, mask: np.ndarray) -> SimpleIndex: + if mask.shape != self.__index.shape: + raise ValueError(f"Mask shape {mask.shape} does not match index size {len(self)}") + + if mask.dtype != bool: + raise ValueError(f"Mask dtype {mask.dtype} is not boolean") + + if not mask.any(): + raise ValueError("Mask is all False") + + if mask.all(): + return self + + return SimpleIndex(self.__index[mask]) + + def get_data(self, data: h5py.Dataset) -> np.ndarray: + """ + Get the data from the dataset using the index. + """ + if not isinstance(data, h5py.Dataset): + raise ValueError("Data must be a h5py.Dataset") + + min_index = self.__index.min() + max_index = self.__index.max() + output = data[min_index:max_index + 1] + indices_into_output = self.__index - min_index + return output[indices_into_output] + +class ChunkedIndex: + + def __init__(self, starts: np.ndarray, sizes: np.ndarray) -> None: + self.__starts = starts + self.__sizes = sizes + + @classmethod + def from_size(cls, size: int) -> ChunkedIndex: + """ + Create a ChunkedIndex from a size. + """ + if size <= 0: + raise ValueError(f"Size must be positive, got {size}") + # Create an array of chunk sizes + + starts = np.array([0]) + sizes = np.array([size]) + return ChunkedIndex(starts, sizes) + + def __len__(self) -> int: + """ + Get the total size of the index. + """ + return np.sum(self.__sizes) + + def take(self, n: int, at: str = "random") -> DataIndex: + if n > len(self): + raise ValueError(f"Cannot take {n} elements from index of size {len(self)}") + + if at == "random": + idxs = np.concatenate([np.arange(start, start + size) for start, size in zip(self.__starts, self.__sizes)]) + idxs = np.random.choice(idxs, n, replace=False) + return SimpleIndex(idxs) + elif at == "start": + last_chunk_in_range = np.searchsorted(np.sum(self.__sizes), n) + new_starts = self.__starts[:last_chunk_in_range] + new_sizes = self.__sizes[:last_chunk_in_range] + new_sizes[-1] = n - np.sum(new_sizes[:-1]) + return ChunkedIndex(new_starts, new_sizes) + elif at == "end": + first_chunk_in_range = np.searchsorted(np.sum(self.__sizes), len(self) - n) + new_starts = self.__starts[first_chunk_in_range:] + new_sizes = self.__sizes[first_chunk_in_range:] + new_sizes[0] = n - np.sum(new_sizes[1:]) + return ChunkedIndex(new_starts, new_sizes) + + def mask(self, mask: np.ndarray) -> DataIndex: + """ + Mask the index with a boolean mask. + """ + if mask.shape != (len(self),): + raise ValueError(f"Mask shape {mask.shape} does not match index size {len(self)}") + + if mask.dtype != bool: + raise ValueError(f"Mask dtype {mask.dtype} is not boolean") + + if not mask.any(): + raise ValueError("Mask is all False") + + if mask.all(): + return self + + # Get the indices of the chunks that are masked + idxs = np.concatenate([np.arange(start, start + size) for start, size in zip(self.__starts, self.__sizes)]) + masked_idxs = idxs[mask] + + return SimpleIndex(masked_idxs) + + + def get_data(self, data: h5py.Dataset) -> np.ndarray: + """ + Get the data from the dataset using the index. + """ + if not isinstance(data, h5py.Dataset): + raise ValueError("Data must be a h5py.Dataset") + + output = np.zeros(len(self), dtype=data.dtype) + running_index = 0 + for start, size in zip(self.__starts, self.__sizes): + output[running_index:running_index + size] = data[start:start + size] + running_index += size + return output From 227cd47c4708290efbcc563f61e3e36adec9815b Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 14 Apr 2025 10:53:41 -0500 Subject: [PATCH 05/11] Full new index working --- opencosmo/collection/collection.py | 6 +- opencosmo/dataset/dataset.py | 57 +++---- opencosmo/dataset/index.py | 230 +++++++++++++++++++++++++++-- opencosmo/dataset/mask.py | 13 +- opencosmo/handler/handler.py | 8 +- opencosmo/handler/im.py | 30 ++-- opencosmo/handler/mpi.py | 126 ++++------------ opencosmo/handler/oom.py | 59 ++------ opencosmo/io.py | 15 +- opencosmo/link/builder.py | 18 +-- opencosmo/link/collection.py | 11 +- opencosmo/link/handler.py | 45 +++--- opencosmo/link/mpi.py | 55 +++---- opencosmo/utils.py | 27 +--- test/test_collection.py | 7 +- 15 files changed, 374 insertions(+), 333 deletions(-) diff --git a/opencosmo/collection/collection.py b/opencosmo/collection/collection.py index 3ed9d26f..1766c42e 100644 --- a/opencosmo/collection/collection.py +++ b/opencosmo/collection/collection.py @@ -20,6 +20,7 @@ from opencosmo.link import StructureCollection from opencosmo.spatial import read_tree from opencosmo.transformations import units as u +from opencosmo.dataset.index import ChunkedIndex class Collection(Protocol): @@ -95,7 +96,6 @@ def write_with_unique_headers(collection: Collection, file: h5py.File): group = file.create_group(key) collection[key].write(group) - def verify_datasets_exist(file: h5py.File, datasets: Iterable[str]): """ Verify a set of datasets exist in a given file. @@ -278,5 +278,5 @@ def read_single_dataset( builders, base_unit_transformations = u.get_default_unit_transformations( file[dataset_key], header ) - mask = np.arange(len(handler)) - return oc.Dataset(handler, header, builders, base_unit_transformations, mask) + index = ChunkedIndex.from_size(len(handler)) + return oc.Dataset(handler, header, builders, base_unit_transformations, index) diff --git a/opencosmo/dataset/dataset.py b/opencosmo/dataset/dataset.py index 1d324702..a9902115 100644 --- a/opencosmo/dataset/dataset.py +++ b/opencosmo/dataset/dataset.py @@ -13,6 +13,7 @@ from opencosmo.dataset.mask import Mask, apply_masks from opencosmo.handler import OpenCosmoDataHandler from opencosmo.header import OpenCosmoHeader, write_header +from opencosmo.dataset.index import DataIndex, ChunkedIndex class Dataset: @@ -22,21 +23,21 @@ def __init__( header: OpenCosmoHeader, builders: dict[str, ColumnBuilder], unit_transformations: dict[t.TransformationType, list[t.Transformation]], - indices: np.ndarray, + index: DataIndex, ): self.__handler = handler self.__header = header self.__builders = builders self.__base_unit_transformations = unit_transformations - self.__indices = indices + self.__index = index @property def header(self) -> OpenCosmoHeader: return self.__header @property - def indices(self) -> np.ndarray: - return self.__indices + def index(self) -> DataIndex: + return self.__index def __repr__(self): """ @@ -54,7 +55,7 @@ def __repr__(self): return head + cosmo_repr + table_head + table_repr def __len__(self): - return len(self.__indices) + return len(self.__index) def __enter__(self): # Need to write tests @@ -74,7 +75,7 @@ def cosmology(self): def data(self): # should rename this, dataset.data can get confusing # Also the point is that there's MORE data than just the table - return self.__handler.get_data(builders=self.__builders, indices=self.__indices) + return self.__handler.get_data(builders=self.__builders, index=self.__index) def write( self, @@ -103,7 +104,7 @@ def write( if with_header: write_header(file, self.__header, dataset_name) - self.__handler.write(file, self.indices, self.__builders.keys(), dataset_name) + self.__handler.write(file, self.__index, self.__builders.keys(), dataset_name) def rows(self) -> Generator[dict[str, float | units.Quantity]]: """ @@ -163,14 +164,14 @@ def take_range(self, start: int, end: int) -> Table: if start < 0 or end > len(self): raise ValueError("start and end must be within the bounds of the dataset.") - new_indices = self.__indices[start:end] + new_index = self.__index.take_range(start, end) return Dataset( self.__handler, self.__header, self.__builders, self.__base_unit_transformations, - new_indices, + new_index, ) def filter(self, *masks: Mask) -> Dataset: @@ -195,19 +196,19 @@ def filter(self, *masks: Mask) -> Dataset: """ - new_indices = apply_masks( - self.__handler, self.__builders, masks, self.__indices + new_index = apply_masks( + self.__handler, self.__builders, masks, self.__index ) - if len(new_indices) == 0: - raise ValueError("Filter returned zero rows!") + if len(new_index) == 0: + raise ValueError("The filter returned no rows!") return Dataset( self.__handler, self.__header, self.__builders, self.__base_unit_transformations, - new_indices, + new_index, ) def select(self, columns: str | Iterable[str]) -> Dataset: @@ -250,7 +251,7 @@ def select(self, columns: str | Iterable[str]) -> Dataset: self.__header, new_builders, self.__base_unit_transformations, - self.__indices, + self.__index, ) def with_units(self, convention: str) -> Dataset: @@ -279,7 +280,7 @@ def with_units(self, convention: str) -> Dataset: self.__header, new_builders, self.__base_unit_transformations, - self.__indices, + self.__index, ) def collect(self) -> Dataset: @@ -305,13 +306,14 @@ def collect(self) -> Dataset: If working in an MPI context, all ranks will recieve the same data. """ - new_handler = self.__handler.collect(self.__builders.keys(), self.__indices) + new_handler = self.__handler.collect(self.__builders.keys(), self.__index) + new_index = ChunkedIndex.from_size(len(new_handler)) return Dataset( new_handler, self.__header, self.__builders, self.__base_unit_transformations, - np.arange(len(new_handler)), + new_index, ) def take(self, n: int, at: str = "start") -> Dataset: @@ -341,28 +343,13 @@ def take(self, n: int, at: str = "start") -> Dataset: or if 'at' is invalid. """ + new_index = self.__index.take(n, at) - if n < 0 or n > len(self): - raise ValueError( - "Invalid value for 'n', must be between 0 and the length of the dataset" - ) - if at == "start": - new_indices = self.__indices[:n] - elif at == "end": - new_indices = self.__indices[-n:] - elif at == "random": - new_indices = np.random.choice(self.__indices, n, replace=False) - new_indices.sort() - - else: - raise ValueError( - "Invalid value for 'at'. Must be one of 'start', 'end', or 'random'." - ) return Dataset( self.__handler, self.__header, self.__builders, self.__base_unit_transformations, - new_indices, + new_index, ) diff --git a/opencosmo/dataset/index.py b/opencosmo/dataset/index.py index 1db151c2..46e51840 100644 --- a/opencosmo/dataset/index.py +++ b/opencosmo/dataset/index.py @@ -1,18 +1,24 @@ from __future__ import annotations -from typing import Protocol +from typing import Protocol, TypeVar, Any import numpy as np import h5py +T = TypeVar("T", np.ndarray, h5py.Dataset) class DataIndex(Protocol): @classmethod def from_size(cls, size: int) -> DataIndex: ... - def get_data(self, data: h5py.Dataset) -> np.ndarray: ... + def set_data(self, data: T, value: Any) -> T: ... + def get_data(self, data: h5py.Dataset | np.ndarray) -> np.ndarray: ... def take(self, n: int, at: str = "random") -> DataIndex: ... + def take_range(self, start: int, end: int) -> DataIndex: ... def mask(self, mask: np.ndarray) -> DataIndex: ... + def range(self) -> tuple[int, int]: ... + def concatenate(self, *others: DataIndex) -> DataIndex: ... def __len__(self) -> int: ... + def __getitem__(self, item: int) -> int: ... class SimpleIndex: @@ -30,6 +36,34 @@ def from_size(cls, size: int) -> SimpleIndex: def __len__(self) -> int: return len(self.__index) + def range(self) -> tuple[int, int]: + """ + Guranteed to be sorted + """ + return self.__index[0], self.__index[-1] + + def concatenate(self, *others: SimpleIndex) -> SimpleIndex: + if len(others) == 0: + return self + if all(isinstance(other, SimpleIndex) for other in others): + new_index = np.concatenate([self.__index] + [other.__index for other in others]) + new_index = np.sort(np.unique(new_index)) + return SimpleIndex(new_index) + else: + simple_indices = map(lambda x: x.to_simple_index() if isinstance(x, ChunkedIndex) else x, others) + return self.concatenate(*simple_indices) + + def set_data(self, data: np.ndarray, value: bool) -> np.ndarray: + """ + Set the data at the index to the given value. + """ + if not isinstance(data, np.ndarray): + raise ValueError("Data must be a numpy array") + + data[self.__index] = value + return data + + def take(self, n: int, at: str = "random") -> SimpleIndex: """ Take n elements from the index. @@ -45,6 +79,18 @@ def take(self, n: int, at: str = "random") -> SimpleIndex: else: raise ValueError(f"Unknown value for 'at': {at}") + def take_range(self, start: int, end: int) -> SimpleIndex: + """ + Take a range of elements from the index. + """ + if start < 0 or end > len(self): + raise ValueError(f"Range {start}:{end} is out of bounds for index of size {len(self)}") + + if start >= end: + raise ValueError(f"Start {start} must be less than end {end}") + + return SimpleIndex(self.__index[start:end]) + def mask(self, mask: np.ndarray) -> SimpleIndex: if mask.shape != self.__index.shape: raise ValueError(f"Mask shape {mask.shape} does not match index size {len(self)}") @@ -64,8 +110,10 @@ def get_data(self, data: h5py.Dataset) -> np.ndarray: """ Get the data from the dataset using the index. """ - if not isinstance(data, h5py.Dataset): + if not isinstance(data, (h5py.Dataset, np.ndarray)): raise ValueError("Data must be a h5py.Dataset") + if len(self) == 0: + return np.array([], dtype=data.dtype) min_index = self.__index.min() max_index = self.__index.max() @@ -73,12 +121,76 @@ def get_data(self, data: h5py.Dataset) -> np.ndarray: indices_into_output = self.__index - min_index return output[indices_into_output] + def __getitem__(self, item: int) -> SimpleIndex: + """ + Get an item from the index. + """ + if item < 0 or item >= len(self): + raise IndexError(f"Index {item} out of bounds for index of size {len(self)}") + val = self.__index[item] + return SimpleIndex(np.array([val])) + +def pack(start: np.ndarray, size: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Combine adjacent chunks into a single chunk. + """ + + # Calculate the end of each chunk + end = start + size + + # Determine where a new chunk should start (i.e., not adjacent to previous) + # We prepend True for the first chunk to always start a group + new_group = np.ones(len(start), dtype=bool) + new_group[1:] = start[1:] != end[:-1] + + # Assign a group ID for each segment + group_ids = np.cumsum(new_group) + + # Combine chunks by group + combined_start = np.zeros(group_ids[-1], dtype=start.dtype) + combined_size = np.zeros_like(combined_start) + + np.add.at(combined_start, group_ids - 1, np.where(new_group, start, 0)) + np.add.at(combined_size, group_ids - 1, size) + + return combined_start, combined_size + class ChunkedIndex: def __init__(self, starts: np.ndarray, sizes: np.ndarray) -> None: + # sort the starts and sizes + # pack the starts and sizes self.__starts = starts self.__sizes = sizes + def range(self) -> tuple[int, int]: + """ + Get the range of the index. + """ + return self.__starts[0], self.__starts[-1] + self.__sizes[-1] - 1 + + def to_simple_index(self) -> SimpleIndex: + """ + Convert the ChunkedIndex to a SimpleIndex. + """ + idxs = np.concatenate([np.arange(start, start + size) for start, size in zip(self.__starts, self.__sizes)]) + idxs = np.unique(idxs) + return SimpleIndex(idxs) + + def concatenate(self, *others: DataIndex) -> DataIndex: + if len(others) == 0: + return self + if all(isinstance(other, ChunkedIndex) for other in others): + new_starts = np.concatenate([self.__starts] + [other.__starts for other in others]) + new_sizes = np.concatenate([self.__sizes] + [other.__sizes for other in others]) + return ChunkedIndex(new_starts, new_sizes) + + else: + simple_indices = map(lambda x: x.to_simple_index() if isinstance(x, ChunkedIndex) else x, others) + return self.concatenate(*simple_indices) + + + @classmethod def from_size(cls, size: int) -> ChunkedIndex: """ @@ -91,6 +203,30 @@ def from_size(cls, size: int) -> ChunkedIndex: starts = np.array([0]) sizes = np.array([size]) return ChunkedIndex(starts, sizes) + + @classmethod + def single_chunk(cls, start: int, size: int) -> ChunkedIndex: + """ + Create a ChunkedIndex with a single chunk. + """ + if size <= 0: + raise ValueError(f"Size must be positive, got {size}") + if start < 0: + raise ValueError(f"Start must be non-negative, got {start}") + starts = np.array([start]) + sizes = np.array([size]) + return ChunkedIndex(starts, sizes) + + def set_data(self, data: np.ndarray, value: bool) -> np.ndarray: + """ + Set the data at the index to the given value. + """ + if not isinstance(data, np.ndarray): + raise ValueError("Data must be a numpy array") + + for start, size in zip(self.__starts, self.__sizes): + data[start:start + size] = value + return data def __len__(self) -> int: """ @@ -107,18 +243,35 @@ def take(self, n: int, at: str = "random") -> DataIndex: idxs = np.random.choice(idxs, n, replace=False) return SimpleIndex(idxs) elif at == "start": - last_chunk_in_range = np.searchsorted(np.sum(self.__sizes), n) - new_starts = self.__starts[:last_chunk_in_range] - new_sizes = self.__sizes[:last_chunk_in_range] + last_chunk_in_range = np.searchsorted(np.cumsum(self.__sizes), n) + new_starts = self.__starts[:last_chunk_in_range+1].copy() + new_sizes = self.__sizes[:last_chunk_in_range+1].copy() new_sizes[-1] = n - np.sum(new_sizes[:-1]) return ChunkedIndex(new_starts, new_sizes) elif at == "end": - first_chunk_in_range = np.searchsorted(np.sum(self.__sizes), len(self) - n) - new_starts = self.__starts[first_chunk_in_range:] - new_sizes = self.__sizes[first_chunk_in_range:] + starting_chunk = np.searchsorted(np.cumsum(self.__sizes), len(self) - n) + new_sizes = self.__sizes[starting_chunk:].copy() + new_starts = self.__starts[starting_chunk:].copy() new_sizes[0] = n - np.sum(new_sizes[1:]) + new_starts[0] = self.__starts[starting_chunk] + self.__sizes[starting_chunk] - new_sizes[0] return ChunkedIndex(new_starts, new_sizes) + def take_range(self, start: int, end: int) -> DataIndex: + """ + Take a range of elements from the index. + """ + if start < 0 or end > len(self): + raise ValueError(f"Range {start}:{end} is out of bounds for index of size {len(self)}") + + if start >= end: + raise ValueError(f"Start {start} must be less than end {end}") + + # Get the indices of the chunks that are in the range + idxs = np.concatenate([np.arange(start, start + size) for start, size in zip(self.__starts, self.__sizes)]) + range_idxs = idxs[start:end] + + return SimpleIndex(range_idxs) + def mask(self, mask: np.ndarray) -> DataIndex: """ Mask the index with a boolean mask. @@ -142,16 +295,63 @@ def mask(self, mask: np.ndarray) -> DataIndex: return SimpleIndex(masked_idxs) - def get_data(self, data: h5py.Dataset) -> np.ndarray: + def get_data(self, data: h5py.Dataset | np.ndarray) -> np.ndarray: """ - Get the data from the dataset using the index. + Get the data from the dataset using the index. We want to perform as few reads + as possible. However, the chunks may not be continuous. This method sorts the + chunks so it can read the data in the largest possible chunks, it then + reshuffles the data to match the original order. + + For large numbers of chunks, this is much much faster than reading each chunk + in the order they are stored in the index. I know because I tried. It sucked. """ - if not isinstance(data, h5py.Dataset): + if not isinstance(data, (h5py.Dataset, np.ndarray)): raise ValueError("Data must be a h5py.Dataset") - output = np.zeros(len(self), dtype=data.dtype) + + if len(self) == 0: + return np.array([], dtype=data.dtype) + if len(self.__starts) == 1: + return data[self.__starts[0]:self.__starts[0] + self.__sizes[0]] + + sorted_start_index = np.argsort(self.__starts) + new_starts = self.__starts[sorted_start_index] + new_sizes = self.__sizes[sorted_start_index] + + packed_starts, packed_sizes = pack(new_starts, new_sizes) + + + shape = (len(self),) + data.shape[1:] + temp = np.zeros(shape, dtype=data.dtype) running_index = 0 - for start, size in zip(self.__starts, self.__sizes): - output[running_index:running_index + size] = data[start:start + size] + for i, (start, size) in enumerate(zip(packed_starts, packed_sizes)): + temp[running_index:running_index + size] = data[start:start + size] running_index += size + + output = np.zeros(len(self), dtype=data.dtype) + cumulative_sorted_sizes = np.insert(np.cumsum(new_sizes), 0, 0) + cumulative_original_sizes = np.insert(np.cumsum(self.__sizes), 0, 0) + + # reshuffle the output to match the original order + for i, sorted_index in enumerate(sorted_start_index): + start = cumulative_original_sizes[sorted_index] + end = cumulative_original_sizes[sorted_index + 1] + data = temp[cumulative_sorted_sizes[i]:cumulative_sorted_sizes[i + 1]] + output[start: end] = data + + + return output + + def __getitem__(self, item: int) -> SimpleIndex: + """ + Get an item from the index. + """ + if item < 0 or item >= len(self): + raise IndexError(f"Index {item} out of bounds for index of size {len(self)}") + sums = np.cumsum(self.__sizes) + index = np.searchsorted(sums, item) + start = self.__starts[index] + offset = item - sums[index - 1] if index > 0 else item + return SimpleIndex(np.array([start + offset])) + diff --git a/opencosmo/dataset/mask.py b/opencosmo/dataset/mask.py index 904efaad..267d110b 100644 --- a/opencosmo/dataset/mask.py +++ b/opencosmo/dataset/mask.py @@ -11,6 +11,7 @@ from opencosmo.dataset.column import ColumnBuilder from opencosmo.handler import OpenCosmoDataHandler +from opencosmo.dataset.index import DataIndex Comparison = Callable[[float, float], bool] @@ -23,7 +24,7 @@ def apply_masks( handler: OpenCosmoDataHandler, column_builders: dict[str, ColumnBuilder], masks: Iterable[Mask], - indices: np.ndarray, + index: DataIndex, ) -> np.ndarray: masks_by_column = defaultdict(list) for f in masks: @@ -36,16 +37,16 @@ def apply_masks( "masks were applied to columns that do not exist in the dataset: " f"{mask_column_names - column_names}" ) - output_indices = indices + output_index = index for column_name, column_masks in masks_by_column.items(): - column_mask = np.ones(len(output_indices), dtype=bool) + column_mask = np.ones(len(output_index), dtype=bool) builder = column_builders[column_name] - column = handler.get_data({column_name: builder}, output_indices) + column = handler.get_data({column_name: builder}, output_index) for f in column_masks: column_mask &= f.apply(column) - output_indices = output_indices[column_mask] - return output_indices + output_index = output_index.mask(column_mask) + return output_index class Column: diff --git a/opencosmo/handler/handler.py b/opencosmo/handler/handler.py index 47bc148b..cf2ff469 100644 --- a/opencosmo/handler/handler.py +++ b/opencosmo/handler/handler.py @@ -8,6 +8,7 @@ from astropy.table import Column, Table # type: ignore from opencosmo.dataset.column import ColumnBuilder +from opencosmo.dataset.index import DataIndex class OpenCosmoDataHandler(Protocol): @@ -43,15 +44,12 @@ def collect( def write( self, file: h5py.File, - indices: np.ndarray, + index: DataIndex, columns: Iterable[str], dataset_name: Optional[str] = None, ) -> None: ... def get_data( self, column_builders: dict[str, ColumnBuilder], - indices: np.ndarray, + index: DataIndex, ) -> Column | Table: ... - def take_indices( - self, n: int, strategy: str, indices: np.ndarray - ) -> np.ndarray: ... diff --git a/opencosmo/handler/im.py b/opencosmo/handler/im.py index d0c80ec0..7f389ec1 100644 --- a/opencosmo/handler/im.py +++ b/opencosmo/handler/im.py @@ -8,6 +8,7 @@ from opencosmo.file import get_data_structure from opencosmo.spatial.tree import Tree +from opencosmo.dataset.index import DataIndex, ChunkedIndex class InMemoryHandler: @@ -24,7 +25,7 @@ def __init__( tree: Tree, group_name: Optional[str] = None, columns: Optional[Iterable[str]] = None, - indices: Optional[np.ndarray] = None, + index: Optional[DataIndex] = None, ): if group_name is None: group = file["data"] @@ -34,11 +35,10 @@ def __init__( if columns is not None: self.__columns = {n: u for n, u in self.__columns.items() if n in columns} self.__tree = tree - self.__data = {colname: group[colname][()] for colname in self.__columns} - if indices is not None: - self.__data = { - colname: self.__data[colname][indices] for colname in self.__columns - } + if index is None: + length = len(next(iter(group.values()))) + index = ChunkedIndex.from_size(length) + self.__data = {colname: index.get_data(group[colname]) for colname in self.__columns} def __len__(self) -> int: return len(next(iter(self.__data.values()))) @@ -63,7 +63,7 @@ def collect(self, columns: Iterable[str], indices: np.ndarray) -> InMemoryHandle def write( self, file: h5py.File, - indices: np.ndarray, + index: DataIndex, columns: Iterable[str], dataset_name: Optional[str] = None, ) -> None: @@ -76,7 +76,7 @@ def write( group = file.require_group(dataset_name) data_group = group.require_group("data") for column in columns: - data_group.create_dataset(column, data=self.__data[column][indices]) + data_group.create_dataset(column, data=index.get_data(self.__data[column])) if self.__columns[column] is not None: data_group[column].attrs["unit"] = self.__columns[column] mask = np.zeros(len(self), dtype=bool) @@ -86,7 +86,7 @@ def write( def get_data( self, builders: dict, - indices: np.ndarray, + index: DataIndex, ) -> Column | Table: """ Get data from the in-memory storage with optional masking and column @@ -95,19 +95,9 @@ def get_data( output = {} for column, builder in builders.items(): - col = self.__data[column][indices] + col = index.get_data(self.__data[column]) output[column] = builder.build(Column(col, name=column)) if len(output) == 1: return next(iter(output.values())) return Table(output) - - def take_indices(self, n: int, strategy: str, indices: np.ndarray): - if n < 0 or n > len(indices): - raise ValueError("n must be between 0 and then number of available rows") - if strategy == "random": - return np.sort(np.random.choice(indices, n, replace=False)) - elif strategy == "start": - return indices[:n] - elif strategy == "end": - return indices[-n:] diff --git a/opencosmo/handler/mpi.py b/opencosmo/handler/mpi.py index d3ce640d..706df35e 100644 --- a/opencosmo/handler/mpi.py +++ b/opencosmo/handler/mpi.py @@ -9,7 +9,22 @@ from opencosmo.file import get_data_structure from opencosmo.handler import InMemoryHandler from opencosmo.spatial.tree import Tree -from opencosmo.utils import read_indices +from opencosmo.dataset.index import DataIndex + + +def partition(comm: MPI.Comm, length: int) -> Tuple[int, int]: + nranks = comm.Get_size() + rank = comm.Get_rank() + if rank == nranks - 1: + start = rank * (length // nranks) + size = length - start + return (start, size) + + start = rank * (length // nranks) + end = (rank + 1) * (length // nranks) + size = end - start + return (start, size) + def verify_input(comm: MPI.Comm, require: Iterable[str] = [], **kwargs) -> dict: @@ -54,7 +69,6 @@ def __init__( tree: Tree, group_name: Optional[str] = None, comm=MPI.COMM_WORLD, - rank_range: Optional[Tuple[int, int]] = None, ): self.__file = file self.__group_name = group_name @@ -65,25 +79,9 @@ def __init__( self.__columns = get_data_structure(self.__group) self.__comm = comm self.__tree = tree - self.__elem_range = rank_range - - def elem_range(self) -> Tuple[int, int]: - """ - The full dataset will be split into equal parts by rank. - """ - if self.__elem_range is not None: - return self.__elem_range - nranks = self.__comm.Get_size() - rank = self.__comm.Get_rank() - n = self.__group[next(iter(self.__columns))].shape[0] - - if rank == nranks - 1: - return (rank * (n // nranks), n) - return (rank * (n // nranks), (rank + 1) * (n // nranks)) def __len__(self) -> int: - range_ = self.elem_range() - return range_[1] - range_[0] + return next(iter(self.__group.values())).shape[0] def __enter__(self): return self @@ -93,29 +91,27 @@ def __exit__(self, *exec_details): self.__columns = None return self.__file.close() - def collect(self, columns: Iterable[str], indices: np.ndarray) -> InMemoryHandler: + def collect(self, columns: Iterable[str], index: DataIndex) -> InMemoryHandler: # concatenate the masks from all ranks columns = list(columns) columns = verify_input(comm=self.__comm, columns=columns)["columns"] - range_ = self.elem_range() - rank_indices = indices + range_[0] - all_indices = self.__comm.allgather(rank_indices) + all_indices = self.__comm.allgather(index) file_path = self.__file.filename - all_indices = np.concatenate(all_indices) + all_indices = all_indices[0].concatenate(*all_indices[1:]) with h5py.File(file_path, "r") as file: return InMemoryHandler( file, tree=self.__tree, columns=columns, - indices=all_indices, + index=all_indices, group_name=self.__group_name, ) def write( self, file: h5py.File, - indices: np.ndarray, + index: DataIndex, columns: Iterable[str], dataset_name: Optional[str] = None, selected: Optional[np.ndarray] = None, @@ -129,10 +125,9 @@ def write( ) columns = input_["columns"] - rank_range = self.elem_range() # indices = redistribute_indices(indices, rank_range) - rank_output_length = len(indices) + rank_output_length = len(index) all_output_lengths = self.__comm.allgather(rank_output_length) @@ -156,12 +151,7 @@ def write( for column in columns: # This step has to be done by all ranks, per documentation - shape: Tuple[int, ...] - if len(self.__group[column].shape) != 1: - shape = (full_output_length, self.__group[column].shape[1]) - else: - shape = (full_output_length,) - + shape = (full_output_length,) + self.__group[column].shape[1:] data_group.create_dataset(column, shape, dtype=self.__group[column].dtype) if self.__columns[column] is not None: data_group[column].attrs["unit"] = self.__columns[column] @@ -170,15 +160,13 @@ def write( if rank_output_length != 0: for column in columns: - data = self.__group[column][rank_range[0] : rank_range[1]][()] - data = data[indices] - + data = index.get_data(self.__group[column]) data_group[column][rank_start:rank_end] = data mask = np.zeros(len(self), dtype=bool) - mask[indices] = True + mask = index.set_data(mask, True) - new_tree = self.__tree.apply_mask(mask, self.__comm, self.elem_range()) + new_tree = self.__tree.apply_mask(mask, self.__comm, index.range()) new_tree.write(group) # type: ignore @@ -187,7 +175,7 @@ def write( def get_data( self, builders: dict, - indices: np.ndarray, + index: DataIndex ) -> Column | Table: """ Get data from the file in the range for this rank. @@ -199,11 +187,7 @@ def get_data( output = {} for column in builder_keys: - col = read_indices( - self.__group[column], - indices, - self.elem_range(), - ) + col = Column(index.get_data(self.__group[column])) output[column] = builders[column].build(col) if len(output) == 1: return next(iter(output.values())) @@ -215,55 +199,3 @@ def take_range(self, start: int, end: int, indices: np.ndarray) -> np.ndarray: return indices[start:end] - def take_indices(self, n: int, strategy: str, indices: np.ndarray) -> np.ndarray: - """ - masks are localized to each rank. For "start" and "end" it's just a matter of - figuring out how many elements each rank is responsible for. For "random" we - need to be more clever. - """ - - rank_length = len(indices) - rank_lengths = self.__comm.allgather(rank_length) - - total_length = np.sum(rank_lengths) - if n > total_length: - # All ranks crash - raise ValueError( - f"Requested {n} elements, but only {total_length} are available." - ) - n = total_length - - if self.__comm.Get_rank() == 0: - if strategy == "random": - take_indices = np.random.choice(total_length, n, replace=False) - take_indices = np.sort(take_indices) - elif strategy == "start": - take_indices = np.arange(n) - elif strategy == "end": - take_indices = np.arange(total_length - n, total_length) - # Distribute the indices to the ranks - else: - take_indices = None - take_indices = self.__comm.bcast(take_indices, root=0) - - if take_indices is None: - # Should not happen, but this is for mypy - raise ValueError("Indices should not be None.") - - rank_start_index = self.__comm.Get_rank() - if rank_start_index: - rank_start_index = np.sum(rank_lengths[: self.__comm.Get_rank()]) - rank_end_index = rank_start_index + rank_length - - rank_indicies = take_indices[ - (take_indices >= rank_start_index) & (take_indices < rank_end_index) - ] - if len(rank_indicies) == 0: - # This rank doesn't have enough data - warn( - "This take operation will return no data for rank " - f"{self.__comm.Get_rank()}" - ) - return np.array([], dtype=int) - - return rank_indicies - rank_start_index diff --git a/opencosmo/handler/oom.py b/opencosmo/handler/oom.py index 36be2aaa..f00dc5e8 100644 --- a/opencosmo/handler/oom.py +++ b/opencosmo/handler/oom.py @@ -9,7 +9,8 @@ from opencosmo.dataset.column import ColumnBuilder from opencosmo.handler import InMemoryHandler from opencosmo.spatial.tree import Tree -from opencosmo.utils import read_indices, write_indices +from opencosmo.utils import write_index +from opencosmo.dataset.index import DataIndex class OutOfMemoryHandler: @@ -38,13 +39,13 @@ def __exit__(self, *exec_details): self.__group = None return self.__file.close() - def collect(self, columns: Iterable[str], indices: np.ndarray) -> InMemoryHandler: + def collect(self, columns: Iterable[str], index: DataIndex) -> InMemoryHandler: file_path = self.__file.filename - if len(indices) == len(self): + if len(index) == len(self): tree = self.__tree else: mask = np.zeros(len(self), dtype=bool) - mask[indices] = True + mask = index.set_data(mask, True) tree = self.__tree.apply_mask(mask) with h5py.File(file_path, "r") as file: @@ -53,13 +54,13 @@ def collect(self, columns: Iterable[str], indices: np.ndarray) -> InMemoryHandle tree, group_name=self.__group_name, columns=columns, - indices=indices, + index=index, ) def write( self, file: h5py.File, - indices: np.ndarray, + index: DataIndex, columns: Iterable[str], dataset_name: Optional[str] = None, ) -> None: @@ -71,64 +72,28 @@ def write( group = file.require_group(dataset_name) data_group = group.create_group("data") for column in columns: - write_indices(self.__group[column], data_group, indices) + write_index(self.__group[column], data_group, index) tree_mask = np.zeros(len(self), dtype=bool) - tree_mask[indices] = True + tree_mask = index.set_data(tree_mask, True) + tree = self.__tree.apply_mask(tree_mask) tree.write(group) - def get_data(self, builders: dict, indices: np.ndarray) -> Column | Table: + def get_data(self, builders: dict, index: DataIndex) -> Column | Table: """ """ if self.__group is None: raise ValueError("This file has already been closed") output = {} for column, builder in builders.items(): - col = read_indices(self.__group[column], indices) + col = Column(index.get_data(self.__group[column])) output[column] = builder.build(col) if len(output) == 1: return next(iter(output.values())) return Table(output) - def get_range( - self, - start: int, - end: int, - builders: dict[str, ColumnBuilder], - indices: np.ndarray, - ) -> dict[str, tuple[float, float]]: - if self.__group is None: - raise ValueError("This file has already been closed") - output = {} - start_idx = indices[start] - end_idx = indices[end] + 1 - for column, builder in builders.items(): - data = self.__group[column][start_idx:end_idx] - data = data[indices[start:end]] - col = Column(data, name=column) - output[column] = builder.build(col) - - return Table(output) - def take_range(self, start: int, end: int, indices: np.ndarray) -> np.ndarray: if start < 0 or end > len(indices): raise ValueError("Indices out of range") return indices[start:end] - - def take_indices(self, n: int, strategy: str, indices: np.ndarray) -> np.ndarray: - if n > (length := len(indices)): - raise ValueError( - f"Requested {n} elements, but only {length} are available." - ) - - if strategy == "start": - return indices[:n] - elif strategy == "end": - return indices[-n:] - elif strategy == "random": - return np.sort(np.random.choice(indices, n, replace=False)) - else: - raise ValueError( - "Strategy for `take` must be one of 'start', 'end', or 'random'" - ) diff --git a/opencosmo/io.py b/opencosmo/io.py index f94d2c88..c1ed0dde 100644 --- a/opencosmo/io.py +++ b/opencosmo/io.py @@ -12,15 +12,16 @@ MPI = None # type: ignore from typing import Iterable, Optional -import numpy as np import opencosmo as oc from opencosmo import collection from opencosmo.file import FileExistance, file_reader, file_writer, resolve_path from opencosmo.handler import InMemoryHandler, OpenCosmoDataHandler, OutOfMemoryHandler +from opencosmo.handler.mpi import partition from opencosmo.header import read_header from opencosmo.spatial import read_tree from opencosmo.transformations import units as u +from opencosmo.dataset.index import ChunkedIndex, DataIndex def open( @@ -89,19 +90,21 @@ def open( raise ValueError("Asked for multiple datasets, but file has only one") handler: OpenCosmoDataHandler + index: DataIndex if MPI is not None and MPI.COMM_WORLD.Get_size() > 1: handler = MPIHandler( file_handle, group_name=datasets, tree=tree, comm=MPI.COMM_WORLD ) + start, size = partition(MPI.COMM_WORLD, len(handler)) + index = ChunkedIndex.single_chunk(start, size) else: handler = OutOfMemoryHandler(file_handle, group_name=datasets, tree=tree) - + index = ChunkedIndex.from_size(len(handler)) builders, base_unit_transformations = u.get_default_unit_transformations( group, header ) - mask = np.arange(len(handler)) - dataset = oc.Dataset(handler, header, builders, base_unit_transformations, mask) + dataset = oc.Dataset(handler, header, builders, base_unit_transformations, index) return dataset @@ -147,12 +150,12 @@ def read( header = read_header(file) tree = read_tree(file, header) handler = InMemoryHandler(file, tree, group_name=datasets) - mask = np.arange(len(handler)) + index = ChunkedIndex.from_size(len(handler)) builders, base_unit_transformations = u.get_default_unit_transformations( group, header ) - return oc.Dataset(handler, header, builders, base_unit_transformations, mask) + return oc.Dataset(handler, header, builders, base_unit_transformations, index) @file_writer diff --git a/opencosmo/link/builder.py b/opencosmo/link/builder.py index a8521f52..131b2eaf 100644 --- a/opencosmo/link/builder.py +++ b/opencosmo/link/builder.py @@ -11,6 +11,7 @@ from opencosmo.header import OpenCosmoHeader from opencosmo.spatial import read_tree from opencosmo.transformations import units as u +from opencosmo.dataset.index import DataIndex, ChunkedIndex, SimpleIndex try: from mpi4py import MPI @@ -100,7 +101,7 @@ def build( self, file: File | Group, header: OpenCosmoHeader, - indices: Optional[np.ndarray] = None, + index: Optional[DataIndex] = None, ) -> Dataset: tree = read_tree(file, header) builders, base_unit_transformations = u.get_default_unit_transformations( @@ -122,23 +123,14 @@ def build( handler = OutOfMemoryHandler(file, tree=tree) - if indices is None: - indices_ = np.arange(len(handler)) - - elif len(indices) > 0: - if indices[0] < 0 or indices[-1] >= len(handler): - raise ValueError( - "Indices must be within 0 and the length of the dataset." - ) - indices_ = indices - else: - indices_ = indices + if index is None: + index = ChunkedIndex.from_size(len(handler)) dataset = Dataset( handler, header, builders, base_unit_transformations, - indices_, + index, ) return dataset diff --git a/opencosmo/link/collection.py b/opencosmo/link/collection.py index a8a45555..9293d59e 100644 --- a/opencosmo/link/collection.py +++ b/opencosmo/link/collection.py @@ -49,7 +49,7 @@ def __init__( self.__properties = properties self.__handlers = handlers - self.__idxs = self.__properties.indices + self.__index = self.__properties.index self.__filters = filters def __repr__(self): @@ -104,8 +104,8 @@ def __getitem__(self, key: str) -> oc.Dataset: return self.__properties elif key not in self.__handlers: raise KeyError(f"Dataset {key} not found in collection.") - indices = self.__properties.indices - return self.__handlers[key].get_data(indices) + index = self.__properties.index + return self.__handlers[key].get_data(index) def __enter__(self): return self @@ -192,7 +192,7 @@ def objects( handlers = {dt: self.__handlers[dt] for dt in data_types} for i, row in enumerate(self.__properties.rows()): - index = np.array(self.__properties.indices[i]) + index = self.__properties.index[i] output = {key: handler.get_data(index) for key, handler in handlers.items()} if not any(len(v) for v in output.values()): continue @@ -210,4 +210,5 @@ def write(self, file: File | Group): keys.sort() for key in keys: handler = self.__handlers[key] - handler.write(file, link_group, key, self.__idxs) + handler.write(file, link_group, key, self.__index) + diff --git a/opencosmo/link/handler.py b/opencosmo/link/handler.py index a5d59f2b..343beaff 100644 --- a/opencosmo/link/handler.py +++ b/opencosmo/link/handler.py @@ -11,6 +11,7 @@ from opencosmo.link.builder import DatasetBuilder, OomDatasetBuilder from opencosmo.spatial import read_tree from opencosmo.transformations import units as u +from opencosmo.dataset.index import DataIndex, SimpleIndex, ChunkedIndex def build_dataset( @@ -54,7 +55,7 @@ def __init__( """ pass - def get_data(self, indices: int | np.ndarray) -> oc.Dataset: + def get_data(self, indices: int | DataIndex) -> oc.Dataset: """ Given a index or a set of indices, return the data from the linked dataset that corresponds to the halo/galaxy at that index in the properties file. @@ -119,31 +120,25 @@ def __init__( def get_all_data(self) -> oc.Dataset: return build_dataset(self.file, self.header) - def get_data(self, indices: int | np.ndarray) -> oc.Dataset: - if isinstance(indices, int): - indices = np.array([indices], dtype=int) - min_idx = np.min(indices) - max_idx = np.max(indices) - + def get_data(self, index: DataIndex) -> oc.Dataset: if isinstance(self.link, tuple): - start = self.link[0][min_idx : max_idx + 1][indices - min_idx] - size = self.link[1][min_idx : max_idx + 1][indices - min_idx] + start = index.get_data(self.link[0]) + size = index.get_data(self.link[1]) valid_rows = size > 0 start = start[valid_rows] size = size[valid_rows] if not start.size: - indices_into_data = np.array([], dtype=int) + new_index = SimpleIndex(np.array([], dtype=int)) else: - indices_into_data = np.concatenate( - [np.arange(idx, idx + length) for idx, length in zip(start, size)] - ) + new_index = ChunkedIndex(start, size) else: - indices_into_data = self.link[min_idx : max_idx + 1][indices - min_idx] - indices_into_data = np.array(indices_into_data[indices_into_data >= 0]) + indices_into_data = index.get_data(self.link) + indices_into_data = indices_into_data[indices_into_data >= 0] + new_index = SimpleIndex(indices_into_data) if not indices_into_data.size: - indices_into_data = np.array([], dtype=int) + indices_into_data = SimpleIndex(np.array([], dtype=int)) - return self.builder.build(self.file, self.header, indices_into_data) + return self.builder.build(self.file, self.header, new_index) def select(self, columns: str | Iterable[str]) -> OomLinkHandler: if isinstance(columns, str): @@ -165,24 +160,22 @@ def with_units(self, convention: str) -> OomLinkHandler: ) def write( - self, group: Group, link_group: Group, name: str, indices: int | np.ndarray + self, group: Group, link_group: Group, name: str, index: DataIndex ): - if isinstance(indices, int): - indices = np.array([indices]) # Pack the indices if not isinstance(self.link, tuple): - new_idxs = np.full(len(indices), -1) - current_values = self.link[indices[0] : indices[-1] + 1] - current_values = current_values[indices - indices[0]] + new_idxs = np.full(len(index), -1) + current_values = index.get_data(self.link) has_data = current_values >= 0 new_idxs[has_data] = np.arange(sum(has_data)) link_group.create_dataset("sod_profile_idx", data=new_idxs, dtype=int) else: - lengths = self.link[1][indices] + lengths = index.get_data(self.link[1]) new_starts = np.insert(np.cumsum(lengths), 0, 0)[:-1] link_group.create_dataset(f"{name}_start", data=new_starts, dtype=int) link_group.create_dataset(f"{name}_size", data=lengths, dtype=int) - dataset = self.get_data(indices) - if dataset is not None: + dataset = self.get_data(index) + + if len(dataset) > 0: dataset.write(group, name) diff --git a/opencosmo/link/mpi.py b/opencosmo/link/mpi.py index 69742e65..3f7dfe1b 100644 --- a/opencosmo/link/mpi.py +++ b/opencosmo/link/mpi.py @@ -9,11 +9,13 @@ import opencosmo as oc from opencosmo.dataset.column import ColumnBuilder, get_column_builders from opencosmo.handler import MPIHandler +from opencosmo.handler.mpi import partition from opencosmo.header import OpenCosmoHeader from opencosmo.link.builder import DatasetBuilder from opencosmo.spatial import Tree, read_tree from opencosmo.transformations import TransformationDict from opencosmo.transformations import units as u +from opencosmo.dataset.index import DataIndex, SimpleIndex, ChunkedIndex def build_dataset( @@ -81,32 +83,28 @@ def get_all_data(self) -> oc.Dataset: self.header, ) - def get_data(self, indices: int | np.ndarray) -> oc.Dataset: - if isinstance(indices, int): - indices = np.array([indices], dtype=int) - + def get_data(self, index: DataIndex) -> oc.Dataset: if isinstance(self.link, tuple): - start = self.link[0][indices + self.offset] - size = self.link[1][indices + self.offset] + start = index.get_data(self.link[0]) + size = index.get_data(self.link[1]) valid_rows = size > 0 start = start[valid_rows] size = size[valid_rows] if len(start) == 0: - indices_into_data = np.array([], dtype=int) + new_index = SimpleIndex(np.array([], dtype=int)) else: - indices_into_data = np.concatenate( - [np.arange(idx, idx + length) for idx, length in zip(start, size)] - ) + new_index = ChunkedIndex(start, size) else: - indices_into_data = self.link[indices + self.offset] + indices_into_data = index.get_data(self.link) indices_into_data = indices_into_data[indices_into_data >= 0] if len(indices_into_data) == 0: indices_into_data = np.array([], dtype=int) + new_index = SimpleIndex(indices_into_data) dataset = self.builder.build( self.file, self.header, - indices=indices_into_data, + index=new_index, ) return dataset @@ -132,12 +130,10 @@ def select(self, columns: str | Iterable[str]) -> MpiLinkHandler: ) def write( - self, data_group: Group, link_group: Group, name: str, indices: int | np.ndarray + self, data_group: Group, link_group: Group, name: str, index: DataIndex ) -> None: # Pack the indices - if isinstance(indices, int): - indices = np.array([indices]) - sizes = self.comm.allgather(len(indices)) + sizes = self.comm.allgather(len(index)) shape = (sum(sizes),) if sum(sizes) == 0: return @@ -145,10 +141,7 @@ def write( if not isinstance(self.link, tuple): link_group.create_dataset("sod_profile_idx", shape=shape, dtype=int) self.comm.Barrier() - start = indices[0] - end = indices[-1] + 1 - indices_into_data = self.link[self.offset + start : self.offset + end] - indices_into_data = indices_into_data[indices - start] + indices_into_data = index.get_data(self.link) nonzero = indices_into_data >= 0 nonzero = self.comm.gather(nonzero) @@ -157,11 +150,12 @@ def write( sod_profile_idx = np.full(len(nonzero), -1) sod_profile_idx[nonzero] = np.arange(sum(nonzero)) link_group["sod_profile_idx"][:] = sod_profile_idx + else: link_group.create_dataset(f"{name}_start", shape=shape, dtype=int) link_group.create_dataset(f"{name}_size", shape=shape, dtype=int) self.comm.Barrier() - rank_sizes = self.link[1][self.offset + indices] + rank_sizes = index.get_data(self.link[1]) all_rank_sizes = self.comm.gather(rank_sizes) if self.comm.Get_rank() == 0: if all_rank_sizes is None: @@ -173,8 +167,8 @@ def write( link_group[f"{name}_start"][:] = starts link_group[f"{name}_size"][:] = all_sizes - dataset = self.get_data(indices) - + self.comm.Barrier() + dataset = self.get_data(index) if dataset is not None: dataset.write(data_group, name) @@ -239,7 +233,7 @@ def build( self, file: File | Group, header: OpenCosmoHeader, - indices: Optional[np.ndarray] = None, + index: Optional[DataIndex] = None, ) -> oc.Dataset: builders, base_unit_transformations = u.get_default_unit_transformations( file, header @@ -256,23 +250,20 @@ def build( builders = {key: builders[key] for key in selected} - rank_range = None - if indices is not None and len(indices) > 0: - rank_range = (indices.min(), indices.max() + 1) - indices = indices - rank_range[0] handler = MPIHandler( - file, tree=self.tree, comm=self.comm, rank_range=rank_range + file, tree=self.tree, comm=self.comm ) - if indices is None: - indices = np.arange(len(handler)) + if index is None: + start, size = partition(self.comm, len(handler)) + index = ChunkedIndex.single_chunk(start, size) dataset = oc.Dataset( handler, header, builders, base_unit_transformations, - indices, + index, ) return dataset diff --git a/opencosmo/utils.py b/opencosmo/utils.py index 4a4c4279..be164adc 100644 --- a/opencosmo/utils.py +++ b/opencosmo/utils.py @@ -1,6 +1,7 @@ """ I/O utilities for hdf5 """ +from opencosmo.dataset.index import DataIndex from typing import Optional @@ -10,33 +11,15 @@ from h5py import Dataset, Group -def read_indices( - ds: Dataset, indices: np.ndarray, range_: Optional[tuple[int, int]] = None -) -> Column: - if len(indices) == 0: - return Column([], name=ds.name) - indices_into_data = indices - if range_ is not None: - indices_into_data = indices_into_data + range_[0] - else: - range_ = (0, indices_into_data.max()) - - if indices_into_data.max() > range_[1]: - raise ValueError("Tried to get indices outside the range of the dataset") - - data = ds[range_[0] : range_[1] + 1] - return Column(data[indices_into_data - range_[0]], name=ds.name) - - -def write_indices( +def write_index( input_ds: Dataset, output_group: Group, - indices: np.ndarray, + index: DataIndex, range_: Optional[tuple[int, int]] = None, ): - if len(indices) == 0: + if len(index) == 0: raise ValueError("No indices provided to write") - data = read_indices(input_ds, indices, range_).data + data = index.get_data(input_ds) output_name = input_ds.name.split("/")[-1] compression = hdf5plugin.Blosc2(cname="lz4", filters=hdf5plugin.Blosc2.BITSHUFFLE) diff --git a/test/test_collection.py b/test/test_collection.py index 670425dc..776d9e13 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -179,16 +179,21 @@ def test_collection_of_linked(galaxy_paths, galaxy_paths_2, tmp_path): datasets = {"scidac_01": galaxies_1, "scidac_02": galaxies_2} collection = SimulationCollection(datasets) + collection = collection.filter(oc.col("gal_mass") > 10**12).take(50, at="start") + oc.write(tmp_path / "galaxies.hdf5", collection) dataset = oc.open(tmp_path / "galaxies.hdf5") - dataset = dataset.filter(oc.col("gal_mass") > 10**12).take(10, at="random") + j = 0 for ds in dataset.values(): for props, particles in ds.objects(): gal_tag = props["gal_tag"] gal_tags = set(particles.data["gal_tag"]) assert len(gal_tags) == 1 assert gal_tags.pop() == gal_tag + j += 1 + + assert j == 100 def test_multiple_properties(galaxy_paths, halo_paths): From e3f9725c968957309bcd48afc47427d271e64a1d Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 14 Apr 2025 10:54:41 -0500 Subject: [PATCH 06/11] Formatting --- opencosmo/collection/collection.py | 3 +- opencosmo/dataset/dataset.py | 8 +- opencosmo/dataset/index.py | 131 +++++++++++++++++++---------- opencosmo/dataset/mask.py | 2 +- opencosmo/handler/im.py | 6 +- opencosmo/handler/mpi.py | 10 +-- opencosmo/handler/oom.py | 5 +- opencosmo/io.py | 3 +- opencosmo/link/builder.py | 2 +- opencosmo/link/collection.py | 9 +- opencosmo/link/handler.py | 8 +- opencosmo/link/mpi.py | 7 +- opencosmo/utils.py | 5 +- 13 files changed, 111 insertions(+), 88 deletions(-) diff --git a/opencosmo/collection/collection.py b/opencosmo/collection/collection.py index 1766c42e..665c197c 100644 --- a/opencosmo/collection/collection.py +++ b/opencosmo/collection/collection.py @@ -14,13 +14,13 @@ import numpy as np import opencosmo as oc +from opencosmo.dataset.index import ChunkedIndex from opencosmo.dataset.mask import Mask from opencosmo.handler import InMemoryHandler, OpenCosmoDataHandler, OutOfMemoryHandler from opencosmo.header import OpenCosmoHeader, read_header from opencosmo.link import StructureCollection from opencosmo.spatial import read_tree from opencosmo.transformations import units as u -from opencosmo.dataset.index import ChunkedIndex class Collection(Protocol): @@ -96,6 +96,7 @@ def write_with_unique_headers(collection: Collection, file: h5py.File): group = file.create_group(key) collection[key].write(group) + def verify_datasets_exist(file: h5py.File, datasets: Iterable[str]): """ Verify a set of datasets exist in a given file. diff --git a/opencosmo/dataset/dataset.py b/opencosmo/dataset/dataset.py index a9902115..34bfb273 100644 --- a/opencosmo/dataset/dataset.py +++ b/opencosmo/dataset/dataset.py @@ -3,17 +3,16 @@ from typing import Generator, Iterable, Optional import h5py -import numpy as np from astropy import units # type: ignore from astropy.table import Table # type: ignore import opencosmo.transformations as t import opencosmo.transformations.units as u from opencosmo.dataset.column import ColumnBuilder, get_column_builders +from opencosmo.dataset.index import ChunkedIndex, DataIndex from opencosmo.dataset.mask import Mask, apply_masks from opencosmo.handler import OpenCosmoDataHandler from opencosmo.header import OpenCosmoHeader, write_header -from opencosmo.dataset.index import DataIndex, ChunkedIndex class Dataset: @@ -196,9 +195,7 @@ def filter(self, *masks: Mask) -> Dataset: """ - new_index = apply_masks( - self.__handler, self.__builders, masks, self.__index - ) + new_index = apply_masks(self.__handler, self.__builders, masks, self.__index) if len(new_index) == 0: raise ValueError("The filter returned no rows!") @@ -345,7 +342,6 @@ def take(self, n: int, at: str = "start") -> Dataset: """ new_index = self.__index.take(n, at) - return Dataset( self.__handler, self.__header, diff --git a/opencosmo/dataset/index.py b/opencosmo/dataset/index.py index 46e51840..d1a2be2a 100644 --- a/opencosmo/dataset/index.py +++ b/opencosmo/dataset/index.py @@ -1,13 +1,14 @@ from __future__ import annotations -from typing import Protocol, TypeVar, Any -import numpy as np -import h5py +from typing import Any, Protocol, TypeVar +import h5py +import numpy as np T = TypeVar("T", np.ndarray, h5py.Dataset) -class DataIndex(Protocol): + +class DataIndex(Protocol): @classmethod def from_size(cls, size: int) -> DataIndex: ... def set_data(self, data: T, value: Any) -> T: ... @@ -23,7 +24,7 @@ def __getitem__(self, item: int) -> int: ... class SimpleIndex: """ - An index of integers. + An index of integers. """ def __init__(self, index: np.ndarray) -> None: @@ -46,11 +47,16 @@ def concatenate(self, *others: SimpleIndex) -> SimpleIndex: if len(others) == 0: return self if all(isinstance(other, SimpleIndex) for other in others): - new_index = np.concatenate([self.__index] + [other.__index for other in others]) + new_index = np.concatenate( + [self.__index] + [other.__index for other in others] + ) new_index = np.sort(np.unique(new_index)) return SimpleIndex(new_index) else: - simple_indices = map(lambda x: x.to_simple_index() if isinstance(x, ChunkedIndex) else x, others) + simple_indices = map( + lambda x: x.to_simple_index() if isinstance(x, ChunkedIndex) else x, + others, + ) return self.concatenate(*simple_indices) def set_data(self, data: np.ndarray, value: bool) -> np.ndarray: @@ -63,7 +69,6 @@ def set_data(self, data: np.ndarray, value: bool) -> np.ndarray: data[self.__index] = value return data - def take(self, n: int, at: str = "random") -> SimpleIndex: """ Take n elements from the index. @@ -84,7 +89,9 @@ def take_range(self, start: int, end: int) -> SimpleIndex: Take a range of elements from the index. """ if start < 0 or end > len(self): - raise ValueError(f"Range {start}:{end} is out of bounds for index of size {len(self)}") + raise ValueError( + f"Range {start}:{end} is out of bounds for index of size {len(self)}" + ) if start >= end: raise ValueError(f"Start {start} must be less than end {end}") @@ -93,7 +100,9 @@ def take_range(self, start: int, end: int) -> SimpleIndex: def mask(self, mask: np.ndarray) -> SimpleIndex: if mask.shape != self.__index.shape: - raise ValueError(f"Mask shape {mask.shape} does not match index size {len(self)}") + raise ValueError( + f"Mask shape {mask.shape} does not match index size {len(self)}" + ) if mask.dtype != bool: raise ValueError(f"Mask dtype {mask.dtype} is not boolean") @@ -103,7 +112,7 @@ def mask(self, mask: np.ndarray) -> SimpleIndex: if mask.all(): return self - + return SimpleIndex(self.__index[mask]) def get_data(self, data: h5py.Dataset) -> np.ndarray: @@ -117,7 +126,7 @@ def get_data(self, data: h5py.Dataset) -> np.ndarray: min_index = self.__index.min() max_index = self.__index.max() - output = data[min_index:max_index + 1] + output = data[min_index : max_index + 1] indices_into_output = self.__index - min_index return output[indices_into_output] @@ -126,14 +135,17 @@ def __getitem__(self, item: int) -> SimpleIndex: Get an item from the index. """ if item < 0 or item >= len(self): - raise IndexError(f"Index {item} out of bounds for index of size {len(self)}") + raise IndexError( + f"Index {item} out of bounds for index of size {len(self)}" + ) val = self.__index[item] return SimpleIndex(np.array([val])) + def pack(start: np.ndarray, size: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Combine adjacent chunks into a single chunk. - """ + """ # Calculate the end of each chunk end = start + size @@ -155,8 +167,8 @@ def pack(start: np.ndarray, size: np.ndarray) -> tuple[np.ndarray, np.ndarray]: return combined_start, combined_size -class ChunkedIndex: +class ChunkedIndex: def __init__(self, starts: np.ndarray, sizes: np.ndarray) -> None: # sort the starts and sizes # pack the starts and sizes @@ -173,7 +185,12 @@ def to_simple_index(self) -> SimpleIndex: """ Convert the ChunkedIndex to a SimpleIndex. """ - idxs = np.concatenate([np.arange(start, start + size) for start, size in zip(self.__starts, self.__sizes)]) + idxs = np.concatenate( + [ + np.arange(start, start + size) + for start, size in zip(self.__starts, self.__sizes) + ] + ) idxs = np.unique(idxs) return SimpleIndex(idxs) @@ -181,16 +198,21 @@ def concatenate(self, *others: DataIndex) -> DataIndex: if len(others) == 0: return self if all(isinstance(other, ChunkedIndex) for other in others): - new_starts = np.concatenate([self.__starts] + [other.__starts for other in others]) - new_sizes = np.concatenate([self.__sizes] + [other.__sizes for other in others]) + new_starts = np.concatenate( + [self.__starts] + [other.__starts for other in others] + ) + new_sizes = np.concatenate( + [self.__sizes] + [other.__sizes for other in others] + ) return ChunkedIndex(new_starts, new_sizes) - + else: - simple_indices = map(lambda x: x.to_simple_index() if isinstance(x, ChunkedIndex) else x, others) + simple_indices = map( + lambda x: x.to_simple_index() if isinstance(x, ChunkedIndex) else x, + others, + ) return self.concatenate(*simple_indices) - - @classmethod def from_size(cls, size: int) -> ChunkedIndex: """ @@ -225,9 +247,9 @@ def set_data(self, data: np.ndarray, value: bool) -> np.ndarray: raise ValueError("Data must be a numpy array") for start, size in zip(self.__starts, self.__sizes): - data[start:start + size] = value + data[start : start + size] = value return data - + def __len__(self) -> int: """ Get the total size of the index. @@ -239,13 +261,18 @@ def take(self, n: int, at: str = "random") -> DataIndex: raise ValueError(f"Cannot take {n} elements from index of size {len(self)}") if at == "random": - idxs = np.concatenate([np.arange(start, start + size) for start, size in zip(self.__starts, self.__sizes)]) + idxs = np.concatenate( + [ + np.arange(start, start + size) + for start, size in zip(self.__starts, self.__sizes) + ] + ) idxs = np.random.choice(idxs, n, replace=False) return SimpleIndex(idxs) elif at == "start": last_chunk_in_range = np.searchsorted(np.cumsum(self.__sizes), n) - new_starts = self.__starts[:last_chunk_in_range+1].copy() - new_sizes = self.__sizes[:last_chunk_in_range+1].copy() + new_starts = self.__starts[: last_chunk_in_range + 1].copy() + new_sizes = self.__sizes[: last_chunk_in_range + 1].copy() new_sizes[-1] = n - np.sum(new_sizes[:-1]) return ChunkedIndex(new_starts, new_sizes) elif at == "end": @@ -253,7 +280,11 @@ def take(self, n: int, at: str = "random") -> DataIndex: new_sizes = self.__sizes[starting_chunk:].copy() new_starts = self.__starts[starting_chunk:].copy() new_sizes[0] = n - np.sum(new_sizes[1:]) - new_starts[0] = self.__starts[starting_chunk] + self.__sizes[starting_chunk] - new_sizes[0] + new_starts[0] = ( + self.__starts[starting_chunk] + + self.__sizes[starting_chunk] + - new_sizes[0] + ) return ChunkedIndex(new_starts, new_sizes) def take_range(self, start: int, end: int) -> DataIndex: @@ -261,13 +292,20 @@ def take_range(self, start: int, end: int) -> DataIndex: Take a range of elements from the index. """ if start < 0 or end > len(self): - raise ValueError(f"Range {start}:{end} is out of bounds for index of size {len(self)}") + raise ValueError( + f"Range {start}:{end} is out of bounds for index of size {len(self)}" + ) if start >= end: raise ValueError(f"Start {start} must be less than end {end}") # Get the indices of the chunks that are in the range - idxs = np.concatenate([np.arange(start, start + size) for start, size in zip(self.__starts, self.__sizes)]) + idxs = np.concatenate( + [ + np.arange(start, start + size) + for start, size in zip(self.__starts, self.__sizes) + ] + ) range_idxs = idxs[start:end] return SimpleIndex(range_idxs) @@ -277,7 +315,9 @@ def mask(self, mask: np.ndarray) -> DataIndex: Mask the index with a boolean mask. """ if mask.shape != (len(self),): - raise ValueError(f"Mask shape {mask.shape} does not match index size {len(self)}") + raise ValueError( + f"Mask shape {mask.shape} does not match index size {len(self)}" + ) if mask.dtype != bool: raise ValueError(f"Mask dtype {mask.dtype} is not boolean") @@ -289,17 +329,21 @@ def mask(self, mask: np.ndarray) -> DataIndex: return self # Get the indices of the chunks that are masked - idxs = np.concatenate([np.arange(start, start + size) for start, size in zip(self.__starts, self.__sizes)]) + idxs = np.concatenate( + [ + np.arange(start, start + size) + for start, size in zip(self.__starts, self.__sizes) + ] + ) masked_idxs = idxs[mask] return SimpleIndex(masked_idxs) - def get_data(self, data: h5py.Dataset | np.ndarray) -> np.ndarray: """ Get the data from the dataset using the index. We want to perform as few reads as possible. However, the chunks may not be continuous. This method sorts the - chunks so it can read the data in the largest possible chunks, it then + chunks so it can read the data in the largest possible chunks, it then reshuffles the data to match the original order. For large numbers of chunks, this is much much faster than reading each chunk @@ -308,11 +352,10 @@ def get_data(self, data: h5py.Dataset | np.ndarray) -> np.ndarray: if not isinstance(data, (h5py.Dataset, np.ndarray)): raise ValueError("Data must be a h5py.Dataset") - if len(self) == 0: return np.array([], dtype=data.dtype) if len(self.__starts) == 1: - return data[self.__starts[0]:self.__starts[0] + self.__sizes[0]] + return data[self.__starts[0] : self.__starts[0] + self.__sizes[0]] sorted_start_index = np.argsort(self.__starts) new_starts = self.__starts[sorted_start_index] @@ -320,27 +363,24 @@ def get_data(self, data: h5py.Dataset | np.ndarray) -> np.ndarray: packed_starts, packed_sizes = pack(new_starts, new_sizes) - shape = (len(self),) + data.shape[1:] temp = np.zeros(shape, dtype=data.dtype) running_index = 0 for i, (start, size) in enumerate(zip(packed_starts, packed_sizes)): - temp[running_index:running_index + size] = data[start:start + size] + temp[running_index : running_index + size] = data[start : start + size] running_index += size output = np.zeros(len(self), dtype=data.dtype) cumulative_sorted_sizes = np.insert(np.cumsum(new_sizes), 0, 0) cumulative_original_sizes = np.insert(np.cumsum(self.__sizes), 0, 0) - + # reshuffle the output to match the original order for i, sorted_index in enumerate(sorted_start_index): start = cumulative_original_sizes[sorted_index] end = cumulative_original_sizes[sorted_index + 1] - data = temp[cumulative_sorted_sizes[i]:cumulative_sorted_sizes[i + 1]] - output[start: end] = data + data = temp[cumulative_sorted_sizes[i] : cumulative_sorted_sizes[i + 1]] + output[start:end] = data - - return output def __getitem__(self, item: int) -> SimpleIndex: @@ -348,10 +388,11 @@ def __getitem__(self, item: int) -> SimpleIndex: Get an item from the index. """ if item < 0 or item >= len(self): - raise IndexError(f"Index {item} out of bounds for index of size {len(self)}") + raise IndexError( + f"Index {item} out of bounds for index of size {len(self)}" + ) sums = np.cumsum(self.__sizes) index = np.searchsorted(sums, item) start = self.__starts[index] offset = item - sums[index - 1] if index > 0 else item return SimpleIndex(np.array([start + offset])) - diff --git a/opencosmo/dataset/mask.py b/opencosmo/dataset/mask.py index 267d110b..e497baa7 100644 --- a/opencosmo/dataset/mask.py +++ b/opencosmo/dataset/mask.py @@ -10,8 +10,8 @@ from astropy import table # type: ignore from opencosmo.dataset.column import ColumnBuilder -from opencosmo.handler import OpenCosmoDataHandler from opencosmo.dataset.index import DataIndex +from opencosmo.handler import OpenCosmoDataHandler Comparison = Callable[[float, float], bool] diff --git a/opencosmo/handler/im.py b/opencosmo/handler/im.py index 7f389ec1..846a2cda 100644 --- a/opencosmo/handler/im.py +++ b/opencosmo/handler/im.py @@ -6,9 +6,9 @@ import numpy as np from astropy.table import Column, Table # type: ignore +from opencosmo.dataset.index import ChunkedIndex, DataIndex from opencosmo.file import get_data_structure from opencosmo.spatial.tree import Tree -from opencosmo.dataset.index import DataIndex, ChunkedIndex class InMemoryHandler: @@ -38,7 +38,9 @@ def __init__( if index is None: length = len(next(iter(group.values()))) index = ChunkedIndex.from_size(length) - self.__data = {colname: index.get_data(group[colname]) for colname in self.__columns} + self.__data = { + colname: index.get_data(group[colname]) for colname in self.__columns + } def __len__(self) -> int: return len(next(iter(self.__data.values()))) diff --git a/opencosmo/handler/mpi.py b/opencosmo/handler/mpi.py index 706df35e..c5ab83bb 100644 --- a/opencosmo/handler/mpi.py +++ b/opencosmo/handler/mpi.py @@ -6,10 +6,10 @@ from astropy.table import Column, Table # type: ignore from mpi4py import MPI +from opencosmo.dataset.index import DataIndex from opencosmo.file import get_data_structure from opencosmo.handler import InMemoryHandler from opencosmo.spatial.tree import Tree -from opencosmo.dataset.index import DataIndex def partition(comm: MPI.Comm, length: int) -> Tuple[int, int]: @@ -26,7 +26,6 @@ def partition(comm: MPI.Comm, length: int) -> Tuple[int, int]: return (start, size) - def verify_input(comm: MPI.Comm, require: Iterable[str] = [], **kwargs) -> dict: """ Verify that the input is the same on all ranks. @@ -172,11 +171,7 @@ def write( self.__comm.Barrier() - def get_data( - self, - builders: dict, - index: DataIndex - ) -> Column | Table: + def get_data(self, builders: dict, index: DataIndex) -> Column | Table: """ Get data from the file in the range for this rank. """ @@ -198,4 +193,3 @@ def take_range(self, start: int, end: int, indices: np.ndarray) -> np.ndarray: raise ValueError("Requested range is not within the rank's range.") return indices[start:end] - diff --git a/opencosmo/handler/oom.py b/opencosmo/handler/oom.py index f00dc5e8..9ab3351b 100644 --- a/opencosmo/handler/oom.py +++ b/opencosmo/handler/oom.py @@ -6,11 +6,10 @@ import numpy as np from astropy.table import Column, Table # type: ignore -from opencosmo.dataset.column import ColumnBuilder +from opencosmo.dataset.index import DataIndex from opencosmo.handler import InMemoryHandler from opencosmo.spatial.tree import Tree from opencosmo.utils import write_index -from opencosmo.dataset.index import DataIndex class OutOfMemoryHandler: @@ -76,7 +75,7 @@ def write( tree_mask = np.zeros(len(self), dtype=bool) tree_mask = index.set_data(tree_mask, True) - + tree = self.__tree.apply_mask(tree_mask) tree.write(group) diff --git a/opencosmo/io.py b/opencosmo/io.py index c1ed0dde..91841dc0 100644 --- a/opencosmo/io.py +++ b/opencosmo/io.py @@ -12,16 +12,15 @@ MPI = None # type: ignore from typing import Iterable, Optional - import opencosmo as oc from opencosmo import collection +from opencosmo.dataset.index import ChunkedIndex, DataIndex from opencosmo.file import FileExistance, file_reader, file_writer, resolve_path from opencosmo.handler import InMemoryHandler, OpenCosmoDataHandler, OutOfMemoryHandler from opencosmo.handler.mpi import partition from opencosmo.header import read_header from opencosmo.spatial import read_tree from opencosmo.transformations import units as u -from opencosmo.dataset.index import ChunkedIndex, DataIndex def open( diff --git a/opencosmo/link/builder.py b/opencosmo/link/builder.py index 131b2eaf..8d138741 100644 --- a/opencosmo/link/builder.py +++ b/opencosmo/link/builder.py @@ -7,11 +7,11 @@ from opencosmo import Dataset from opencosmo.dataset.column import get_column_builders +from opencosmo.dataset.index import ChunkedIndex, DataIndex from opencosmo.handler import OutOfMemoryHandler from opencosmo.header import OpenCosmoHeader from opencosmo.spatial import read_tree from opencosmo.transformations import units as u -from opencosmo.dataset.index import DataIndex, ChunkedIndex, SimpleIndex try: from mpi4py import MPI diff --git a/opencosmo/link/collection.py b/opencosmo/link/collection.py index 9293d59e..8cf0231b 100644 --- a/opencosmo/link/collection.py +++ b/opencosmo/link/collection.py @@ -2,7 +2,6 @@ from typing import Any, Iterable, Optional -import numpy as np from h5py import File, Group import opencosmo as oc @@ -10,9 +9,7 @@ def filter_properties_by_dataset( - dataset: oc.Dataset, - properties: oc.Dataset, - *masks + dataset: oc.Dataset, properties: oc.Dataset, *masks ) -> oc.Dataset: masked_dataset = dataset.filter(*masks) if properties.header.file.data_type == "halo_properties": @@ -24,6 +21,7 @@ def filter_properties_by_dataset( new_properties = properties.filter(oc.col(linked_column).isin(tags)) return new_properties + class StructureCollection: """ A collection of datasets that contain both high-level properties @@ -146,7 +144,7 @@ def filter(self, *masks, dataset: Optional[str] = None) -> StructureCollection: if dataset is None: filtered = self.__properties.filter(*masks) elif dataset not in self.__handlers: - raise ValueError(f"Dataset {dataset} not found in collection.") + raise ValueError(f"Dataset {dataset} not found in collection.") else: filtered = filter_properties_by_dataset( self[dataset], self.__properties, *masks @@ -211,4 +209,3 @@ def write(self, file: File | Group): for key in keys: handler = self.__handlers[key] handler.write(file, link_group, key, self.__index) - diff --git a/opencosmo/link/handler.py b/opencosmo/link/handler.py index 343beaff..0dfa02e1 100644 --- a/opencosmo/link/handler.py +++ b/opencosmo/link/handler.py @@ -6,12 +6,12 @@ from h5py import File, Group import opencosmo as oc +from opencosmo.dataset.index import ChunkedIndex, DataIndex, SimpleIndex from opencosmo.handler import OutOfMemoryHandler from opencosmo.header import OpenCosmoHeader from opencosmo.link.builder import DatasetBuilder, OomDatasetBuilder from opencosmo.spatial import read_tree from opencosmo.transformations import units as u -from opencosmo.dataset.index import DataIndex, SimpleIndex, ChunkedIndex def build_dataset( @@ -122,7 +122,7 @@ def get_all_data(self) -> oc.Dataset: def get_data(self, index: DataIndex) -> oc.Dataset: if isinstance(self.link, tuple): - start = index.get_data(self.link[0]) + start = index.get_data(self.link[0]) size = index.get_data(self.link[1]) valid_rows = size > 0 start = start[valid_rows] @@ -159,9 +159,7 @@ def with_units(self, convention: str) -> OomLinkHandler: self.builder.with_units(convention), ) - def write( - self, group: Group, link_group: Group, name: str, index: DataIndex - ): + def write(self, group: Group, link_group: Group, name: str, index: DataIndex): # Pack the indices if not isinstance(self.link, tuple): new_idxs = np.full(len(index), -1) diff --git a/opencosmo/link/mpi.py b/opencosmo/link/mpi.py index 3f7dfe1b..755a176c 100644 --- a/opencosmo/link/mpi.py +++ b/opencosmo/link/mpi.py @@ -8,6 +8,7 @@ import opencosmo as oc from opencosmo.dataset.column import ColumnBuilder, get_column_builders +from opencosmo.dataset.index import ChunkedIndex, DataIndex, SimpleIndex from opencosmo.handler import MPIHandler from opencosmo.handler.mpi import partition from opencosmo.header import OpenCosmoHeader @@ -15,7 +16,6 @@ from opencosmo.spatial import Tree, read_tree from opencosmo.transformations import TransformationDict from opencosmo.transformations import units as u -from opencosmo.dataset.index import DataIndex, SimpleIndex, ChunkedIndex def build_dataset( @@ -250,10 +250,7 @@ def build( builders = {key: builders[key] for key in selected} - - handler = MPIHandler( - file, tree=self.tree, comm=self.comm - ) + handler = MPIHandler(file, tree=self.tree, comm=self.comm) if index is None: start, size = partition(self.comm, len(handler)) index = ChunkedIndex.single_chunk(start, size) diff --git a/opencosmo/utils.py b/opencosmo/utils.py index be164adc..7d53ce3c 100644 --- a/opencosmo/utils.py +++ b/opencosmo/utils.py @@ -1,15 +1,14 @@ """ I/O utilities for hdf5 """ -from opencosmo.dataset.index import DataIndex from typing import Optional import hdf5plugin # type: ignore -import numpy as np -from astropy.table import Column # type: ignore from h5py import Dataset, Group +from opencosmo.dataset.index import DataIndex + def write_index( input_ds: Dataset, From 77e417a22f806729598d327197811338b8317780 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 14 Apr 2025 11:29:41 -0500 Subject: [PATCH 07/11] Fix mypy issues --- opencosmo/collection/collection.py | 5 ++- opencosmo/dataset/index.py | 46 ++++++++++++++++++++-------- opencosmo/dataset/mask.py | 2 +- opencosmo/handler/handler.py | 3 +- opencosmo/handler/im.py | 8 +++-- opencosmo/link/builder.py | 14 ++------- opencosmo/link/handler.py | 28 +++++------------ opencosmo/link/mpi.py | 49 ++++++------------------------ 8 files changed, 61 insertions(+), 94 deletions(-) diff --git a/opencosmo/collection/collection.py b/opencosmo/collection/collection.py index 665c197c..9df0dae2 100644 --- a/opencosmo/collection/collection.py +++ b/opencosmo/collection/collection.py @@ -11,7 +11,6 @@ import h5py -import numpy as np import opencosmo as oc from opencosmo.dataset.index import ChunkedIndex @@ -258,8 +257,8 @@ def open_single_dataset( builders, base_unit_transformations = u.get_default_unit_transformations( file[dataset_key], header ) - mask = np.arange(len(handler)) - return oc.Dataset(handler, header, builders, base_unit_transformations, mask) + index = ChunkedIndex.from_size(len(handler)) + return oc.Dataset(handler, header, builders, base_unit_transformations, index) def read_single_dataset( diff --git a/opencosmo/dataset/index.py b/opencosmo/dataset/index.py index d1a2be2a..af521e9d 100644 --- a/opencosmo/dataset/index.py +++ b/opencosmo/dataset/index.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Protocol, TypeVar +from typing import Any, Protocol, TypeGuard, TypeVar import h5py import numpy as np @@ -8,10 +8,26 @@ T = TypeVar("T", np.ndarray, h5py.Dataset) +def all_are_chunked( + others: tuple[DataIndex, ...], +) -> TypeGuard[tuple[ChunkedIndex, ...]]: + """ + Check if all elements in the tuple are instances of ChunkedIndex. + """ + return all(isinstance(other, ChunkedIndex) for other in others) + + +def all_are_simple(others: tuple[DataIndex, ...]) -> TypeGuard[tuple[SimpleIndex, ...]]: + """ + Check if all elements in the tuple are instances of SimpleIndex. + """ + return all(isinstance(other, SimpleIndex) for other in others) + + class DataIndex(Protocol): @classmethod def from_size(cls, size: int) -> DataIndex: ... - def set_data(self, data: T, value: Any) -> T: ... + def set_data(self, data: np.ndarray, value: Any) -> np.ndarray: ... def get_data(self, data: h5py.Dataset | np.ndarray) -> np.ndarray: ... def take(self, n: int, at: str = "random") -> DataIndex: ... def take_range(self, start: int, end: int) -> DataIndex: ... @@ -19,7 +35,7 @@ def mask(self, mask: np.ndarray) -> DataIndex: ... def range(self) -> tuple[int, int]: ... def concatenate(self, *others: DataIndex) -> DataIndex: ... def __len__(self) -> int: ... - def __getitem__(self, item: int) -> int: ... + def __getitem__(self, item: int) -> DataIndex: ... class SimpleIndex: @@ -31,7 +47,7 @@ def __init__(self, index: np.ndarray) -> None: self.__index = np.sort(index) @classmethod - def from_size(cls, size: int) -> SimpleIndex: + def from_size(cls, size: int) -> DataIndex: return SimpleIndex(np.arange(size)) def __len__(self) -> int: @@ -43,10 +59,10 @@ def range(self) -> tuple[int, int]: """ return self.__index[0], self.__index[-1] - def concatenate(self, *others: SimpleIndex) -> SimpleIndex: + def concatenate(self, *others: DataIndex) -> DataIndex: if len(others) == 0: return self - if all(isinstance(other, SimpleIndex) for other in others): + if all_are_simple(others): new_index = np.concatenate( [self.__index] + [other.__index for other in others] ) @@ -69,7 +85,7 @@ def set_data(self, data: np.ndarray, value: bool) -> np.ndarray: data[self.__index] = value return data - def take(self, n: int, at: str = "random") -> SimpleIndex: + def take(self, n: int, at: str = "random") -> DataIndex: """ Take n elements from the index. """ @@ -84,7 +100,7 @@ def take(self, n: int, at: str = "random") -> SimpleIndex: else: raise ValueError(f"Unknown value for 'at': {at}") - def take_range(self, start: int, end: int) -> SimpleIndex: + def take_range(self, start: int, end: int) -> DataIndex: """ Take a range of elements from the index. """ @@ -98,7 +114,7 @@ def take_range(self, start: int, end: int) -> SimpleIndex: return SimpleIndex(self.__index[start:end]) - def mask(self, mask: np.ndarray) -> SimpleIndex: + def mask(self, mask: np.ndarray) -> DataIndex: if mask.shape != self.__index.shape: raise ValueError( f"Mask shape {mask.shape} does not match index size {len(self)}" @@ -130,7 +146,7 @@ def get_data(self, data: h5py.Dataset) -> np.ndarray: indices_into_output = self.__index - min_index return output[indices_into_output] - def __getitem__(self, item: int) -> SimpleIndex: + def __getitem__(self, item: int) -> DataIndex: """ Get an item from the index. """ @@ -197,7 +213,7 @@ def to_simple_index(self) -> SimpleIndex: def concatenate(self, *others: DataIndex) -> DataIndex: if len(others) == 0: return self - if all(isinstance(other, ChunkedIndex) for other in others): + if all_are_chunked(others): new_starts = np.concatenate( [self.__starts] + [other.__starts for other in others] ) @@ -214,7 +230,7 @@ def concatenate(self, *others: DataIndex) -> DataIndex: return self.concatenate(*simple_indices) @classmethod - def from_size(cls, size: int) -> ChunkedIndex: + def from_size(cls, size: int) -> DataIndex: """ Create a ChunkedIndex from a size. """ @@ -269,12 +285,14 @@ def take(self, n: int, at: str = "random") -> DataIndex: ) idxs = np.random.choice(idxs, n, replace=False) return SimpleIndex(idxs) + elif at == "start": last_chunk_in_range = np.searchsorted(np.cumsum(self.__sizes), n) new_starts = self.__starts[: last_chunk_in_range + 1].copy() new_sizes = self.__sizes[: last_chunk_in_range + 1].copy() new_sizes[-1] = n - np.sum(new_sizes[:-1]) return ChunkedIndex(new_starts, new_sizes) + elif at == "end": starting_chunk = np.searchsorted(np.cumsum(self.__sizes), len(self) - n) new_sizes = self.__sizes[starting_chunk:].copy() @@ -286,6 +304,8 @@ def take(self, n: int, at: str = "random") -> DataIndex: - new_sizes[0] ) return ChunkedIndex(new_starts, new_sizes) + else: + raise ValueError(f"Unknown value for 'at': {at}") def take_range(self, start: int, end: int) -> DataIndex: """ @@ -383,7 +403,7 @@ def get_data(self, data: h5py.Dataset | np.ndarray) -> np.ndarray: return output - def __getitem__(self, item: int) -> SimpleIndex: + def __getitem__(self, item: int) -> DataIndex: """ Get an item from the index. """ diff --git a/opencosmo/dataset/mask.py b/opencosmo/dataset/mask.py index e497baa7..e5e660ca 100644 --- a/opencosmo/dataset/mask.py +++ b/opencosmo/dataset/mask.py @@ -25,7 +25,7 @@ def apply_masks( column_builders: dict[str, ColumnBuilder], masks: Iterable[Mask], index: DataIndex, -) -> np.ndarray: +) -> DataIndex: masks_by_column = defaultdict(list) for f in masks: masks_by_column[f.column_name].append(f) diff --git a/opencosmo/handler/handler.py b/opencosmo/handler/handler.py index cf2ff469..99251f3c 100644 --- a/opencosmo/handler/handler.py +++ b/opencosmo/handler/handler.py @@ -4,7 +4,6 @@ from typing import Iterable, Optional, Protocol import h5py -import numpy as np from astropy.table import Column, Table # type: ignore from opencosmo.dataset.column import ColumnBuilder @@ -39,7 +38,7 @@ def __enter__(self): ... def __exit__(self, *exc_details): ... def __len__(self) -> int: ... def collect( - self, columns: Iterable[str], mask: np.ndarray + self, columns: Iterable[str], index: DataIndex ) -> OpenCosmoDataHandler: ... def write( self, diff --git a/opencosmo/handler/im.py b/opencosmo/handler/im.py index 846a2cda..12f4b637 100644 --- a/opencosmo/handler/im.py +++ b/opencosmo/handler/im.py @@ -51,14 +51,16 @@ def __enter__(self): def __exit__(self, *exec_details): return False - def collect(self, columns: Iterable[str], indices: np.ndarray) -> InMemoryHandler: + def collect(self, columns: Iterable[str], index: DataIndex) -> InMemoryHandler: """ Create a new InMemoryHandler with only the specified columns and the specified mask applied. """ - new_data = {colname: self.__data[colname][indices] for colname in columns} + new_data = { + colname: index.get_data(self.__data[colname]) for colname in columns + } mask = np.zeros(len(self), dtype=bool) - mask[indices] = True + mask = index.set_data(mask, True) tree = self.__tree.apply_mask(mask) return InMemoryHandler(new_data, tree) diff --git a/opencosmo/link/builder.py b/opencosmo/link/builder.py index 8d138741..3b66fc0d 100644 --- a/opencosmo/link/builder.py +++ b/opencosmo/link/builder.py @@ -2,7 +2,6 @@ from typing import Iterable, Optional, Protocol, Self -import numpy as np from h5py import File, Group from opencosmo import Dataset @@ -26,26 +25,17 @@ class DatasetBuilder(Protocol): the data. """ - def __init__( - self, - selected: Optional[set[str]] = None, - unit_convention: Optional[str] = None, - *args, - **kwargs, - ): - pass - def with_units(self, convention: str) -> Self: pass - def select(self, selected: Iterable[str]) -> Self: + def select(self, selected: str | Iterable[str]) -> Self: pass def build( self, file: File | Group, header: OpenCosmoHeader, - indices: Optional[np.ndarray] = None, + index: Optional[DataIndex] = None, ) -> Dataset: pass diff --git a/opencosmo/link/handler.py b/opencosmo/link/handler.py index 0dfa02e1..6254ce92 100644 --- a/opencosmo/link/handler.py +++ b/opencosmo/link/handler.py @@ -7,24 +7,8 @@ import opencosmo as oc from opencosmo.dataset.index import ChunkedIndex, DataIndex, SimpleIndex -from opencosmo.handler import OutOfMemoryHandler from opencosmo.header import OpenCosmoHeader from opencosmo.link.builder import DatasetBuilder, OomDatasetBuilder -from opencosmo.spatial import read_tree -from opencosmo.transformations import units as u - - -def build_dataset( - file: File | Group, header: OpenCosmoHeader, indices: Optional[np.ndarray] = None -) -> oc.Dataset: - tree = read_tree(file, header) - builders, base_unit_transformations = u.get_default_unit_transformations( - file, header - ) - handler = OutOfMemoryHandler(file, tree=tree) - if indices is None: - indices = np.arange(len(handler)) - return oc.Dataset(handler, header, builders, base_unit_transformations, indices) class LinkHandler(Protocol): @@ -55,7 +39,7 @@ def __init__( """ pass - def get_data(self, indices: int | DataIndex) -> oc.Dataset: + def get_data(self, index: DataIndex) -> oc.Dataset: """ Given a index or a set of indices, return the data from the linked dataset that corresponds to the halo/galaxy at that index in the properties file. @@ -71,7 +55,7 @@ def get_all_data(self) -> oc.Dataset: pass def write( - self, data_group: Group, link_group: Group, name: str, indices: int | np.ndarray + self, data_group: Group, link_group: Group, name: str, index: DataIndex ) -> None: """ Write the linked data for the given indices to data_group. @@ -118,7 +102,10 @@ def __init__( self.builder = builder def get_all_data(self) -> oc.Dataset: - return build_dataset(self.file, self.header) + return self.builder.build( + self.file, + self.header, + ) def get_data(self, index: DataIndex) -> oc.Dataset: if isinstance(self.link, tuple): @@ -127,6 +114,7 @@ def get_data(self, index: DataIndex) -> oc.Dataset: valid_rows = size > 0 start = start[valid_rows] size = size[valid_rows] + new_index: DataIndex if not start.size: new_index = SimpleIndex(np.array([], dtype=int)) else: @@ -135,8 +123,6 @@ def get_data(self, index: DataIndex) -> oc.Dataset: indices_into_data = index.get_data(self.link) indices_into_data = indices_into_data[indices_into_data >= 0] new_index = SimpleIndex(indices_into_data) - if not indices_into_data.size: - indices_into_data = SimpleIndex(np.array([], dtype=int)) return self.builder.build(self.file, self.header, new_index) diff --git a/opencosmo/link/mpi.py b/opencosmo/link/mpi.py index 755a176c..bbbe57ef 100644 --- a/opencosmo/link/mpi.py +++ b/opencosmo/link/mpi.py @@ -7,50 +7,16 @@ from mpi4py import MPI import opencosmo as oc -from opencosmo.dataset.column import ColumnBuilder, get_column_builders +from opencosmo.dataset.column import get_column_builders from opencosmo.dataset.index import ChunkedIndex, DataIndex, SimpleIndex from opencosmo.handler import MPIHandler from opencosmo.handler.mpi import partition from opencosmo.header import OpenCosmoHeader from opencosmo.link.builder import DatasetBuilder from opencosmo.spatial import Tree, read_tree -from opencosmo.transformations import TransformationDict from opencosmo.transformations import units as u -def build_dataset( - file: File | Group, - indices: np.ndarray, - header: OpenCosmoHeader, - comm: MPI.Comm, - tree: Tree, - base_transformations: TransformationDict, - builders: dict[str, ColumnBuilder], -) -> oc.Dataset: - if len(indices) > 0: - index_range = (indices.min(), indices.max() + 1) - indices = indices - index_range[0] - else: - index_range = None - - handler = MPIHandler(file, tree=tree, comm=comm, rank_range=index_range) - return oc.Dataset(handler, header, builders, base_transformations, indices) - - -def build_full_dataset( - file: File | Group, - header: OpenCosmoHeader, - comm: MPI.Comm, - tree: Tree, - base_transformations: TransformationDict, - builders: dict[str, ColumnBuilder], -) -> oc.Dataset: - handler = MPIHandler(file, tree=tree, comm=comm) - return oc.Dataset( - handler, header, builders, base_transformations, np.arange(len(handler)) - ) - - class MpiLinkHandler: def __init__( self, @@ -68,8 +34,9 @@ def __init__( self.comm = comm if builder is None: tree = read_tree(file, self.header) - builder = MpiDatasetBuilder(tree, comm=comm) - self.builder = builder + self.builder: DatasetBuilder = MpiDatasetBuilder(tree, comm=comm) + else: + self.builder = builder if isinstance(self.link, tuple): n_per_rank = self.link[0].shape[0] // self.comm.Get_size() self.offset = n_per_rank * self.comm.Get_rank() @@ -90,6 +57,7 @@ def get_data(self, index: DataIndex) -> oc.Dataset: valid_rows = size > 0 start = start[valid_rows] size = size[valid_rows] + new_index: DataIndex if len(start) == 0: new_index = SimpleIndex(np.array([], dtype=int)) else: @@ -143,10 +111,13 @@ def write( self.comm.Barrier() indices_into_data = index.get_data(self.link) nonzero = indices_into_data >= 0 - nonzero = self.comm.gather(nonzero) + all_nonzero = self.comm.gather(nonzero) if self.comm.Get_rank() == 0: - nonzero = np.concatenate(nonzero) + if all_nonzero is None: + # should never happen, but mypy... + raise ValueError("No data to write") + nonzero = np.concatenate(all_nonzero) sod_profile_idx = np.full(len(nonzero), -1) sod_profile_idx[nonzero] = np.arange(sum(nonzero)) link_group["sod_profile_idx"][:] = sod_profile_idx From 94f400e0efcdc3f322da51c3e9691f592b768482 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 14 Apr 2025 13:34:53 -0500 Subject: [PATCH 08/11] Fix for large trees --- opencosmo/link/io.py | 2 +- opencosmo/spatial/tree.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/opencosmo/link/io.py b/opencosmo/link/io.py index fa14a3dc..92b95ba2 100644 --- a/opencosmo/link/io.py +++ b/opencosmo/link/io.py @@ -18,7 +18,7 @@ LINK_ALIASES = { # Left: Name in file, right: Name in collection "sodbighaloparticles_star_particles": "star_particles", "sodbighaloparticles_dm_particles": "dm_particles", - "sodbighaloparticles_gravity_particles": "dm_particles", + "sodbighaloparticles_gravity_particles": "gravity_particles", "sodbighaloparticles_agn_particles": "agn_particles", "sodbighaloparticles_gas_particles": "gas_particles", "sod_profile": "halo_profiles", diff --git a/opencosmo/spatial/tree.py b/opencosmo/spatial/tree.py index 29cfb955..8ab431d0 100644 --- a/opencosmo/spatial/tree.py +++ b/opencosmo/spatial/tree.py @@ -71,7 +71,7 @@ def apply_range_mask( st[0] = range_[0] st = st - range_[0] # Determine how many true values are in the mask in the ranges - new_sizes = np.fromiter((np.sum(a) for a in np.split(mask, st[1:])), dtype=int) + new_sizes = np.add.reduceat(mask, st) output_sizes[level] = (first_start_index, new_sizes) return output_sizes From b1783a06efdad54df15ce160083654cdbabf7087 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 15 Apr 2025 08:52:18 -0500 Subject: [PATCH 09/11] Disable trees (for now) --- opencosmo/collection/collection.py | 11 +++++++---- opencosmo/handler/im.py | 12 ++++++++---- opencosmo/handler/mpi.py | 9 ++++----- opencosmo/handler/oom.py | 20 +++++++++++--------- opencosmo/io.py | 6 ++++-- opencosmo/link/builder.py | 4 ++-- opencosmo/link/mpi.py | 5 +++-- test/parallel/test_mpi.py | 1 + test/test_spatial.py | 2 +- 9 files changed, 41 insertions(+), 29 deletions(-) diff --git a/opencosmo/collection/collection.py b/opencosmo/collection/collection.py index 9df0dae2..55b9090c 100644 --- a/opencosmo/collection/collection.py +++ b/opencosmo/collection/collection.py @@ -18,7 +18,6 @@ from opencosmo.handler import InMemoryHandler, OpenCosmoDataHandler, OutOfMemoryHandler from opencosmo.header import OpenCosmoHeader, read_header from opencosmo.link import StructureCollection -from opencosmo.spatial import read_tree from opencosmo.transformations import units as u @@ -245,7 +244,9 @@ def open_single_dataset( if header is None: header = read_header(file[dataset_key]) - tree = read_tree(file[dataset_key], header) + + # tree = read_tree(file[dataset_key], header) + tree = None handler: OpenCosmoDataHandler if MPI is not None and MPI.COMM_WORLD.Get_size() > 1: handler = MPIHandler( @@ -272,8 +273,10 @@ def read_single_dataset( if header is None: header = read_header(file[dataset_key]) - - tree = read_tree(file[dataset_key], header) + + # tree = read_tree(file[dataset_key], header) + tree = None + handler = InMemoryHandler(file, tree, dataset_key) builders, base_unit_transformations = u.get_default_unit_transformations( file[dataset_key], header diff --git a/opencosmo/handler/im.py b/opencosmo/handler/im.py index 12f4b637..a3b30651 100644 --- a/opencosmo/handler/im.py +++ b/opencosmo/handler/im.py @@ -22,7 +22,7 @@ class InMemoryHandler: def __init__( self, file: h5py.File, - tree: Tree, + tree: Optional[Tree] = None, group_name: Optional[str] = None, columns: Optional[Iterable[str]] = None, index: Optional[DataIndex] = None, @@ -61,7 +61,10 @@ def collect(self, columns: Iterable[str], index: DataIndex) -> InMemoryHandler: } mask = np.zeros(len(self), dtype=bool) mask = index.set_data(mask, True) - tree = self.__tree.apply_mask(mask) + if self.__tree is not None: + tree = self.__tree.apply_mask(mask) + else: + tree = None return InMemoryHandler(new_data, tree) def write( @@ -84,8 +87,9 @@ def write( if self.__columns[column] is not None: data_group[column].attrs["unit"] = self.__columns[column] mask = np.zeros(len(self), dtype=bool) - tree = self.__tree.apply_mask(mask) - tree.write(group, dataset_name="index") + if self.__tree is not None: + tree = self.__tree.apply_mask(mask) + tree.write(group, dataset_name="index") def get_data( self, diff --git a/opencosmo/handler/mpi.py b/opencosmo/handler/mpi.py index c5ab83bb..29718a13 100644 --- a/opencosmo/handler/mpi.py +++ b/opencosmo/handler/mpi.py @@ -65,7 +65,7 @@ class MPIHandler: def __init__( self, file: h5py.File, - tree: Tree, + tree: Optional[Tree] = None, group_name: Optional[str] = None, comm=MPI.COMM_WORLD, ): @@ -164,10 +164,9 @@ def write( mask = np.zeros(len(self), dtype=bool) mask = index.set_data(mask, True) - - new_tree = self.__tree.apply_mask(mask, self.__comm, index.range()) - - new_tree.write(group) # type: ignore + if self.__tree is not None: + new_tree = self.__tree.apply_mask(mask, self.__comm, index.range()) + new_tree.write(group) # type: ignore self.__comm.Barrier() diff --git a/opencosmo/handler/oom.py b/opencosmo/handler/oom.py index 9ab3351b..de4edeb7 100644 --- a/opencosmo/handler/oom.py +++ b/opencosmo/handler/oom.py @@ -18,7 +18,7 @@ class OutOfMemoryHandler: disk until needed """ - def __init__(self, file: h5py.File, tree: Tree, group_name: Optional[str] = None): + def __init__(self, file: h5py.File, tree: Optional[Tree] = None, group_name: Optional[str] = None): self.__group_name = group_name self.__file = file if group_name is None: @@ -40,12 +40,14 @@ def __exit__(self, *exec_details): def collect(self, columns: Iterable[str], index: DataIndex) -> InMemoryHandler: file_path = self.__file.filename - if len(index) == len(self): - tree = self.__tree - else: + tree: Optional[Tree] = None + if self.__tree is not None and len(index) == len(self): mask = np.zeros(len(self), dtype=bool) mask = index.set_data(mask, True) tree = self.__tree.apply_mask(mask) + + else: + tree = self.__tree with h5py.File(file_path, "r") as file: return InMemoryHandler( @@ -73,11 +75,11 @@ def write( for column in columns: write_index(self.__group[column], data_group, index) - tree_mask = np.zeros(len(self), dtype=bool) - tree_mask = index.set_data(tree_mask, True) - - tree = self.__tree.apply_mask(tree_mask) - tree.write(group) + if self.__tree is not None: + tree_mask = np.zeros(len(self), dtype=bool) + tree_mask = index.set_data(tree_mask, True) + tree = self.__tree.apply_mask(tree_mask) + tree.write(group) def get_data(self, builders: dict, index: DataIndex) -> Column | Table: """ """ diff --git a/opencosmo/io.py b/opencosmo/io.py index 91841dc0..a134d8d9 100644 --- a/opencosmo/io.py +++ b/opencosmo/io.py @@ -84,7 +84,8 @@ def open( group = file_handle header = read_header(file_handle) - tree = read_tree(file_handle, header) + # tree = read_tree(file_handle, header) + tree = None if datasets is not None and not isinstance(datasets, str): raise ValueError("Asked for multiple datasets, but file has only one") @@ -147,7 +148,8 @@ def read( if datasets is not None and not isinstance(datasets, str): raise ValueError("Asked for multiple datasets, but file has only one") header = read_header(file) - tree = read_tree(file, header) + # tree = read_tree(file, header) + tree = None handler = InMemoryHandler(file, tree, group_name=datasets) index = ChunkedIndex.from_size(len(handler)) builders, base_unit_transformations = u.get_default_unit_transformations( diff --git a/opencosmo/link/builder.py b/opencosmo/link/builder.py index 3b66fc0d..4b675e4a 100644 --- a/opencosmo/link/builder.py +++ b/opencosmo/link/builder.py @@ -9,7 +9,6 @@ from opencosmo.dataset.index import ChunkedIndex, DataIndex from opencosmo.handler import OutOfMemoryHandler from opencosmo.header import OpenCosmoHeader -from opencosmo.spatial import read_tree from opencosmo.transformations import units as u try: @@ -93,7 +92,8 @@ def build( header: OpenCosmoHeader, index: Optional[DataIndex] = None, ) -> Dataset: - tree = read_tree(file, header) + # tree = read_tree(file, header) + tree = None builders, base_unit_transformations = u.get_default_unit_transformations( file, header ) diff --git a/opencosmo/link/mpi.py b/opencosmo/link/mpi.py index bbbe57ef..93cb3a63 100644 --- a/opencosmo/link/mpi.py +++ b/opencosmo/link/mpi.py @@ -33,7 +33,8 @@ def __init__( self.header = header self.comm = comm if builder is None: - tree = read_tree(file, self.header) + # tree = read_tree(file, self.header) + tree = None self.builder: DatasetBuilder = MpiDatasetBuilder(tree, comm=comm) else: self.builder = builder @@ -154,7 +155,7 @@ class MpiDatasetBuilder: def __init__( self, - tree: Tree, + tree: Optional[Tree] = None, selected: Optional[set[str]] = None, unit_convention: Optional[str] = None, comm: MPI.Comm = MPI.COMM_WORLD, diff --git a/test/parallel/test_mpi.py b/test/parallel/test_mpi.py index 051ce901..77beaf3f 100644 --- a/test/parallel/test_mpi.py +++ b/test/parallel/test_mpi.py @@ -102,6 +102,7 @@ def test_filters(input_path): parallel_assert(all(data["sod_halo_mass"] > 0)) +@pytest.mark.skip("Trees are not fully implemented") @pytest.mark.parallel(nprocs=4) def test_filter_write(input_path, tmp_path): comm = mpi4py.MPI.COMM_WORLD diff --git a/test/test_spatial.py b/test/test_spatial.py index c15ef916..db80af83 100644 --- a/test/test_spatial.py +++ b/test/test_spatial.py @@ -8,7 +8,7 @@ def input_path(data_path): return data_path / "haloproperties.hdf5" - +@pytest.mark.skip("Trees are not fully implemented yet") def test_filter_write(input_path, tmp_path): tmp_file = tmp_path / "filtered_data.hdf5" with oc.open(input_path) as f: From d4fd2f5460d2b6cc883c2930241dd635837abdae Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 15 Apr 2025 09:01:37 -0500 Subject: [PATCH 10/11] Linting --- opencosmo/collection/collection.py | 5 ++--- opencosmo/handler/oom.py | 9 +++++++-- opencosmo/io.py | 1 - opencosmo/link/mpi.py | 2 +- test/test_spatial.py | 1 + 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/opencosmo/collection/collection.py b/opencosmo/collection/collection.py index 55b9090c..258531fc 100644 --- a/opencosmo/collection/collection.py +++ b/opencosmo/collection/collection.py @@ -244,7 +244,6 @@ def open_single_dataset( if header is None: header = read_header(file[dataset_key]) - # tree = read_tree(file[dataset_key], header) tree = None handler: OpenCosmoDataHandler @@ -273,10 +272,10 @@ def read_single_dataset( if header is None: header = read_header(file[dataset_key]) - + # tree = read_tree(file[dataset_key], header) tree = None - + handler = InMemoryHandler(file, tree, dataset_key) builders, base_unit_transformations = u.get_default_unit_transformations( file[dataset_key], header diff --git a/opencosmo/handler/oom.py b/opencosmo/handler/oom.py index de4edeb7..5a8e1596 100644 --- a/opencosmo/handler/oom.py +++ b/opencosmo/handler/oom.py @@ -18,7 +18,12 @@ class OutOfMemoryHandler: disk until needed """ - def __init__(self, file: h5py.File, tree: Optional[Tree] = None, group_name: Optional[str] = None): + def __init__( + self, + file: h5py.File, + tree: Optional[Tree] = None, + group_name: Optional[str] = None, + ): self.__group_name = group_name self.__file = file if group_name is None: @@ -45,7 +50,7 @@ def collect(self, columns: Iterable[str], index: DataIndex) -> InMemoryHandler: mask = np.zeros(len(self), dtype=bool) mask = index.set_data(mask, True) tree = self.__tree.apply_mask(mask) - + else: tree = self.__tree diff --git a/opencosmo/io.py b/opencosmo/io.py index a134d8d9..fb39ee13 100644 --- a/opencosmo/io.py +++ b/opencosmo/io.py @@ -19,7 +19,6 @@ from opencosmo.handler import InMemoryHandler, OpenCosmoDataHandler, OutOfMemoryHandler from opencosmo.handler.mpi import partition from opencosmo.header import read_header -from opencosmo.spatial import read_tree from opencosmo.transformations import units as u diff --git a/opencosmo/link/mpi.py b/opencosmo/link/mpi.py index 93cb3a63..226b82ee 100644 --- a/opencosmo/link/mpi.py +++ b/opencosmo/link/mpi.py @@ -13,7 +13,7 @@ from opencosmo.handler.mpi import partition from opencosmo.header import OpenCosmoHeader from opencosmo.link.builder import DatasetBuilder -from opencosmo.spatial import Tree, read_tree +from opencosmo.spatial import Tree from opencosmo.transformations import units as u diff --git a/test/test_spatial.py b/test/test_spatial.py index db80af83..d5e4e9f6 100644 --- a/test/test_spatial.py +++ b/test/test_spatial.py @@ -8,6 +8,7 @@ def input_path(data_path): return data_path / "haloproperties.hdf5" + @pytest.mark.skip("Trees are not fully implemented yet") def test_filter_write(input_path, tmp_path): tmp_file = tmp_path / "filtered_data.hdf5" From a75a08828c49e3da8aef2aaeea34fe0c8da54263 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Tue, 15 Apr 2025 09:04:44 -0500 Subject: [PATCH 11/11] Fix an import --- opencosmo/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opencosmo/io.py b/opencosmo/io.py index fb39ee13..40e1a24d 100644 --- a/opencosmo/io.py +++ b/opencosmo/io.py @@ -8,6 +8,7 @@ from mpi4py import MPI from opencosmo.handler import MPIHandler + from opencosmo.handler.mpi import partition except ImportError: MPI = None # type: ignore from typing import Iterable, Optional @@ -17,7 +18,6 @@ from opencosmo.dataset.index import ChunkedIndex, DataIndex from opencosmo.file import FileExistance, file_reader, file_writer, resolve_path from opencosmo.handler import InMemoryHandler, OpenCosmoDataHandler, OutOfMemoryHandler -from opencosmo.handler.mpi import partition from opencosmo.header import read_header from opencosmo.transformations import units as u