@@ -158,20 +158,43 @@ def _split_batches_for_accumulation(self, inputs):
158
158
if self .args .gradient_accumulation_steps == 1 :
159
159
return [inputs ]
160
160
161
- # if self.args.to_static:
162
161
if self .args .to_static and self .args .pipeline_parallel_degree > 1 :
163
162
return [inputs ]
164
163
165
164
local_batches = [{} for i in range (self .args .gradient_accumulation_steps )]
165
+ assert isinstance (inputs , dict )
166
166
167
- for key , value in inputs .items ():
168
- ori_mesh , ori_placements = value .process_mesh , value .placements
169
- replicate_value = dist .reshard (value , ori_mesh , [dist .Replicate (), dist .Replicate ()])
167
+ def split_dtensor_by_axis (dtensor , axis ):
168
+ mesh = dtensor .process_mesh
169
+ placements = [dist .Replicate () for _ in range (len (mesh .shape ))]
170
+ replicate_value = dist .reshard (dtensor , mesh , placements )
170
171
local_datas = replicate_value .split (self .args .gradient_accumulation_steps , axis = 0 )
171
-
172
- for index , data in enumerate (local_datas ):
173
- local_batches [index ].update ({key : dist .reshard (data , ori_mesh , ori_placements )})
174
-
172
+ return local_datas
173
+
174
+ for key , dtensors in inputs .items ():
175
+ if isinstance (dtensors , paddle .Tensor ):
176
+ mesh , placements = dtensors .process_mesh , dtensors .placements
177
+ local_datas = split_dtensor_by_axis (dtensors , 0 )
178
+ for index , data in enumerate (local_datas ):
179
+ local_batches [index ].update ({key : dist .reshard (data , mesh , placements )})
180
+ elif isinstance (dtensors , (list , tuple )):
181
+ if len (dtensors ) == 0 :
182
+ for i in range (self .args .gradient_accumulation_steps ):
183
+ local_batches [i ].update ({key : []})
184
+ else :
185
+ for dtensor in dtensors :
186
+ if isinstance (dtensor , paddle .Tensor ):
187
+ mesh , placements = dtensor .process_mesh , dtensor .placements
188
+ local_datas = split_dtensor_by_axis (dtensor , 0 )
189
+ for index , data in enumerate (local_datas ):
190
+ if key in local_batches [index ].keys ():
191
+ local_batches [index ][key ].append (dist .reshard (data , mesh , placements ))
192
+ else :
193
+ local_batches [index ].update ({key : [dist .reshard (data , mesh , placements )]})
194
+ else :
195
+ raise ValueError (f"unsupported type: { type (dtensor )} " )
196
+ else :
197
+ raise ValueError (f"unsupported type: { type (dtensors )} " )
175
198
return local_batches
176
199
177
200
def _inner_training_loop (
0 commit comments