Skip to content

Commit eba40dc

Browse files
committed
Small updates
1 parent 00643da commit eba40dc

File tree

5 files changed

+182
-257
lines changed

5 files changed

+182
-257
lines changed

opencosmo/io/mpi.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from mpi4py import MPI
2+
import numpy as np
3+
from .schemas import FileSchema, DatasetSchema, SimCollectionSchema, StructCollectionSchema, ColumnSchema, IdxLinkSchema, StartSizeLinkSchema
4+
from .protocols import DataSchema
5+
from copy import copy
6+
7+
def verify_structure(schemas: dict[str, DataSchema], comm: MPI.Comm):
8+
verify_names(schemas, comm)
9+
verify_types(schemas, comm)
10+
11+
12+
def verify_names(schemas: dict[str, DataSchema], comm: MPI.Comm):
13+
names = set(schemas.keys())
14+
all_names = comm.allgather(names)
15+
if not all(ns == all_names[0] for ns in all_names[1:]):
16+
raise ValueError("Tried to combine a collection of schemas with different names!")
17+
18+
def verify_types(schemas: dict[str, DataSchema], comm: MPI.Comm):
19+
types = list(str(type(c)) for c in schemas.values())
20+
types.sort()
21+
all_types = comm.allgather(types)
22+
if not all(ts == all_types[0] for ts in all_types[1:]):
23+
raise ValueError("Tried to combine a collection of schemas with different types!")
24+
25+
26+
def combine_file_schemas(schema: FileSchema, comm: MPI.Comm = MPI.COMM_WORLD):
27+
verify_structure(schema.children, comm)
28+
29+
if comm.Get_rank() == 0:
30+
new_schema = FileSchema()
31+
else:
32+
new_schema = None
33+
34+
for child_name in schema.children:
35+
child_name_ = comm.bcast(child_name)
36+
child = schema.children[child_name_]
37+
new_child = combine_file_child(child, comm)
38+
if comm.Get_rank() == 0:
39+
new_schema.add_child(new_child, child_name)
40+
return new_schema
41+
42+
def combine_file_child(schema: DataSchema, comm: MPI.Comm):
43+
match schema:
44+
case DatasetSchema():
45+
return combine_dataset_schemas(schema, comm)
46+
case SimCollectionSchema():
47+
return combine_simcollection_schema(schema, comm)
48+
case StructCollectionSchema():
49+
return combine_structcollection_schema(schema, comm)
50+
51+
52+
def combine_dataset_schemas(schema: DatasetSchema, comm: MPI.Comm):
53+
rank = comm.Get_rank()
54+
verify_structure(schema.columns, comm)
55+
verify_structure(schema.links, comm)
56+
57+
if rank == 0:
58+
new_schema = DatasetSchema(schema.header)
59+
else:
60+
new_schema = None
61+
62+
for colname in schema.columns.keys():
63+
colname_ = comm.bcast(colname)
64+
new_column = combine_column_schemas(schema.columns[colname_], comm)
65+
if rank == 0:
66+
new_schema.add_child(new_column, colname)
67+
68+
69+
if len(schema.links) > 0:
70+
new_links = combine_links(schema.links, comm)
71+
if rank == 0:
72+
for name, link in new_links.items():
73+
new_schema.add_child(link, name)
74+
return new_schema
75+
76+
def combine_links(links: dict[str, StartSizeLinkSchema | IdxLinkSchema], comm: MPI.Comm):
77+
new_links = {}
78+
for name, link in links.items():
79+
if isinstance(link, StartSizeLinkSchema):
80+
new_links[name] = combine_start_size_link_schema(link, comm)
81+
else:
82+
new_links[name] = combine_idx_link_schema(link, comm)
83+
84+
return new_links
85+
86+
def combine_idx_link_schema(schema: IdxLinkSchema, comm: MPI.Comm):
87+
column_schema = combine_column_schemas(schema.column, comm)
88+
new_schema = copy(schema)
89+
new_schema.column = column_schema
90+
return new_schema
91+
92+
def combine_start_size_link_schema(schema: StartSizeLinkSchema, comm: MPI.Comm):
93+
start_column_schema = combine_column_schemas(schema.start, comm)
94+
size_column_schema = combine_column_schemas(schema.size, comm)
95+
new_schema = copy(schema)
96+
new_schema.start = start_column_schema
97+
new_schema.size = size_column_schema
98+
return new_schema
99+
100+
def combine_simcollection_schema(schema: SimCollectionSchema, comm: MPI.Comm):
101+
rank = comm.Get_rank()
102+
verify_structure(schema.children, comm)
103+
104+
child_names = schema.children.keys()
105+
106+
if rank == 0:
107+
new_schema = SimCollectionSchema()
108+
109+
else:
110+
new_schema = None
111+
112+
for child_name in child_names:
113+
child_name_ = comm.bcast(child_name)
114+
child = schema.children[child_name_]
115+
match child:
116+
case StructCollectionSchema():
117+
new_child = combine_structcollection_schema(child, comm)
118+
case DatasetSchema():
119+
new_child = combine_dataset_schemas(schema, comm)
120+
if rank == 0:
121+
new_schema.add_child(new_child, child_name)
122+
return new_schema
123+
124+
125+
126+
127+
def combine_structcollection_schema(schema: StructCollectionSchema, comm: MPI.Comm):
128+
rank = comm.Get_rank()
129+
child_names = set(schema.children.keys())
130+
all_child_names = comm.allgather(child_names)
131+
if not all(cns == all_child_names[0] for cns in all_child_names[1:]):
132+
raise ValueError("Tried to combine ismulation collections with different children!")
133+
134+
child_types = set(str(type(c)) for c in schema.children.values())
135+
all_child_types = comm.allgather(child_types)
136+
if not all(cts == all_child_types[0] for cts in all_child_types[1:]):
137+
raise ValueError("Tried to combine ismulation collections with different children!")
138+
139+
new_schema = StructCollectionSchema(schema.header) if rank == 0 else None
140+
child_names = list(child_names)
141+
child_names.sort()
142+
143+
144+
for i, name in enumerate(child_names):
145+
cn = comm.bcast(name)
146+
child = schema.children[cn]
147+
match child:
148+
case DatasetSchema():
149+
150+
new_child = combine_dataset_schemas(child, comm)
151+
if rank == 0:
152+
new_schema.add_child(new_child, cn)
153+
154+
return new_schema
155+
156+
157+
158+
159+
160+
161+
162+
def combine_column_schemas(schema: ColumnSchema, comm: MPI.Comm):
163+
rank = comm.Get_rank()
164+
lengths = comm.allgather(len(schema.index))
165+
rank_offsets = np.insert(np.cumsum(lengths), 0, 0)[:-1]
166+
rank_offset = rank_offsets[rank]
167+
schema.set_offset(rank_offset)
168+
169+
indices = comm.allgather(schema.index)
170+
new_index = indices[0].concatenate(*indices[1:])
171+
172+
return ColumnSchema(schema.name, new_index, schema.source)
173+
174+
175+
176+

opencosmo/structure/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from .collection import StructureCollection
2-
from .handler import LinkHandler, OomLinkHandler
2+
from .handler import LinkHandler, LinkedDatasetHandler
33
from .io import open_linked_file, open_linked_files
44

55
__all__ = [
66
"StructureCollection",
77
"LinkHandler",
8+
"LinkedDatasetHandler",
89
"OomLinkHandler",
910
"open_linked_files",
1011
"open_linked_file",

opencosmo/structure/handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def with_units(self, convention: str) -> LinkHandler:
7979
pass
8080

8181

82-
class OomLinkHandler:
82+
class LinkedDatasetHandler:
8383
"""
8484
Links are currently only supported out-of-memory.
8585
"""

opencosmo/structure/io.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@
1111

1212
from opencosmo.io import schemas as ios
1313

14-
try:
15-
from mpi4py import MPI
16-
17-
from opencosmo.structure.mpi import MpiLinkHandler
18-
except ImportError:
19-
MPI = None # type: ignore
20-
2114
LINK_ALIASES = { # Left: Name in file, right: Name in collection
2215
"sodbighaloparticles_star_particles": "star_particles",
2316
"sodbighaloparticles_dm_particles": "dm_particles",
@@ -93,6 +86,7 @@ def open_linked_files(*files: Path):
9386
"""
9487
Open a collection of files that are linked together, such as a
9588
properties file and a particle file.
89+
9690
"""
9791
if len(files) == 1 and isinstance(files[0], list):
9892
return open_linked_files(*files[0])
@@ -201,11 +195,6 @@ def get_link_handlers(
201195
raise KeyError("No linked datasets found in the file.")
202196
links = link_file["data_linked"]
203197

204-
handler: Type[s.LinkHandler]
205-
if MPI is not None and MPI.COMM_WORLD.Get_size() > 1:
206-
handler = MpiLinkHandler
207-
else:
208-
handler = s.OomLinkHandler
209198
unique_dtypes = {key.rsplit("_", 1)[0] for key in links.keys()}
210199
output_links = {}
211200
for dtype in unique_dtypes:
@@ -219,10 +208,10 @@ def get_link_handlers(
219208
start = links[f"{dtype}_start"]
220209
size = links[f"{dtype}_size"]
221210

222-
output_links[key] = handler(linked_files[key], (start, size), header)
211+
output_links[key] = s.LinkedDatasetHandler(linked_files[key], (start, size), header)
223212
except KeyError:
224213
index = links[f"{dtype}_idx"]
225-
output_links[key] = handler(linked_files[key], index, header)
214+
output_links[key] = s.LinkedDatasetHandler(linked_files[key], index, header)
226215
return output_links
227216

228217

0 commit comments

Comments
 (0)