@@ -126,7 +126,10 @@ def write(
126
126
127
127
rank_range = self .elem_range ()
128
128
rank_output_length = np .sum (mask )
129
+
129
130
all_output_lengths = self .__comm .allgather (rank_output_length )
131
+ all_input_lengths = self .__comm .allgather (len (mask ))
132
+
130
133
rank = self .__comm .Get_rank ()
131
134
132
135
# Determine the number of elements this rank is responsible for
@@ -158,15 +161,33 @@ def write(
158
161
data = self .__group [column ][rank_range [0 ] : rank_range [1 ]][()]
159
162
data = data [mask ]
160
163
data_group [column ][rank_start :rank_end ] = data
161
-
162
- masks = self . __comm . gather ( mask , root = 0 )
164
+
165
+ displacements = np . insert ( np . cumsum ( all_input_lengths [: - 1 ]), 0 , 0 )
163
166
if rank == 0 :
164
- mask = np .concatenate (masks )
167
+ recvbuf = np .empty (sum (all_input_lengths ), dtype = np .uint8 )
168
+ self .__comm .Gatherv (
169
+ sendbuf = mask .view (np .uint8 ),
170
+ recvbuf = (
171
+ recvbuf .view (np .uint8 ),
172
+ all_input_lengths ,
173
+ displacements ,
174
+ MPI .BYTE ,
175
+ ),
176
+ root = 0 ,
177
+ )
178
+ else :
179
+ self .__comm .Gatherv (
180
+ sendbuf = mask .view (np .uint8 ), recvbuf = (None , None , None , None ), root = 0
181
+ )
182
+
183
+ if rank == 0 :
184
+ mask = recvbuf .astype (bool )
165
185
tree = self .__tree .apply_mask (mask )
166
186
else :
167
187
tree = None
168
188
tree = self .__comm .bcast (tree , root = 0 )
169
- tree .write (group )
189
+ #
190
+ tree .write (group ) # type: ignore
170
191
171
192
self .__comm .Barrier ()
172
193
@@ -192,7 +213,7 @@ def get_data(
192
213
data = data [mask ]
193
214
col = Column (data , name = column )
194
215
output [column ] = builder .build (col )
195
- self .__comm .Barrier ()
216
+ self .__comm .Barrier ()
196
217
197
218
if len (output ) == 1 :
198
219
return next (iter (output .values ()))
0 commit comments