Skip to content

Commit 7909367

Browse files
committed
Fix MPI writing
1 parent 0d605a9 commit 7909367

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

opencosmo/handler/mpi.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def write(
126126

127127
rank_range = self.elem_range()
128128
rank_output_length = np.sum(mask)
129+
129130
all_output_lengths = self.__comm.allgather(rank_output_length)
131+
all_input_lengths = self.__comm.allgather(len(mask))
132+
130133
rank = self.__comm.Get_rank()
131134

132135
# Determine the number of elements this rank is responsible for
@@ -158,15 +161,33 @@ def write(
158161
data = self.__group[column][rank_range[0] : rank_range[1]][()]
159162
data = data[mask]
160163
data_group[column][rank_start:rank_end] = data
161-
162-
masks = self.__comm.gather(mask, root=0)
164+
165+
displacements = np.insert(np.cumsum(all_input_lengths[:-1]), 0, 0)
163166
if rank == 0:
164-
mask = np.concatenate(masks)
167+
recvbuf = np.empty(sum(all_input_lengths), dtype=np.uint8)
168+
self.__comm.Gatherv(
169+
sendbuf=mask.view(np.uint8),
170+
recvbuf=(
171+
recvbuf.view(np.uint8),
172+
all_input_lengths,
173+
displacements,
174+
MPI.BYTE,
175+
),
176+
root=0,
177+
)
178+
else:
179+
self.__comm.Gatherv(
180+
sendbuf=mask.view(np.uint8), recvbuf=(None, None, None, None), root=0
181+
)
182+
183+
if rank == 0:
184+
mask = recvbuf.astype(bool)
165185
tree = self.__tree.apply_mask(mask)
166186
else:
167187
tree = None
168188
tree = self.__comm.bcast(tree, root=0)
169-
tree.write(group)
189+
#
190+
tree.write(group) # type: ignore
170191

171192
self.__comm.Barrier()
172193

@@ -192,7 +213,7 @@ def get_data(
192213
data = data[mask]
193214
col = Column(data, name=column)
194215
output[column] = builder.build(col)
195-
self.__comm.Barrier()
216+
self.__comm.Barrier()
196217

197218
if len(output) == 1:
198219
return next(iter(output.values()))

test/parallel/test_mpi.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,3 @@ def test_write_particles(particle_path, tmp_path):
161161
for model in models:
162162
key = f"_OpenCosmoHeader__{model}"
163163
assert getattr(header, key) == getattr(read_header, key)
164-

0 commit comments

Comments
 (0)