Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions opencosmo/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@


import h5py
import numpy as np

import opencosmo as oc
from opencosmo.dataset.index import ChunkedIndex
from opencosmo.dataset.mask import Mask
from opencosmo.handler import InMemoryHandler, OpenCosmoDataHandler, OutOfMemoryHandler
from opencosmo.header import OpenCosmoHeader, read_header
from opencosmo.link import StructureCollection
from opencosmo.spatial import read_tree
from opencosmo.transformations import units as u


Expand Down Expand Up @@ -178,7 +177,7 @@ def __map(self, method, *args, **kwargs):
output = {k: getattr(v, method)(*args, **kwargs) for k, v in self.items()}
return SimulationCollection(output)

def filter(self, *masks: Mask) -> SimulationCollection:
def filter(self, *masks: Mask, **kwargs) -> SimulationCollection:
"""
Filter the datasets in the collection. This method behaves
exactly like :meth:`opencosmo.Dataset.filter`, except that
Expand All @@ -196,7 +195,7 @@ def filter(self, *masks: Mask) -> SimulationCollection:
A new collection with the same datasets, but only the
particles that pass the filter.
"""
return self.__map("filter", *masks)
return self.__map("filter", *masks, **kwargs)

def select(self, *args, **kwargs) -> SimulationCollection:
"""
Expand Down Expand Up @@ -245,7 +244,8 @@ def open_single_dataset(
if header is None:
header = read_header(file[dataset_key])

tree = read_tree(file[dataset_key], header)
# tree = read_tree(file[dataset_key], header)
tree = None
handler: OpenCosmoDataHandler
if MPI is not None and MPI.COMM_WORLD.Get_size() > 1:
handler = MPIHandler(
Expand All @@ -257,8 +257,8 @@ def open_single_dataset(
builders, base_unit_transformations = u.get_default_unit_transformations(
file[dataset_key], header
)
mask = np.arange(len(handler))
return oc.Dataset(handler, header, builders, base_unit_transformations, mask)
index = ChunkedIndex.from_size(len(handler))
return oc.Dataset(handler, header, builders, base_unit_transformations, index)


def read_single_dataset(
Expand All @@ -273,10 +273,12 @@ def read_single_dataset(
if header is None:
header = read_header(file[dataset_key])

tree = read_tree(file[dataset_key], header)
# tree = read_tree(file[dataset_key], header)
tree = None

handler = InMemoryHandler(file, tree, dataset_key)
builders, base_unit_transformations = u.get_default_unit_transformations(
file[dataset_key], header
)
mask = np.arange(len(handler))
return oc.Dataset(handler, header, builders, base_unit_transformations, mask)
index = ChunkedIndex.from_size(len(handler))
return oc.Dataset(handler, header, builders, base_unit_transformations, index)
59 changes: 21 additions & 38 deletions opencosmo/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Generator, Iterable, Optional

import h5py
import numpy as np
from astropy import units # type: ignore
from astropy.table import Table # type: ignore

import opencosmo.transformations as t
import opencosmo.transformations.units as u
from opencosmo.dataset.column import ColumnBuilder, get_column_builders
from opencosmo.dataset.index import ChunkedIndex, DataIndex
from opencosmo.dataset.mask import Mask, apply_masks
from opencosmo.handler import OpenCosmoDataHandler
from opencosmo.header import OpenCosmoHeader, write_header
Expand All @@ -22,21 +22,21 @@ def __init__(
header: OpenCosmoHeader,
builders: dict[str, ColumnBuilder],
unit_transformations: dict[t.TransformationType, list[t.Transformation]],
indices: np.ndarray,
index: DataIndex,
):
self.__handler = handler
self.__header = header
self.__builders = builders
self.__base_unit_transformations = unit_transformations
self.__indices = indices
self.__index = index

@property
def header(self) -> OpenCosmoHeader:
return self.__header

@property
def indices(self) -> np.ndarray:
return self.__indices
def index(self) -> DataIndex:
return self.__index

def __repr__(self):
"""
Expand All @@ -54,7 +54,7 @@ def __repr__(self):
return head + cosmo_repr + table_head + table_repr

def __len__(self):
return len(self.__indices)
return len(self.__index)

def __enter__(self):
# Need to write tests
Expand All @@ -74,7 +74,7 @@ def cosmology(self):
def data(self):
# should rename this, dataset.data can get confusing
# Also the point is that there's MORE data than just the table
return self.__handler.get_data(builders=self.__builders, indices=self.__indices)
return self.__handler.get_data(builders=self.__builders, index=self.__index)

def write(
self,
Expand Down Expand Up @@ -103,7 +103,7 @@ def write(
if with_header:
write_header(file, self.__header, dataset_name)

self.__handler.write(file, self.indices, self.__builders.keys(), dataset_name)
self.__handler.write(file, self.__index, self.__builders.keys(), dataset_name)

def rows(self) -> Generator[dict[str, float | units.Quantity]]:
"""
Expand Down Expand Up @@ -163,14 +163,14 @@ def take_range(self, start: int, end: int) -> Table:
if start < 0 or end > len(self):
raise ValueError("start and end must be within the bounds of the dataset.")

new_indices = self.__indices[start:end]
new_index = self.__index.take_range(start, end)

return Dataset(
self.__handler,
self.__header,
self.__builders,
self.__base_unit_transformations,
new_indices,
new_index,
)

def filter(self, *masks: Mask) -> Dataset:
Expand All @@ -195,19 +195,17 @@ def filter(self, *masks: Mask) -> Dataset:

"""

new_indices = apply_masks(
self.__handler, self.__builders, masks, self.__indices
)
new_index = apply_masks(self.__handler, self.__builders, masks, self.__index)

if len(new_indices) == 0:
raise ValueError("Filter returned zero rows!")
if len(new_index) == 0:
raise ValueError("The filter returned no rows!")

return Dataset(
self.__handler,
self.__header,
self.__builders,
self.__base_unit_transformations,
new_indices,
new_index,
)

def select(self, columns: str | Iterable[str]) -> Dataset:
Expand Down Expand Up @@ -250,7 +248,7 @@ def select(self, columns: str | Iterable[str]) -> Dataset:
self.__header,
new_builders,
self.__base_unit_transformations,
self.__indices,
self.__index,
)

def with_units(self, convention: str) -> Dataset:
Expand Down Expand Up @@ -279,7 +277,7 @@ def with_units(self, convention: str) -> Dataset:
self.__header,
new_builders,
self.__base_unit_transformations,
self.__indices,
self.__index,
)

def collect(self) -> Dataset:
Expand All @@ -305,13 +303,14 @@ def collect(self) -> Dataset:

If working in an MPI context, all ranks will recieve the same data.
"""
new_handler = self.__handler.collect(self.__builders.keys(), self.__indices)
new_handler = self.__handler.collect(self.__builders.keys(), self.__index)
new_index = ChunkedIndex.from_size(len(new_handler))
return Dataset(
new_handler,
self.__header,
self.__builders,
self.__base_unit_transformations,
np.arange(len(new_handler)),
new_index,
)

def take(self, n: int, at: str = "start") -> Dataset:
Expand Down Expand Up @@ -341,28 +340,12 @@ def take(self, n: int, at: str = "start") -> Dataset:
or if 'at' is invalid.

"""

if n < 0 or n > len(self):
raise ValueError(
"Invalid value for 'n', must be between 0 and the length of the dataset"
)
if at == "start":
new_indices = self.__indices[:n]
elif at == "end":
new_indices = self.__indices[-n:]
elif at == "random":
new_indices = np.random.choice(self.__indices, n, replace=False)
new_indices.sort()

else:
raise ValueError(
"Invalid value for 'at'. Must be one of 'start', 'end', or 'random'."
)
new_index = self.__index.take(n, at)

return Dataset(
self.__handler,
self.__header,
self.__builders,
self.__base_unit_transformations,
new_indices,
new_index,
)
Loading
Loading