diff --git a/changes/77.feature.rst b/changes/77.feature.rst new file mode 100644 index 00000000..7b863d12 --- /dev/null +++ b/changes/77.feature.rst @@ -0,0 +1 @@ +Columns can now be dropped with `Dataset.drop ` (inverse of "select") diff --git a/opencosmo/collection/lightcone/lightcone.py b/opencosmo/collection/lightcone/lightcone.py index 849e799c..0ee84ac9 100644 --- a/opencosmo/collection/lightcone/lightcone.py +++ b/opencosmo/collection/lightcone/lightcone.py @@ -436,6 +436,37 @@ def select(self, columns: str | Iterable[str]) -> Self: hide_redshift = True return self.__map("select", columns, hide_redshift=hide_redshift) + def drop(self, columns: str | Iterable[str]) -> Self: + """ + Produce a new dataset by dropping columns from this dataset. + + Parameters + ---------- + columns : str or list[str] + The column or columns to drop. + + Returns + ------- + dataset : Dataset + The new dataset without the dropped columns + + Raises + ------ + ValueError + If any of the given columns are not in the dataset. + """ + if isinstance(columns, str): + columns = [columns] + + dropped_columns = set(columns) + current_columns = set(self.columns) + if missing := dropped_columns.difference(current_columns): + raise ValueError( + f"Tried to drop columns that are not in this dataset: {missing}" + ) + kept_columns = current_columns - dropped_columns + return self.select(kept_columns) + def take(self, n: int, at: str = "random") -> "Lightcone": """ Create a new dataset from some number of rows from this dataset. diff --git a/opencosmo/collection/simulation/simulation.py b/opencosmo/collection/simulation/simulation.py index e3c655a3..e2eeb846 100644 --- a/opencosmo/collection/simulation/simulation.py +++ b/opencosmo/collection/simulation/simulation.py @@ -188,7 +188,7 @@ def filter(self, *masks: Mask, **kwargs) -> Self: def select(self, *args, **kwargs) -> Self: """ - Select a subset of the datasets in the collection. This method + Select a set of columns in the datasets in this collection. This method calls the underlying method in :class:`opencosmo.Dataset`, or :class:`opencosmo.Collection` depending on the context. As such its behavior and arguments can vary depending on what this collection @@ -206,6 +206,26 @@ def select(self, *args, **kwargs) -> Self: """ return self.__map("select", *args, **kwargs) + def drop(self, *args, **kwargs) -> Self: + """ + Drop a set of columns from the datasets in the collection. This method + calls the underlying method in :class:`opencosmo.Dataset`, or + :class:`opencosmo.Collection` depending on the context. As such + its behavior and arguments can vary depending on what this collection + contains. + + Parameters + ---------- + args: + The arguments to pass to the select method. This is + usually a list of column names to drop. + kwargs: + The keyword arguments to pass to the select method. + This is usually a dictionary of column names to select. + + """ + return self.__map("drop", *args, **kwargs) + def take(self, n: int, at: str = "random") -> Self: """ Take a subest of rows from all datasets or collections in this collection. diff --git a/opencosmo/collection/structure/collection.py b/opencosmo/collection/structure/collection.py index 7b321132..9ce0f029 100644 --- a/opencosmo/collection/structure/collection.py +++ b/opencosmo/collection/structure/collection.py @@ -264,8 +264,10 @@ def select( ) -> StructureCollection: """ Update the linked collection to only include the columns specified - in the given dataset. If no dataset is specified, the properties dataset - is used. + in the given dataset. If no dataset is specified the properties of the + structure will be used. For example, if this collection contains halos, + calling this function without a "dataset" argument will select columns + from the halo_properties dataset. Parameters ---------- @@ -302,6 +304,51 @@ def select( self.__links, ) + def drop(self, columns: str | Iterable[str], dataset: Optional[str] = None): + """ + Update the linked collection by dropping the specified columns + in the given dataset. If no dataset is specified, the properties dataset + is used. For example, if this collection contains galaxies, + calling this function without a "dataset" argument will select columns + from the galaxy_properties dataset. + + + Parameters + ---------- + columns : str | Iterable[str] + The columns to select from the dataset. + + dataset : str, optional + The dataset to select from. If None, the properties dataset is used. + + Returns + ------- + StructureCollection + A new collection with only the selected columns for the specified dataset. + + Raises + ------- + ValueError + If the specified dataset is not found in the collection. + """ + + if dataset is None or dataset == self.__header.file.data_type: + new_source = self.__source.drop(columns) + return StructureCollection( + new_source, self.__header, self.__datasets, self.__links + ) + + elif dataset not in self.__datasets: + raise ValueError(f"Dataset {dataset} not found in collection.") + output_ds = self.__datasets[dataset] + new_dataset = output_ds.drop(columns) + return StructureCollection( + self.__source, + self.__header, + {**self.__datasets, dataset: new_dataset}, + self.__links, + ) + def with_units(self, convention: str): """ Apply the given unit convention to the collection. diff --git a/opencosmo/dataset/dataset.py b/opencosmo/dataset/dataset.py index 97fe684a..cfe186d7 100644 --- a/opencosmo/dataset/dataset.py +++ b/opencosmo/dataset/dataset.py @@ -402,6 +402,37 @@ def select(self, columns: str | Iterable[str]) -> Dataset: self.__tree, ) + def drop(self, columns: str | Iterable[str]) -> Dataset: + """ + Create a new dataset without the provided columns. + + Parameters + ---------- + columns : str or list[str] + The columns to drop + + Returns + ------- + dataset : Dataset + The new dataset without the droppedcolumns + + Raises + ------ + ValueError + If any of the provided columns are not in the dataset. + + """ + if isinstance(columns, str): + columns = [columns] + + current_columns = set(self.__state.columns) + dropped_columns = set(columns) + + if missing := dropped_columns.difference(current_columns): + raise ValueError(f"Columns {missing} are not in this dataset") + + return self.select(current_columns - dropped_columns) + def take(self, n: int, at: str = "random") -> Dataset: """ Create a new dataset from some number of rows from this dataset. diff --git a/test/test_collection.py b/test/test_collection.py index 0e6edecc..528ae5a1 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -167,6 +167,27 @@ def test_data_link_selection(halo_paths): assert found_dm_particles +def test_data_link_drop(halo_paths): + collection = oc.open(*halo_paths) + collection = collection.filter(oc.col("sod_halo_mass") > 10**13).take( + 10, at="random" + ) + collection = collection.drop(["x", "y", "z"], dataset="dm_particles") + collection = collection.drop(["fof_halo_tag", "sod_halo_mass"]) + found_dm_particles = False + for halo in collection.objects(): + properties = halo["halo_properties"] + assert not set(properties.keys()).intersection( + {"fof_halo_tag", "sod_halo_mass"} + ) + + if halo["dm_particles"] is not None: + dm_particles = halo["dm_particles"] + found_dm_particles = True + assert not set(dm_particles.columns).intersection({"x", "y", "z"}) + assert found_dm_particles + + def test_link_halos_to_galaxies(halo_paths, galaxy_paths): galaxy_path = galaxy_paths[0] collection = oc.open(*halo_paths, galaxy_path) diff --git a/test/test_dataset.py b/test/test_dataset.py index 6a7a296b..b8eabb44 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -81,6 +81,32 @@ def test_take_oom(input_path): assert len(data) == 10 +def test_drop(input_path): + with oc.open(input_path) as ds: + data = ds.data + cols = list(data.columns) + # select 10 columns at random + dropped_cols = np.random.choice(cols, 10, replace=False) + selected = ds.drop(dropped_cols) + selected_data = selected.data + + dropped_cols = set(dropped_cols) + remaining_cols = set(selected_data.colnames) + assert not dropped_cols.intersection(remaining_cols) + + +def test_drop_single(input_path): + with oc.open(input_path) as ds: + data = ds.data + cols = list(data.columns) + # select 10 columns at random + dropped_col = np.random.choice(cols) + remaining = ds.drop(dropped_col) + remaining_data = remaining.data + + assert dropped_col not in remaining_data.colnames + + def test_select_oom(input_path): with oc.open(input_path) as ds: data = ds.data diff --git a/test/test_lightcone.py b/test/test_lightcone.py index d8a95fa0..ebcb689b 100644 --- a/test/test_lightcone.py +++ b/test/test_lightcone.py @@ -200,6 +200,17 @@ def test_lc_collection_select( assert columns_found == to_select +def test_lc_collection_drop(haloproperties_600_path, haloproperties_601_path, tmp_path): + ds = oc.open(haloproperties_600_path, haloproperties_601_path) + columns = ds.columns + to_drop = set(random.choice(columns, 10)) + + ds = ds.drop(to_drop) + columns_found = set(ds.data.columns) + + assert not columns_found.intersection(to_drop) + + def test_lc_collection_take(haloproperties_600_path, haloproperties_601_path, tmp_path): ds = oc.open(haloproperties_600_path, haloproperties_601_path) n_to_take = int(0.75 * len(ds))