diff --git a/changes/+83c06d0d.improvement.rst b/changes/+83c06d0d.improvement.rst new file mode 100644 index 00000000..9ad8a26b --- /dev/null +++ b/changes/+83c06d0d.improvement.rst @@ -0,0 +1 @@ +:py:meth:`StructureCollection.select ` and :py:meth:`StructureCollection.drop ` now follow the same semantics as :py:meth:`StructureCollection.evaluate StructureCollection: filtered, self.__header, self.__datasets, self.__links, self.__hide_source ) - def select( - self, - columns: str | Iterable[str], - dataset: str, - ) -> StructureCollection: + def select(self, **column_selections: str | Iterable[str]) -> StructureCollection: """ Update a dataset in the collection collection to only include the - columns specified. + columns specified. The name of the arguments to this function should be + dataset names. For example: + + .. code-block:: python + + collection = collection.select( + halo_properties = ["fof_halo_mass", "sod_halo_mass", "sod_halo_cdelta"], + dm_particles = ["x", "y", "z"] + ) + + Datasets that do not appear in the argument list will not be modified. You can + remove entire datasets from the collection with + :py:meth:`with_datasets ` + Parameters ---------- - columns : str | Iterable[str] - The columns to select from the dataset. + **column_selections : str | Iterable[str] + The columns to select from a given dataset dataset : str The dataset to select from. @@ -457,40 +466,49 @@ def select( ValueError If the specified dataset is not found in the collection. """ - if dataset == self.__header.file.data_type: - new_source = self.__source.select(columns) - return StructureCollection( - new_source, self.__header, self.__datasets, self.__links - ) + if not column_selections: + return self + new_source = self.__source + new_datasets = {} + for dataset, columns in column_selections.items(): + if dataset == self.__header.file.data_type: + new_source = self.__source.select(columns) + continue - elif dataset not in self.__datasets: - raise ValueError(f"Dataset {dataset} not found in collection.") - output_ds = self.__datasets[dataset] - if not isinstance(output_ds, oc.Dataset): - raise NotImplementedError + elif dataset not in self.__datasets: + raise ValueError(f"Dataset {dataset} not found in collection.") + + output_ds = self.__datasets[dataset] + + if not isinstance(output_ds, oc.Dataset): + raise NotImplementedError + + new_dataset = output_ds.select(columns) + new_datasets[dataset] = new_dataset - new_dataset = output_ds.select(columns) return StructureCollection( - self.__source, + new_source, self.__header, - {**self.__datasets, dataset: new_dataset}, + self.__datasets | new_datasets, self.__links, self.__hide_source, ) - def drop(self, columns: str | Iterable[str], dataset: Optional[str] = None): + def drop(self, **columns_to_drop): """ Update the linked collection by dropping the specified columns - in the given dataset. If no dataset is specified, the properties dataset - is used. For example, if this collection contains galaxies, - calling this function without a "dataset" argument will select columns - from the galaxy_properties dataset. + in the specified datasets. Follows the exact same semantics as + :py:meth:`StructureCollection.select `. + Argument names should be datasets in this collection, and the argument + values should be a string or list of strings. + Datasets that are not included will not be modified. You can drop + entire datasets with :py:meth:`with_datasets ` Parameters ---------- - columns : str | Iterable[str] - The columns to select from the dataset. + **columns_to_drop : str | Iterable[str] + The columns to drop from the dataset. dataset : str, optional The dataset to select from. If None, the properties dataset is used. @@ -505,21 +523,25 @@ def drop(self, columns: str | Iterable[str], dataset: Optional[str] = None): ValueError If the specified dataset is not found in the collection. """ + if not columns_to_drop: + return self + new_source = self.__source + new_datasets = {} - if dataset is None or dataset == self.__header.file.data_type: - new_source = self.__source.drop(columns) - return StructureCollection( - new_source, self.__header, self.__datasets, self.__links - ) + for dataset_name, columns in columns_to_drop.items(): + if dataset_name == self.__header.file.data_type: + new_source = self.__source.drop(columns) + continue + + elif dataset_name not in self.__datasets: + raise ValueError(f"Dataset {dataset_name} not found in collection.") + new_ds = self.__datasets[dataset_name].drop(columns) + new_datasets[dataset_name] = new_ds - elif dataset not in self.__datasets: - raise ValueError(f"Dataset {dataset} not found in collection.") - output_ds = self.__datasets[dataset] - new_dataset = output_ds.drop(columns) return StructureCollection( - self.__source, + new_source, self.__header, - {**self.__datasets, dataset: new_dataset}, + self.__datasets | new_datasets, self.__links, self.__hide_source, ) diff --git a/test/test_collection.py b/test/test_collection.py index 4e14e097..1f9a2fa7 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -558,8 +558,9 @@ def test_data_link_selection(halo_paths): collection = collection.filter(oc.col("sod_halo_mass") > 10**13).take( 10, at="random" ) - collection = collection.select(["x", "y", "z"], dataset="dm_particles") - collection = collection.select(["fof_halo_tag", "sod_halo_mass"], "halo_properties") + collection = collection.select( + dm_particles=["x", "y", "z"], halo_properties=["fof_halo_tag", "sod_halo_mass"] + ) found_dm_particles = False for halo in collection.objects(): properties = halo["halo_properties"] @@ -578,8 +579,9 @@ def test_data_link_drop(halo_paths): collection = collection.filter(oc.col("sod_halo_mass") > 10**13).take( 10, at="random" ) - collection = collection.drop(["x", "y", "z"], dataset="dm_particles") - collection = collection.drop(["fof_halo_tag", "sod_halo_mass"]) + collection = collection.drop( + dm_particles=["x", "y", "z"], halo_properties=["fof_halo_tag", "sod_halo_mass"] + ) found_dm_particles = False for halo in collection.objects(): properties = halo["halo_properties"]