|
71 | 71 | is_split_model = False
|
72 | 72 | local_stage = None
|
73 | 73 |
|
74 |
| -group0 = None |
75 |
| -group1 = None |
76 |
| -group2 = None |
77 |
| -group3 = None |
78 |
| - |
79 | 74 | def manual_model_split(model,stage_idx,group):
|
80 | 75 | global is_split_model
|
81 | 76 | global local_stage
|
@@ -741,34 +736,11 @@ def get_mesh(pp_idx=0):
|
741 | 736 | if "pp" in mesh.dim_names:
|
742 | 737 | mesh = mesh.get_mesh_with_dim("pp", pp_idx)
|
743 | 738 | return mesh
|
744 |
| - global group0, group1, group2, group3 |
745 |
| - if group0 is None: |
746 |
| - group0 = paddle.distributed.new_group([0, 4]) |
747 |
| - if group1 is None: |
748 |
| - group1 = paddle.distributed.new_group([1, 5]) |
749 |
| - if group2 is None: |
750 |
| - group2 = paddle.distributed.new_group([2, 6]) |
751 |
| - if group3 is None: |
752 |
| - group3 = paddle.distributed.new_group([3, 7]) |
753 | 739 | rank = dist.get_rank()
|
754 | 740 | if rank == 0 or rank == 1 or rank == 2 or rank == 3:
|
755 |
| - if rank == 0: |
756 |
| - stage = manual_model_split(model, 0, group0) |
757 |
| - elif rank == 1: |
758 |
| - stage = manual_model_split(model, 0, group1) |
759 |
| - elif rank == 2: |
760 |
| - stage = manual_model_split(model, 0, group2) |
761 |
| - else: |
762 |
| - stage = manual_model_split(model, 0, group3) |
| 741 | + stage = manual_model_split(model, 0, self.comm_group_in_pp) |
763 | 742 | else:
|
764 |
| - if rank == 4: |
765 |
| - stage = manual_model_split(model, 1, group0) |
766 |
| - elif rank == 5: |
767 |
| - stage = manual_model_split(model, 1, group1) |
768 |
| - elif rank == 6: |
769 |
| - stage = manual_model_split(model, 1, group2) |
770 |
| - else: |
771 |
| - stage = manual_model_split(model, 1, group3) |
| 743 | + stage = manual_model_split(model, 1, self.comm_group_in_pp) |
772 | 744 |
|
773 | 745 | schedule = Schedule1F1B(stage, n_microbatches = 2, loss_fn=self.criterion)
|
774 | 746 |
|
|
0 commit comments