Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/+83c06d0d.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:py:meth:`StructureCollection.select <opencosmo.StructureCollection.select>` and :py:meth:`StructureCollection.drop <opencosmo.StructureCollection.drop>` now follow the same semantics as :py:meth:`StructureCollection.evaluate <opencosmo.StructureCollection.evaluate` for passing columns from multiple datasets in a single function call.
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,22 @@ future-annotations = true
preview = true
select = ["TC"]

[tool.towncrier.fragment.improvement]
name = "Improvements"
showcontent = true

[tool.towncrier.fragment.feature]
name = "New Features"

[tool.towncrier.fragment.bugfix]
name = "Bugfixes"

[tool.towncrier.fragment.doc]
name = "Documentation"

[tool.towncrier.fragment.removal]
name = "Deprecations and Removals"

[tool.towncrier.fragment.misc]
name = "Miscellaneous"

7 changes: 3 additions & 4 deletions src/opencosmo/collection/structure/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,9 @@ def __prepare_collection(
raise NotImplementedError
else:
collection = collection.with_datasets(list(spec.keys()))
for ds_name, columns in spec.items():
if columns is None:
continue
collection = collection.select(columns, dataset=ds_name)

selections = {ds_name: cols for ds_name, cols in spec.items() if cols is not None}
collection = collection.select(**selections)
return collection


Expand Down
100 changes: 61 additions & 39 deletions src/opencosmo/collection/structure/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,20 +429,29 @@ def filter(self, *masks, on_galaxies: bool = False) -> StructureCollection:
filtered, self.__header, self.__datasets, self.__links, self.__hide_source
)

def select(
self,
columns: str | Iterable[str],
dataset: str,
) -> StructureCollection:
def select(self, **column_selections: str | Iterable[str]) -> StructureCollection:
"""
Update a dataset in the collection collection to only include the
columns specified.
columns specified. The name of the arguments to this function should be
dataset names. For example:

.. code-block:: python

collection = collection.select(
halo_properties = ["fof_halo_mass", "sod_halo_mass", "sod_halo_cdelta"],
dm_particles = ["x", "y", "z"]
)

Datasets that do not appear in the argument list will not be modified. You can
remove entire datasets from the collection with
:py:meth:`with_datasets <opencosmo.StructureCollection.with_datasets>`



Parameters
----------
columns : str | Iterable[str]
The columns to select from the dataset.
**column_selections : str | Iterable[str]
The columns to select from a given dataset

dataset : str
The dataset to select from.
Expand All @@ -457,40 +466,49 @@ def select(
ValueError
If the specified dataset is not found in the collection.
"""
if dataset == self.__header.file.data_type:
new_source = self.__source.select(columns)
return StructureCollection(
new_source, self.__header, self.__datasets, self.__links
)
if not column_selections:
return self
new_source = self.__source
new_datasets = {}
for dataset, columns in column_selections.items():
if dataset == self.__header.file.data_type:
new_source = self.__source.select(columns)
continue

elif dataset not in self.__datasets:
raise ValueError(f"Dataset {dataset} not found in collection.")
output_ds = self.__datasets[dataset]
if not isinstance(output_ds, oc.Dataset):
raise NotImplementedError
elif dataset not in self.__datasets:
raise ValueError(f"Dataset {dataset} not found in collection.")

output_ds = self.__datasets[dataset]

if not isinstance(output_ds, oc.Dataset):
raise NotImplementedError

new_dataset = output_ds.select(columns)
new_datasets[dataset] = new_dataset

new_dataset = output_ds.select(columns)
return StructureCollection(
self.__source,
new_source,
self.__header,
{**self.__datasets, dataset: new_dataset},
self.__datasets | new_datasets,
self.__links,
self.__hide_source,
)

def drop(self, columns: str | Iterable[str], dataset: Optional[str] = None):
def drop(self, **columns_to_drop):
"""
Update the linked collection by dropping the specified columns
in the given dataset. If no dataset is specified, the properties dataset
is used. For example, if this collection contains galaxies,
calling this function without a "dataset" argument will select columns
from the galaxy_properties dataset.
in the specified datasets. Follows the exact same semantics as
:py:meth:`StructureCollection.select <opencosmo.StructureCollection.select>`.
Argument names should be datasets in this collection, and the argument
values should be a string or list of strings.

Datasets that are not included will not be modified. You can drop
entire datasets with :py:meth:`with_datasets <opencosmo.StructureCollection.with_datasets>`

Parameters
----------
columns : str | Iterable[str]
The columns to select from the dataset.
**columns_to_drop : str | Iterable[str]
The columns to drop from the dataset.

dataset : str, optional
The dataset to select from. If None, the properties dataset is used.
Expand All @@ -505,21 +523,25 @@ def drop(self, columns: str | Iterable[str], dataset: Optional[str] = None):
ValueError
If the specified dataset is not found in the collection.
"""
if not columns_to_drop:
return self
new_source = self.__source
new_datasets = {}

if dataset is None or dataset == self.__header.file.data_type:
new_source = self.__source.drop(columns)
return StructureCollection(
new_source, self.__header, self.__datasets, self.__links
)
for dataset_name, columns in columns_to_drop.items():
if dataset_name == self.__header.file.data_type:
new_source = self.__source.drop(columns)
continue

elif dataset_name not in self.__datasets:
raise ValueError(f"Dataset {dataset_name} not found in collection.")
new_ds = self.__datasets[dataset_name].drop(columns)
new_datasets[dataset_name] = new_ds

elif dataset not in self.__datasets:
raise ValueError(f"Dataset {dataset} not found in collection.")
output_ds = self.__datasets[dataset]
new_dataset = output_ds.drop(columns)
return StructureCollection(
self.__source,
new_source,
self.__header,
{**self.__datasets, dataset: new_dataset},
self.__datasets | new_datasets,
self.__links,
self.__hide_source,
)
Expand Down
10 changes: 6 additions & 4 deletions test/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,9 @@ def test_data_link_selection(halo_paths):
collection = collection.filter(oc.col("sod_halo_mass") > 10**13).take(
10, at="random"
)
collection = collection.select(["x", "y", "z"], dataset="dm_particles")
collection = collection.select(["fof_halo_tag", "sod_halo_mass"], "halo_properties")
collection = collection.select(
dm_particles=["x", "y", "z"], halo_properties=["fof_halo_tag", "sod_halo_mass"]
)
found_dm_particles = False
for halo in collection.objects():
properties = halo["halo_properties"]
Expand All @@ -578,8 +579,9 @@ def test_data_link_drop(halo_paths):
collection = collection.filter(oc.col("sod_halo_mass") > 10**13).take(
10, at="random"
)
collection = collection.drop(["x", "y", "z"], dataset="dm_particles")
collection = collection.drop(["fof_halo_tag", "sod_halo_mass"])
collection = collection.drop(
dm_particles=["x", "y", "z"], halo_properties=["fof_halo_tag", "sod_halo_mass"]
)
found_dm_particles = False
for halo in collection.objects():
properties = halo["halo_properties"]
Expand Down
Loading