From 0493392bdab8e9b299df24a1d8b72b595b8b48b9 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 20 Mar 2025 13:45:03 -0500 Subject: [PATCH 1/5] Better tree handling --- opencosmo/spatial/tree.py | 65 ++++++++++++++++++++++++--------------- test/test_spatial.py | 13 +++++--- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/opencosmo/spatial/tree.py b/opencosmo/spatial/tree.py index baf4e957..f384728a 100644 --- a/opencosmo/spatial/tree.py +++ b/opencosmo/spatial/tree.py @@ -17,19 +17,19 @@ def read_tree(file: h5py.File | h5py.Group, header: OpenCosmoHeader): index and a slice into the data. """ max_level = header.reformat.max_level - data_indices = OrderedDict() + starts = {} + sizes = {} + for level in range(max_level + 1): group = file[f"index/level_{level}"] - starts = group["start"][()] - sizes = group["size"][()] - level_indices = {} - for i, (start, size) in enumerate(zip(starts, sizes)): - level_indices[i] = slice(start, start + size) - data_indices[level] = level_indices + level_starts = group["start"][()] + level_sizes = group["size"][()] + starts[level] = level_starts + sizes[level] = level_sizes spatial_index = OctTreeIndex(header.simulation, max_level) - return Tree(spatial_index, data_indices) + return Tree(spatial_index, starts, sizes) def write_tree(file: h5py.File, tree: Tree, dataset_name: str = "index"): @@ -43,9 +43,11 @@ class Tree: spatial queries """ - def __init__(self, index: SpatialIndex, slices: dict[int, dict[int, slice]]): + def __init__(self, index: SpatialIndex, starts: dict[int], sizes: dict[int]): self.__index = index - self.__slices = slices + self.__starts = starts + self.__sizes = sizes + def apply_mask(self, mask: np.ndarray) -> Tree: """ @@ -55,19 +57,34 @@ def apply_mask(self, mask: np.ndarray) -> Tree: The mask will have the same shape as the original data. """ + if np.all(mask): return self - new_slices = {} - for level, slices in self.__slices.items(): - lengths = [np.sum(mask[s]) for s in slices.values()] - new_starts = np.cumsum([0] + lengths[:-1]) - new_slices[level] = { - i: slice(new_starts[i], new_starts[i] + lengths[i]) - for i in range(len(lengths)) - if lengths[i] > 0 - } - return Tree(self.__index, new_slices) + output_starts = {} + output_sizes = {} + for level in self.__starts: + start = self.__starts[level] + size = self.__sizes[level] + offsets = np.zeros_like(size) + for i in range(len(start)): + # Create a slice object for the current level + s = slice(start[i], start[i] + size[i]) + slice_mask = mask[s] # Apply the slice to the mask + offsets[i] = np.sum(slice_mask) # Count the number of True values + level_starts = np.cumsum(np.insert(offsets, 0, 0))[:-1] # Cumulative sum to get new starts + level_sizes = offsets + output_starts[level] = level_starts + output_sizes[level] = level_sizes + + return Tree(self.__index, output_starts, output_sizes) + + + + # Apply the mask to get the new slice + + + def write(self, file: h5py.File, dataset_name: str = "index"): """ Write the tree to an HDF5 file. Note that this function @@ -76,9 +93,7 @@ def write(self, file: h5py.File, dataset_name: str = "index"): necessary. """ group = file.require_group(dataset_name) - for level, slices in self.__slices.items(): + for level in self.__starts: level_group = group.require_group(f"level_{level}") - start = np.array([s.start for s in slices.values()]) - size = np.array([s.stop - s.start for s in slices.values()]) - level_group.create_dataset("start", data=start) - level_group.create_dataset("size", data=size) + level_group.create_dataset("start", data=self.__starts[level]) + level_group.create_dataset("size", data=self.__sizes[level]) diff --git a/test/test_spatial.py b/test/test_spatial.py index 97f124d0..1b6d96ae 100644 --- a/test/test_spatial.py +++ b/test/test_spatial.py @@ -1,4 +1,5 @@ import pytest +import numpy as np import opencosmo as oc @@ -16,14 +17,18 @@ def test_filter_write(input_path, tmp_path): size_unfiltered = len(f.data) ds = oc.read(tmp_file) - slices = ds._Dataset__handler._InMemoryHandler__tree._Tree__slices + starts = ds._Dataset__handler._InMemoryHandler__tree._Tree__starts + sizes = ds._Dataset__handler._InMemoryHandler__tree._Tree__sizes size = len(ds.data) assert size < size_unfiltered def is_valid(sl, size): return sl.stop > sl.start and sl.start >= 0 and sl.stop <= size - for level in slices: - slice_total = sum((s.stop - s.start) for s in slices[level].values()) + for level in range(len(starts)): + slice_total = np.sum(sizes[level]) + + assert slice_total == size - assert all(is_valid(s, size) for s in slices[level].values()) + assert np.all(np.cumsum(np.insert(sizes[level], 0, 0))[:-1] == starts[level]) + From d56638119b0e213b756348b26b9ffda27be01677 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 21 Mar 2025 13:24:38 -0500 Subject: [PATCH 2/5] Better parallel masking algorithm --- opencosmo/handler/mpi.py | 31 +++-------------- opencosmo/handler/oom.py | 5 +-- opencosmo/spatial/tree.py | 71 ++++++++++++++++++++++++++++++++++----- test/parallel/test_mpi.py | 15 +++++++++ 4 files changed, 84 insertions(+), 38 deletions(-) diff --git a/opencosmo/handler/mpi.py b/opencosmo/handler/mpi.py index 8c395512..6671596a 100644 --- a/opencosmo/handler/mpi.py +++ b/opencosmo/handler/mpi.py @@ -8,7 +8,7 @@ from opencosmo.file import get_data_structure from opencosmo.handler import InMemoryHandler -from opencosmo.spatial.tree import Tree +from opencosmo.spatial.tree import Tree, pack_masked_ranges def verify_input(comm: MPI.Comm, require: Iterable[str] = [], **kwargs) -> dict: @@ -162,32 +162,9 @@ def write( data = data[mask] data_group[column][rank_start:rank_end] = data - displacements = np.insert(np.cumsum(all_input_lengths[:-1]), 0, 0) - if rank == 0: - recvbuf = np.empty(sum(all_input_lengths), dtype=np.uint8) - self.__comm.Gatherv( - sendbuf=mask.view(np.uint8), - recvbuf=( - recvbuf.view(np.uint8), - all_input_lengths, - displacements, - MPI.BYTE, - ), - root=0, - ) - else: - self.__comm.Gatherv( - sendbuf=mask.view(np.uint8), recvbuf=(None, None, None, None), root=0 - ) - - if rank == 0: - mask = recvbuf.astype(bool) - tree = self.__tree.apply_mask(mask) - else: - tree = None - tree = self.__comm.bcast(tree, root=0) - # - tree.write(group) # type: ignore + new_tree = self.__tree.apply_mask(mask, self.__comm, self.elem_range()) + + new_tree.write(group) # type: ignore self.__comm.Barrier() diff --git a/opencosmo/handler/oom.py b/opencosmo/handler/oom.py index 9b6fa2eb..6a1c9536 100644 --- a/opencosmo/handler/oom.py +++ b/opencosmo/handler/oom.py @@ -89,9 +89,10 @@ def get_data( raise ValueError("This file has already been closed") output = {} for column, builder in builders.items(): - data = self.__group[column][()] if mask is not None: - data = data[mask] + data = self.__group[column][mask] + else: + data = self.__group[column][()] col = Column(data, name=column) output[column] = builder.build(col) diff --git a/opencosmo/spatial/tree.py b/opencosmo/spatial/tree.py index f384728a..adbb1c8c 100644 --- a/opencosmo/spatial/tree.py +++ b/opencosmo/spatial/tree.py @@ -5,6 +5,11 @@ import h5py import numpy as np +try: + from mpi4py import MPI +except ImportError: + MPI = None + from opencosmo.header import OpenCosmoHeader from opencosmo.spatial.index import SpatialIndex from opencosmo.spatial.octree import OctTreeIndex @@ -35,6 +40,46 @@ def read_tree(file: h5py.File | h5py.Group, header: OpenCosmoHeader): def write_tree(file: h5py.File, tree: Tree, dataset_name: str = "index"): tree.write(file, dataset_name) +def apply_range_mask(mask: np.ndarray, range_: tuple[int, int], starts: dict[int, np.ndarray], sizes: dict[int, np.ndarray]) -> dict[int, tuple[int, np.ndarray]]: + """ + Given an index range, apply a mask of the same size to produces new sizes. + """ + output_sizes = {} + for level, st in starts.items(): + ends = st + sizes[level] + # Not in range if the end is less than start, or the start is greater than end + overlaps_mask = ~((st > range_[1]) | (ends < range_[0])) + # The first start may be less thank the range start so + first_start_index = np.argmax(overlaps_mask) + st = st[overlaps_mask] + st[0] = range_[0] + st = st - range_[0] + # Determine how many true values are in the mask in the ranges + new_sizes = np.fromiter((np.sum(a) for a in np.split(mask, st[1:])), dtype=int) + output_sizes[level] = (first_start_index, new_sizes) + return output_sizes + +def pack_masked_ranges(old_starts: dict[int, np.ndarray], new_sizes: list[dict[int, tuple[int, np.ndarray]]]) -> dict[int, np.ndarray]: + """ + Given a list of masked ranges, pack them into a new set of sizes. + """ + output_starts = {} + output_sizes = {} + for level in new_sizes[0]: + new_level_sizes = np.zeros_like(old_starts[level]) + new_start_info = [rm[level] for rm in new_sizes] + for (first_idx, sizes) in new_start_info: + new_level_sizes[first_idx:first_idx + len(sizes)] += sizes + output_sizes[level] = new_level_sizes + output_starts[level] = np.cumsum(np.insert(new_level_sizes, 0, 0))[:-1] + + return output_starts, output_sizes + + + + + + class Tree: """ @@ -43,13 +88,13 @@ class Tree: spatial queries """ - def __init__(self, index: SpatialIndex, starts: dict[int], sizes: dict[int]): + def __init__(self, index: SpatialIndex, starts: dict[int, np.ndarray], sizes: dict[int, np.ndarray]): self.__index = index self.__starts = starts self.__sizes = sizes - def apply_mask(self, mask: np.ndarray) -> Tree: + def apply_mask(self, mask: np.ndarray, comm: MPI.Comm = None, range_ = None) -> Tree: """ Given a boolean mask, create a new tree with slices adjusted to only include the elements where the mask is True. This is used @@ -58,6 +103,8 @@ def apply_mask(self, mask: np.ndarray) -> Tree: The mask will have the same shape as the original data. """ + if comm is not None: + return self.__apply_rank_mask(mask, comm, range_) if np.all(mask): return self output_starts = {} @@ -78,18 +125,24 @@ def apply_mask(self, mask: np.ndarray) -> Tree: return Tree(self.__index, output_starts, output_sizes) - - - # Apply the mask to get the new slice - - - + def __apply_rank_mask(self, mask: np.ndarray, comm: MPI.Comm, range_: tuple[int, int]) -> Tree: + """ + Given a range and a mask, apply the mask to the tree. The mask + will have the same shape as the original data. + """ + new_sizes = apply_range_mask(mask, range_, self.__starts, self.__sizes) + all_new_sizes = comm.allgather(new_sizes) + new_starts, new_sizes = pack_masked_ranges(self.__starts, all_new_sizes) + return Tree(self.__index, new_starts, new_sizes) def write(self, file: h5py.File, dataset_name: str = "index"): """ Write the tree to an HDF5 file. Note that this function is not responsible for applying masking. The routine calling this - function should first create a new tree with apply_mask if + funct + MPI = None + MPI = None + MPI = Noneion should first create a new tree with apply_mask if necessary. """ group = file.require_group(dataset_name) diff --git a/test/parallel/test_mpi.py b/test/parallel/test_mpi.py index 125a4bfe..40f5ec33 100644 --- a/test/parallel/test_mpi.py +++ b/test/parallel/test_mpi.py @@ -91,13 +91,28 @@ def test_filter_write(input_path, tmp_path): ds = oc.open(input_path) ds = ds.filter(oc.col("sod_halo_mass") > 0) + + + oc.write(temporary_path, ds) data = ds.collect().data ds.close() ds = oc.read(temporary_path) written_data = ds.data + + handler = ds._Dataset__handler + tree = handler._InMemoryHandler__tree + starts = tree._Tree__starts + sizes = tree._Tree__sizes + parallel_assert(lambda: np.all(data == written_data)) + for level in sizes: + parallel_assert(lambda: np.sum(sizes[level]) == len(handler)) + parallel_assert(lambda: starts[level][0] == 0) + if level > 0: + sizes_from_starts = np.diff(np.append(starts[level], len(handler))) + parallel_assert(lambda: np.all(sizes_from_starts == sizes[level])) @pytest.mark.parallel(nprocs=4) From a6dc4f12c2d014199b656c457e933b6e03b4ede6 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 21 Mar 2025 14:01:56 -0500 Subject: [PATCH 3/5] Linting and mypy errors --- opencosmo/handler/mpi.py | 5 ++-- opencosmo/spatial/tree.py | 63 +++++++++++++++++++++++++-------------- test/parallel/test_mpi.py | 2 -- test/test_spatial.py | 4 +-- 4 files changed, 43 insertions(+), 31 deletions(-) diff --git a/opencosmo/handler/mpi.py b/opencosmo/handler/mpi.py index 6671596a..4b9915cc 100644 --- a/opencosmo/handler/mpi.py +++ b/opencosmo/handler/mpi.py @@ -8,7 +8,7 @@ from opencosmo.file import get_data_structure from opencosmo.handler import InMemoryHandler -from opencosmo.spatial.tree import Tree, pack_masked_ranges +from opencosmo.spatial.tree import Tree def verify_input(comm: MPI.Comm, require: Iterable[str] = [], **kwargs) -> dict: @@ -128,7 +128,6 @@ def write( rank_output_length = np.sum(mask) all_output_lengths = self.__comm.allgather(rank_output_length) - all_input_lengths = self.__comm.allgather(len(mask)) rank = self.__comm.Get_rank() @@ -163,7 +162,7 @@ def write( data_group[column][rank_start:rank_end] = data new_tree = self.__tree.apply_mask(mask, self.__comm, self.elem_range()) - + new_tree.write(group) # type: ignore self.__comm.Barrier() diff --git a/opencosmo/spatial/tree.py b/opencosmo/spatial/tree.py index adbb1c8c..67303a49 100644 --- a/opencosmo/spatial/tree.py +++ b/opencosmo/spatial/tree.py @@ -1,14 +1,14 @@ from __future__ import annotations -from collections import OrderedDict import h5py import numpy as np +from typing import Optional try: from mpi4py import MPI except ImportError: - MPI = None + MPI = None # type: ignore from opencosmo.header import OpenCosmoHeader from opencosmo.spatial.index import SpatialIndex @@ -25,7 +25,6 @@ def read_tree(file: h5py.File | h5py.Group, header: OpenCosmoHeader): starts = {} sizes = {} - for level in range(max_level + 1): group = file[f"index/level_{level}"] level_starts = group["start"][()] @@ -40,7 +39,13 @@ def read_tree(file: h5py.File | h5py.Group, header: OpenCosmoHeader): def write_tree(file: h5py.File, tree: Tree, dataset_name: str = "index"): tree.write(file, dataset_name) -def apply_range_mask(mask: np.ndarray, range_: tuple[int, int], starts: dict[int, np.ndarray], sizes: dict[int, np.ndarray]) -> dict[int, tuple[int, np.ndarray]]: + +def apply_range_mask( + mask: np.ndarray, + range_: tuple[int, int], + starts: dict[int, np.ndarray], + sizes: dict[int, np.ndarray], +) -> dict[int, tuple[int, np.ndarray]]: """ Given an index range, apply a mask of the same size to produces new sizes. """ @@ -50,7 +55,7 @@ def apply_range_mask(mask: np.ndarray, range_: tuple[int, int], starts: dict[int # Not in range if the end is less than start, or the start is greater than end overlaps_mask = ~((st > range_[1]) | (ends < range_[0])) # The first start may be less thank the range start so - first_start_index = np.argmax(overlaps_mask) + first_start_index = int(np.argmax(overlaps_mask)) st = st[overlaps_mask] st[0] = range_[0] st = st - range_[0] @@ -59,7 +64,11 @@ def apply_range_mask(mask: np.ndarray, range_: tuple[int, int], starts: dict[int output_sizes[level] = (first_start_index, new_sizes) return output_sizes -def pack_masked_ranges(old_starts: dict[int, np.ndarray], new_sizes: list[dict[int, tuple[int, np.ndarray]]]) -> dict[int, np.ndarray]: + +def pack_masked_ranges( + old_starts: dict[int, np.ndarray], + new_sizes: list[dict[int, tuple[int, np.ndarray]]], +) -> tuple[dict[int, np.ndarray], dict[int, np.ndarray]]: """ Given a list of masked ranges, pack them into a new set of sizes. """ @@ -68,17 +77,12 @@ def pack_masked_ranges(old_starts: dict[int, np.ndarray], new_sizes: list[dict[i for level in new_sizes[0]: new_level_sizes = np.zeros_like(old_starts[level]) new_start_info = [rm[level] for rm in new_sizes] - for (first_idx, sizes) in new_start_info: - new_level_sizes[first_idx:first_idx + len(sizes)] += sizes + for first_idx, sizes in new_start_info: + new_level_sizes[first_idx : first_idx + len(sizes)] += sizes output_sizes[level] = new_level_sizes output_starts[level] = np.cumsum(np.insert(new_level_sizes, 0, 0))[:-1] - - return output_starts, output_sizes - - - - + return output_starts, output_sizes class Tree: @@ -88,13 +92,22 @@ class Tree: spatial queries """ - def __init__(self, index: SpatialIndex, starts: dict[int, np.ndarray], sizes: dict[int, np.ndarray]): + def __init__( + self, + index: SpatialIndex, + starts: dict[int, np.ndarray], + sizes: dict[int, np.ndarray], + ): self.__index = index self.__starts = starts self.__sizes = sizes - - def apply_mask(self, mask: np.ndarray, comm: MPI.Comm = None, range_ = None) -> Tree: + def apply_mask( + self, + mask: np.ndarray, + comm: Optional[MPI.Comm] = None, + range_: Optional[tuple] = None, + ) -> Tree: """ Given a boolean mask, create a new tree with slices adjusted to only include the elements where the mask is True. This is used @@ -103,7 +116,7 @@ def apply_mask(self, mask: np.ndarray, comm: MPI.Comm = None, range_ = None) -> The mask will have the same shape as the original data. """ - if comm is not None: + if comm is not None and range_ is not None: return self.__apply_rank_mask(mask, comm, range_) if np.all(mask): return self @@ -118,23 +131,27 @@ def apply_mask(self, mask: np.ndarray, comm: MPI.Comm = None, range_ = None) -> s = slice(start[i], start[i] + size[i]) slice_mask = mask[s] # Apply the slice to the mask offsets[i] = np.sum(slice_mask) # Count the number of True values - level_starts = np.cumsum(np.insert(offsets, 0, 0))[:-1] # Cumulative sum to get new starts + level_starts = np.cumsum(np.insert(offsets, 0, 0))[ + :-1 + ] # Cumulative sum to get new starts level_sizes = offsets output_starts[level] = level_starts output_sizes[level] = level_sizes return Tree(self.__index, output_starts, output_sizes) - def __apply_rank_mask(self, mask: np.ndarray, comm: MPI.Comm, range_: tuple[int, int]) -> Tree: + def __apply_rank_mask( + self, mask: np.ndarray, comm: MPI.Comm, range_: tuple[int, int] + ) -> Tree: """ Given a range and a mask, apply the mask to the tree. The mask will have the same shape as the original data. """ new_sizes = apply_range_mask(mask, range_, self.__starts, self.__sizes) all_new_sizes = comm.allgather(new_sizes) - new_starts, new_sizes = pack_masked_ranges(self.__starts, all_new_sizes) - return Tree(self.__index, new_starts, new_sizes) - + output_starts, output_sizes = pack_masked_ranges(self.__starts, all_new_sizes) + return Tree(self.__index, output_starts, output_sizes) + def write(self, file: h5py.File, dataset_name: str = "index"): """ Write the tree to an HDF5 file. Note that this function diff --git a/test/parallel/test_mpi.py b/test/parallel/test_mpi.py index 40f5ec33..2ef70ca8 100644 --- a/test/parallel/test_mpi.py +++ b/test/parallel/test_mpi.py @@ -92,8 +92,6 @@ def test_filter_write(input_path, tmp_path): ds = oc.open(input_path) ds = ds.filter(oc.col("sod_halo_mass") > 0) - - oc.write(temporary_path, ds) data = ds.collect().data ds.close() diff --git a/test/test_spatial.py b/test/test_spatial.py index 1b6d96ae..c15ef916 100644 --- a/test/test_spatial.py +++ b/test/test_spatial.py @@ -1,5 +1,5 @@ -import pytest import numpy as np +import pytest import opencosmo as oc @@ -28,7 +28,5 @@ def is_valid(sl, size): for level in range(len(starts)): slice_total = np.sum(sizes[level]) - assert slice_total == size assert np.all(np.cumsum(np.insert(sizes[level], 0, 0))[:-1] == starts[level]) - From ad8100b7577736b320083ae1b9a695a28bcaffb9 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 21 Mar 2025 14:23:02 -0500 Subject: [PATCH 4/5] And one more fix --- opencosmo/spatial/tree.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/opencosmo/spatial/tree.py b/opencosmo/spatial/tree.py index 67303a49..55da7124 100644 --- a/opencosmo/spatial/tree.py +++ b/opencosmo/spatial/tree.py @@ -20,13 +20,23 @@ def read_tree(file: h5py.File | h5py.Group, header: OpenCosmoHeader): Read a tree from an HDF5 file and the associated header. The tree is just a mapping between a spatial index and a slice into the data. + + Note: The max level in the header may not actually match + the max level in the file. When a large dataset is filtered down, + we may reduce the tree level to save space in the output file. + + The max level in the header is the maximum level in the full + dataset, so this is the HIGHEST it can be. """ max_level = header.reformat.max_level starts = {} sizes = {} for level in range(max_level + 1): - group = file[f"index/level_{level}"] + try: + group = file[f"index/level_{level}"] + except KeyError: + break level_starts = group["start"][()] level_sizes = group["size"][()] starts[level] = level_starts @@ -50,7 +60,17 @@ def apply_range_mask( Given an index range, apply a mask of the same size to produces new sizes. """ output_sizes = {} + max_level = 0 + for level in starts: + level_sizes = sizes[level] + nonzero = level_sizes[level_sizes > 0] + if np.average(nonzero) < 500: + break + max_level += 1 + for level, st in starts.items(): + if level > max_level: + break ends = st + sizes[level] # Not in range if the end is less than start, or the start is greater than end overlaps_mask = ~((st > range_[1]) | (ends < range_[0])) From 17ef328b74a16456c3dc2a580fd18793aea26b6e Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 21 Mar 2025 14:34:32 -0500 Subject: [PATCH 5/5] Another fix --- opencosmo/spatial/tree.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/opencosmo/spatial/tree.py b/opencosmo/spatial/tree.py index 55da7124..29cfb955 100644 --- a/opencosmo/spatial/tree.py +++ b/opencosmo/spatial/tree.py @@ -1,9 +1,9 @@ from __future__ import annotations +from typing import Optional import h5py import numpy as np -from typing import Optional try: from mpi4py import MPI @@ -60,17 +60,8 @@ def apply_range_mask( Given an index range, apply a mask of the same size to produces new sizes. """ output_sizes = {} - max_level = 0 - for level in starts: - level_sizes = sizes[level] - nonzero = level_sizes[level_sizes > 0] - if np.average(nonzero) < 500: - break - max_level += 1 for level, st in starts.items(): - if level > max_level: - break ends = st + sizes[level] # Not in range if the end is less than start, or the start is greater than end overlaps_mask = ~((st > range_[1]) | (ends < range_[0])) @@ -88,9 +79,17 @@ def apply_range_mask( def pack_masked_ranges( old_starts: dict[int, np.ndarray], new_sizes: list[dict[int, tuple[int, np.ndarray]]], + min_level_size: int = 500, ) -> tuple[dict[int, np.ndarray], dict[int, np.ndarray]]: """ Given a list of masked ranges, pack them into a new set of sizes. + This is used when working with MPI, and allows us to avoid sending + very large masks between ranks. + + For queries that return a small fraction of the data, we can end up + writing a lot of zeros in the lower levels of the tree. So we can + dynamically choose to stop writing levels when the average size of + the level is below a certain threshold """ output_starts = {} output_sizes = {} @@ -99,6 +98,10 @@ def pack_masked_ranges( new_start_info = [rm[level] for rm in new_sizes] for first_idx, sizes in new_start_info: new_level_sizes[first_idx : first_idx + len(sizes)] += sizes + + avg_size = np.mean(new_level_sizes[new_level_sizes > 0]) + if avg_size < min_level_size: + break output_sizes[level] = new_level_sizes output_starts[level] = np.cumsum(np.insert(new_level_sizes, 0, 0))[:-1]