diff --git a/opencosmo/handler/mpi.py b/opencosmo/handler/mpi.py index 8c395512..4b9915cc 100644 --- a/opencosmo/handler/mpi.py +++ b/opencosmo/handler/mpi.py @@ -128,7 +128,6 @@ def write( rank_output_length = np.sum(mask) all_output_lengths = self.__comm.allgather(rank_output_length) - all_input_lengths = self.__comm.allgather(len(mask)) rank = self.__comm.Get_rank() @@ -162,32 +161,9 @@ def write( data = data[mask] data_group[column][rank_start:rank_end] = data - displacements = np.insert(np.cumsum(all_input_lengths[:-1]), 0, 0) - if rank == 0: - recvbuf = np.empty(sum(all_input_lengths), dtype=np.uint8) - self.__comm.Gatherv( - sendbuf=mask.view(np.uint8), - recvbuf=( - recvbuf.view(np.uint8), - all_input_lengths, - displacements, - MPI.BYTE, - ), - root=0, - ) - else: - self.__comm.Gatherv( - sendbuf=mask.view(np.uint8), recvbuf=(None, None, None, None), root=0 - ) + new_tree = self.__tree.apply_mask(mask, self.__comm, self.elem_range()) - if rank == 0: - mask = recvbuf.astype(bool) - tree = self.__tree.apply_mask(mask) - else: - tree = None - tree = self.__comm.bcast(tree, root=0) - # - tree.write(group) # type: ignore + new_tree.write(group) # type: ignore self.__comm.Barrier() diff --git a/opencosmo/handler/oom.py b/opencosmo/handler/oom.py index 9b6fa2eb..6a1c9536 100644 --- a/opencosmo/handler/oom.py +++ b/opencosmo/handler/oom.py @@ -89,9 +89,10 @@ def get_data( raise ValueError("This file has already been closed") output = {} for column, builder in builders.items(): - data = self.__group[column][()] if mask is not None: - data = data[mask] + data = self.__group[column][mask] + else: + data = self.__group[column][()] col = Column(data, name=column) output[column] = builder.build(col) diff --git a/opencosmo/spatial/tree.py b/opencosmo/spatial/tree.py index baf4e957..29cfb955 100644 --- a/opencosmo/spatial/tree.py +++ b/opencosmo/spatial/tree.py @@ -1,10 +1,15 @@ from __future__ import annotations -from collections import OrderedDict +from typing import Optional import h5py import numpy as np +try: + from mpi4py import MPI +except ImportError: + MPI = None # type: ignore + from opencosmo.header import OpenCosmoHeader from opencosmo.spatial.index import SpatialIndex from opencosmo.spatial.octree import OctTreeIndex @@ -15,27 +20,94 @@ def read_tree(file: h5py.File | h5py.Group, header: OpenCosmoHeader): Read a tree from an HDF5 file and the associated header. The tree is just a mapping between a spatial index and a slice into the data. + + Note: The max level in the header may not actually match + the max level in the file. When a large dataset is filtered down, + we may reduce the tree level to save space in the output file. + + The max level in the header is the maximum level in the full + dataset, so this is the HIGHEST it can be. """ max_level = header.reformat.max_level - data_indices = OrderedDict() + starts = {} + sizes = {} for level in range(max_level + 1): - group = file[f"index/level_{level}"] - starts = group["start"][()] - sizes = group["size"][()] - level_indices = {} - for i, (start, size) in enumerate(zip(starts, sizes)): - level_indices[i] = slice(start, start + size) - data_indices[level] = level_indices + try: + group = file[f"index/level_{level}"] + except KeyError: + break + level_starts = group["start"][()] + level_sizes = group["size"][()] + starts[level] = level_starts + sizes[level] = level_sizes spatial_index = OctTreeIndex(header.simulation, max_level) - return Tree(spatial_index, data_indices) + return Tree(spatial_index, starts, sizes) def write_tree(file: h5py.File, tree: Tree, dataset_name: str = "index"): tree.write(file, dataset_name) +def apply_range_mask( + mask: np.ndarray, + range_: tuple[int, int], + starts: dict[int, np.ndarray], + sizes: dict[int, np.ndarray], +) -> dict[int, tuple[int, np.ndarray]]: + """ + Given an index range, apply a mask of the same size to produces new sizes. + """ + output_sizes = {} + + for level, st in starts.items(): + ends = st + sizes[level] + # Not in range if the end is less than start, or the start is greater than end + overlaps_mask = ~((st > range_[1]) | (ends < range_[0])) + # The first start may be less thank the range start so + first_start_index = int(np.argmax(overlaps_mask)) + st = st[overlaps_mask] + st[0] = range_[0] + st = st - range_[0] + # Determine how many true values are in the mask in the ranges + new_sizes = np.fromiter((np.sum(a) for a in np.split(mask, st[1:])), dtype=int) + output_sizes[level] = (first_start_index, new_sizes) + return output_sizes + + +def pack_masked_ranges( + old_starts: dict[int, np.ndarray], + new_sizes: list[dict[int, tuple[int, np.ndarray]]], + min_level_size: int = 500, +) -> tuple[dict[int, np.ndarray], dict[int, np.ndarray]]: + """ + Given a list of masked ranges, pack them into a new set of sizes. + This is used when working with MPI, and allows us to avoid sending + very large masks between ranks. + + For queries that return a small fraction of the data, we can end up + writing a lot of zeros in the lower levels of the tree. So we can + dynamically choose to stop writing levels when the average size of + the level is below a certain threshold + """ + output_starts = {} + output_sizes = {} + for level in new_sizes[0]: + new_level_sizes = np.zeros_like(old_starts[level]) + new_start_info = [rm[level] for rm in new_sizes] + for first_idx, sizes in new_start_info: + new_level_sizes[first_idx : first_idx + len(sizes)] += sizes + + avg_size = np.mean(new_level_sizes[new_level_sizes > 0]) + if avg_size < min_level_size: + break + output_sizes[level] = new_level_sizes + output_starts[level] = np.cumsum(np.insert(new_level_sizes, 0, 0))[:-1] + + return output_starts, output_sizes + + class Tree: """ The Tree handles the spatial indexing of the data. As of right now, it's only @@ -43,11 +115,22 @@ class Tree: spatial queries """ - def __init__(self, index: SpatialIndex, slices: dict[int, dict[int, slice]]): + def __init__( + self, + index: SpatialIndex, + starts: dict[int, np.ndarray], + sizes: dict[int, np.ndarray], + ): self.__index = index - self.__slices = slices + self.__starts = starts + self.__sizes = sizes - def apply_mask(self, mask: np.ndarray) -> Tree: + def apply_mask( + self, + mask: np.ndarray, + comm: Optional[MPI.Comm] = None, + range_: Optional[tuple] = None, + ) -> Tree: """ Given a boolean mask, create a new tree with slices adjusted to only include the elements where the mask is True. This is used @@ -55,30 +138,55 @@ def apply_mask(self, mask: np.ndarray) -> Tree: The mask will have the same shape as the original data. """ + + if comm is not None and range_ is not None: + return self.__apply_rank_mask(mask, comm, range_) if np.all(mask): return self - new_slices = {} - for level, slices in self.__slices.items(): - lengths = [np.sum(mask[s]) for s in slices.values()] - new_starts = np.cumsum([0] + lengths[:-1]) - new_slices[level] = { - i: slice(new_starts[i], new_starts[i] + lengths[i]) - for i in range(len(lengths)) - if lengths[i] > 0 - } - return Tree(self.__index, new_slices) + output_starts = {} + output_sizes = {} + for level in self.__starts: + start = self.__starts[level] + size = self.__sizes[level] + offsets = np.zeros_like(size) + for i in range(len(start)): + # Create a slice object for the current level + s = slice(start[i], start[i] + size[i]) + slice_mask = mask[s] # Apply the slice to the mask + offsets[i] = np.sum(slice_mask) # Count the number of True values + level_starts = np.cumsum(np.insert(offsets, 0, 0))[ + :-1 + ] # Cumulative sum to get new starts + level_sizes = offsets + output_starts[level] = level_starts + output_sizes[level] = level_sizes + + return Tree(self.__index, output_starts, output_sizes) + + def __apply_rank_mask( + self, mask: np.ndarray, comm: MPI.Comm, range_: tuple[int, int] + ) -> Tree: + """ + Given a range and a mask, apply the mask to the tree. The mask + will have the same shape as the original data. + """ + new_sizes = apply_range_mask(mask, range_, self.__starts, self.__sizes) + all_new_sizes = comm.allgather(new_sizes) + output_starts, output_sizes = pack_masked_ranges(self.__starts, all_new_sizes) + return Tree(self.__index, output_starts, output_sizes) def write(self, file: h5py.File, dataset_name: str = "index"): """ Write the tree to an HDF5 file. Note that this function is not responsible for applying masking. The routine calling this - function should first create a new tree with apply_mask if + funct + MPI = None + MPI = None + MPI = Noneion should first create a new tree with apply_mask if necessary. """ group = file.require_group(dataset_name) - for level, slices in self.__slices.items(): + for level in self.__starts: level_group = group.require_group(f"level_{level}") - start = np.array([s.start for s in slices.values()]) - size = np.array([s.stop - s.start for s in slices.values()]) - level_group.create_dataset("start", data=start) - level_group.create_dataset("size", data=size) + level_group.create_dataset("start", data=self.__starts[level]) + level_group.create_dataset("size", data=self.__sizes[level]) diff --git a/test/parallel/test_mpi.py b/test/parallel/test_mpi.py index 125a4bfe..2ef70ca8 100644 --- a/test/parallel/test_mpi.py +++ b/test/parallel/test_mpi.py @@ -91,13 +91,26 @@ def test_filter_write(input_path, tmp_path): ds = oc.open(input_path) ds = ds.filter(oc.col("sod_halo_mass") > 0) + oc.write(temporary_path, ds) data = ds.collect().data ds.close() ds = oc.read(temporary_path) written_data = ds.data + + handler = ds._Dataset__handler + tree = handler._InMemoryHandler__tree + starts = tree._Tree__starts + sizes = tree._Tree__sizes + parallel_assert(lambda: np.all(data == written_data)) + for level in sizes: + parallel_assert(lambda: np.sum(sizes[level]) == len(handler)) + parallel_assert(lambda: starts[level][0] == 0) + if level > 0: + sizes_from_starts = np.diff(np.append(starts[level], len(handler))) + parallel_assert(lambda: np.all(sizes_from_starts == sizes[level])) @pytest.mark.parallel(nprocs=4) diff --git a/test/test_spatial.py b/test/test_spatial.py index 97f124d0..c15ef916 100644 --- a/test/test_spatial.py +++ b/test/test_spatial.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import opencosmo as oc @@ -16,14 +17,16 @@ def test_filter_write(input_path, tmp_path): size_unfiltered = len(f.data) ds = oc.read(tmp_file) - slices = ds._Dataset__handler._InMemoryHandler__tree._Tree__slices + starts = ds._Dataset__handler._InMemoryHandler__tree._Tree__starts + sizes = ds._Dataset__handler._InMemoryHandler__tree._Tree__sizes size = len(ds.data) assert size < size_unfiltered def is_valid(sl, size): return sl.stop > sl.start and sl.start >= 0 and sl.stop <= size - for level in slices: - slice_total = sum((s.stop - s.start) for s in slices[level].values()) + for level in range(len(starts)): + slice_total = np.sum(sizes[level]) + assert slice_total == size - assert all(is_valid(s, size) for s in slices[level].values()) + assert np.all(np.cumsum(np.insert(sizes[level], 0, 0))[:-1] == starts[level])