Skip to content

Commit 4cf0806

Browse files
authored
Support nested list of dict inputs (#8876)
1 parent dbf395f commit 4cf0806

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,43 @@ def _split_batches_for_accumulation(self, inputs):
158158
if self.args.gradient_accumulation_steps == 1:
159159
return [inputs]
160160

161-
# if self.args.to_static:
162161
if self.args.to_static and self.args.pipeline_parallel_degree > 1:
163162
return [inputs]
164163

165164
local_batches = [{} for i in range(self.args.gradient_accumulation_steps)]
165+
assert isinstance(inputs, dict)
166166

167-
for key, value in inputs.items():
168-
ori_mesh, ori_placements = value.process_mesh, value.placements
169-
replicate_value = dist.reshard(value, ori_mesh, [dist.Replicate(), dist.Replicate()])
167+
def split_dtensor_by_axis(dtensor, axis):
168+
mesh = dtensor.process_mesh
169+
placements = [dist.Replicate() for _ in range(len(mesh.shape))]
170+
replicate_value = dist.reshard(dtensor, mesh, placements)
170171
local_datas = replicate_value.split(self.args.gradient_accumulation_steps, axis=0)
171-
172-
for index, data in enumerate(local_datas):
173-
local_batches[index].update({key: dist.reshard(data, ori_mesh, ori_placements)})
174-
172+
return local_datas
173+
174+
for key, dtensors in inputs.items():
175+
if isinstance(dtensors, paddle.Tensor):
176+
mesh, placements = dtensors.process_mesh, dtensors.placements
177+
local_datas = split_dtensor_by_axis(dtensors, 0)
178+
for index, data in enumerate(local_datas):
179+
local_batches[index].update({key: dist.reshard(data, mesh, placements)})
180+
elif isinstance(dtensors, (list, tuple)):
181+
if len(dtensors) == 0:
182+
for i in range(self.args.gradient_accumulation_steps):
183+
local_batches[i].update({key: []})
184+
else:
185+
for dtensor in dtensors:
186+
if isinstance(dtensor, paddle.Tensor):
187+
mesh, placements = dtensor.process_mesh, dtensor.placements
188+
local_datas = split_dtensor_by_axis(dtensor, 0)
189+
for index, data in enumerate(local_datas):
190+
if key in local_batches[index].keys():
191+
local_batches[index][key].append(dist.reshard(data, mesh, placements))
192+
else:
193+
local_batches[index].update({key: [dist.reshard(data, mesh, placements)]})
194+
else:
195+
raise ValueError(f"unsupported type: {type(dtensor)}")
196+
else:
197+
raise ValueError(f"unsupported type: {type(dtensors)}")
175198
return local_batches
176199

177200
def _inner_training_loop(

0 commit comments

Comments
 (0)