@@ -691,7 +691,7 @@ def __init__(
691
691
init_parameters : bool = True ,
692
692
data_parallel_wrapper : Optional [DataParallelWrapper ] = None ,
693
693
use_inter_host_allreduce : bool = False ,
694
- custom_all_reduce : Optional [Callable [[torch .Tensor ], None ]] = None ,
694
+ custom_all_reduce : Optional [Callable [[List [ torch .Tensor ] ], None ]] = None ,
695
695
) -> None :
696
696
assert device .type == "cuda" , "DMPCollection only supports CUDA"
697
697
self ._device = device
@@ -701,9 +701,7 @@ def __init__(
701
701
self ._sharding_pg : dist .ProcessGroup = None # pyre-ignore[8]
702
702
self ._replica_pg : dist .ProcessGroup = None # pyre-ignore[8]
703
703
self ._global_rank : int = dist .get_rank (global_pg )
704
- self ._custom_all_reduce : Optional [Callable [[torch .Tensor ], None ]] = (
705
- custom_all_reduce
706
- )
704
+ self ._custom_all_reduce = custom_all_reduce
707
705
708
706
self ._device_mesh , self ._sharding_pg , self ._replica_pg = (
709
707
self ._create_process_groups (
@@ -790,25 +788,23 @@ def _allreduce_tensors(
790
788
We perform all reduce per tensor dtype per collective constraints.
791
789
"""
792
790
793
- def custom_all_reduce (tensors : List [torch .Tensor ]) -> None :
794
- # pyre-ignore[29]
795
- self ._custom_all_reduce (tensors )
791
+ custom_all_reduce = self ._custom_all_reduce
792
+ if custom_all_reduce is not None :
796
793
797
- def default_allreduce ( tensor_list : List [torch .Tensor ]) -> None :
798
- self . _replica_pg . allreduce_coalesced ( tensor_list , opts = opts ). wait ( )
794
+ def _all_reduce ( tensors : List [torch .Tensor ]) -> None :
795
+ custom_all_reduce ( tensors )
799
796
800
- allreduce = (
801
- custom_all_reduce
802
- if self ._custom_all_reduce is not None
803
- else default_allreduce
804
- )
797
+ else :
798
+
799
+ def _all_reduce (tensors : List [torch .Tensor ]) -> None :
800
+ self ._replica_pg .allreduce_coalesced (tensors , opts = opts ).wait ()
805
801
806
802
for tensor_list in tensors_dict .values ():
807
- allreduce (tensor_list )
803
+ _all_reduce (tensor_list )
808
804
809
805
def set_all_reduce_hook (
810
806
self ,
811
- reduce_hook : Callable [[torch .Tensor ], None ],
807
+ reduce_hook : Callable [[List [ torch .Tensor ] ], None ],
812
808
) -> None :
813
809
"""
814
810
Replace default all reduce with custom callable. Users can alternatively
@@ -817,7 +813,7 @@ def set_all_reduce_hook(
817
813
process group, and stream synchronization.
818
814
819
815
Args:
820
- reduce_hook (Callable[[torch.Tensor], torch.Tensor]): The custom all reduce function to use for
816
+ reduce_hook (Callable[[List[ torch.Tensor] ], torch.Tensor]): The custom all reduce function to use for
821
817
embedding weights and optimizer states
822
818
"""
823
819
if self ._custom_all_reduce is not None :
0 commit comments