Skip to content

Commit 822248f

Browse files
committed
Mapping over simulation collection is working
1 parent fe969c8 commit 822248f

File tree

3 files changed

+61
-19
lines changed

3 files changed

+61
-19
lines changed

src/opencosmo/collection/simulation/simulation.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,30 @@ def make_schema(self) -> DataSchema:
8282

8383
return schema
8484

85-
def __map(self, method, *args, **kwargs):
85+
def __map(self, method, *args, construct=True, **kwargs):
8686
"""
8787
This type of collection will only ever be constructed if all the underlying
8888
datasets have the same data type, so it is always safe to map operations
8989
across all of them.
9090
"""
91-
output = {k: getattr(v, method)(*args, **kwargs) for k, v in self.items()}
92-
return SimulationCollection(output)
91+
regular_kwargs = {}
92+
mapped_kwargs = {}
93+
known_datasets = set(self.keys())
94+
for name, value in kwargs.items():
95+
if isinstance(value, dict) and set(value.keys()) == known_datasets:
96+
mapped_kwargs[name] = value
97+
else:
98+
regular_kwargs[name] = value
99+
100+
output = {}
101+
for name, dataset in self.items():
102+
dataset_mapped_kwargs = {key: kw[name] for key, kw in mapped_kwargs.items()}
103+
output[name] = getattr(dataset, method)(
104+
*args, **regular_kwargs, **dataset_mapped_kwargs
105+
)
106+
if construct:
107+
return SimulationCollection(output)
108+
return output
93109

94110
def __map_attribute(self, attribute):
95111
return {k: getattr(v, attribute) for k, v in self.items()}
@@ -321,21 +337,16 @@ def evaluate(
321337
else:
322338
datasets = list(datasets)
323339

324-
results = {
325-
ds_name: self[ds_name].evaluate(
326-
func,
327-
vectorize=vectorize,
328-
insert=insert,
329-
format=format,
330-
**evaluate_kwargs,
331-
)
332-
for ds_name in datasets
333-
}
334-
if not insert:
335-
return results
336-
else:
337-
output = {**self, **results}
338-
return SimulationCollection(output)
340+
results = self.__map(
341+
"evaluate",
342+
func,
343+
vectorize=vectorize,
344+
insert=insert,
345+
format=format,
346+
construct=insert,
347+
**evaluate_kwargs,
348+
)
349+
return results
339350

340351
def with_units(self, convention: str) -> Self:
341352
"""

src/opencosmo/dataset/visit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __make_output(
7272
first_values = {name: first_values}
7373
storage = {}
7474
for name, value in first_values.items():
75-
shape = (n,)
75+
shape: tuple[int, ...] = (n,)
7676
dtype = type(value)
7777
if isinstance(value, np.ndarray):
7878
shape = shape + value.shape

test/test_collection.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,37 @@ def fof_px(fof_halo_mass, fof_halo_com_vx):
539539
)
540540

541541

542+
def test_simulation_collection_evaluate_map_kwarg(multi_path):
543+
collection = oc.open(multi_path)
544+
545+
def fof_px(fof_halo_mass, fof_halo_com_vx, random_value, other_value):
546+
return fof_halo_mass * fof_halo_com_vx * random_value / other_value
547+
548+
random_data = {
549+
key: np.random.randint(0, 10, len(ds)) for key, ds in collection.items()
550+
}
551+
random_val = 3.0
552+
553+
output = collection.evaluate(
554+
fof_px,
555+
vectorize=True,
556+
insert=False,
557+
format="numpy",
558+
random_value=random_data,
559+
other_value=random_val,
560+
)
561+
for ds_name, ds in collection.items():
562+
assert "fof_px" not in ds.columns
563+
data = ds.select(["fof_halo_mass", "fof_halo_com_vx"]).get_data("numpy")
564+
assert np.all(
565+
output[ds_name]["fof_px"]
566+
== data["fof_halo_mass"]
567+
* data["fof_halo_com_vx"]
568+
* random_data[ds_name]
569+
/ random_val
570+
)
571+
572+
542573
def test_simulation_collection_add(multi_path):
543574
collection = oc.open(multi_path)
544575
ds_name = next(iter(collection.keys()))

0 commit comments

Comments
 (0)