Skip to content

Commit 2c512f2

Browse files
committed
opt
1 parent 0773c65 commit 2c512f2

File tree

1 file changed

+2
-30
lines changed

1 file changed

+2
-30
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,6 @@
7171
is_split_model = False
7272
local_stage = None
7373

74-
group0 = None
75-
group1 = None
76-
group2 = None
77-
group3 = None
78-
7974
def manual_model_split(model,stage_idx,group):
8075
global is_split_model
8176
global local_stage
@@ -741,34 +736,11 @@ def get_mesh(pp_idx=0):
741736
if "pp" in mesh.dim_names:
742737
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
743738
return mesh
744-
global group0, group1, group2, group3
745-
if group0 is None:
746-
group0 = paddle.distributed.new_group([0, 4])
747-
if group1 is None:
748-
group1 = paddle.distributed.new_group([1, 5])
749-
if group2 is None:
750-
group2 = paddle.distributed.new_group([2, 6])
751-
if group3 is None:
752-
group3 = paddle.distributed.new_group([3, 7])
753739
rank = dist.get_rank()
754740
if rank == 0 or rank == 1 or rank == 2 or rank == 3:
755-
if rank == 0:
756-
stage = manual_model_split(model, 0, group0)
757-
elif rank == 1:
758-
stage = manual_model_split(model, 0, group1)
759-
elif rank == 2:
760-
stage = manual_model_split(model, 0, group2)
761-
else:
762-
stage = manual_model_split(model, 0, group3)
741+
stage = manual_model_split(model, 0, self.comm_group_in_pp)
763742
else:
764-
if rank == 4:
765-
stage = manual_model_split(model, 1, group0)
766-
elif rank == 5:
767-
stage = manual_model_split(model, 1, group1)
768-
elif rank == 6:
769-
stage = manual_model_split(model, 1, group2)
770-
else:
771-
stage = manual_model_split(model, 1, group3)
743+
stage = manual_model_split(model, 1, self.comm_group_in_pp)
772744

773745
schedule = Schedule1F1B(stage, n_microbatches = 2, loss_fn=self.criterion)
774746

0 commit comments

Comments
 (0)