Skip to content

Commit 3a74a84

Browse files
xunnanxufacebook-github-bot
authored andcommitted
fix custom AR typing in DMPC (pytorch#2815)
Summary: Pull Request resolved: pytorch#2815 this takes a list of tensors instead of a single one Reviewed By: iamzainhuda Differential Revision: D71131639 fbshipit-source-id: 3701912c4ba286e25a5e806310818b2fabb0c471
1 parent 7476e8e commit 3a74a84

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def __init__(
691691
init_parameters: bool = True,
692692
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
693693
use_inter_host_allreduce: bool = False,
694-
custom_all_reduce: Optional[Callable[[torch.Tensor], None]] = None,
694+
custom_all_reduce: Optional[Callable[[List[torch.Tensor]], None]] = None,
695695
) -> None:
696696
assert device.type == "cuda", "DMPCollection only supports CUDA"
697697
self._device = device
@@ -701,9 +701,7 @@ def __init__(
701701
self._sharding_pg: dist.ProcessGroup = None # pyre-ignore[8]
702702
self._replica_pg: dist.ProcessGroup = None # pyre-ignore[8]
703703
self._global_rank: int = dist.get_rank(global_pg)
704-
self._custom_all_reduce: Optional[Callable[[torch.Tensor], None]] = (
705-
custom_all_reduce
706-
)
704+
self._custom_all_reduce = custom_all_reduce
707705

708706
self._device_mesh, self._sharding_pg, self._replica_pg = (
709707
self._create_process_groups(
@@ -790,25 +788,23 @@ def _allreduce_tensors(
790788
We perform all reduce per tensor dtype per collective constraints.
791789
"""
792790

793-
def custom_all_reduce(tensors: List[torch.Tensor]) -> None:
794-
# pyre-ignore[29]
795-
self._custom_all_reduce(tensors)
791+
custom_all_reduce = self._custom_all_reduce
792+
if custom_all_reduce is not None:
796793

797-
def default_allreduce(tensor_list: List[torch.Tensor]) -> None:
798-
self._replica_pg.allreduce_coalesced(tensor_list, opts=opts).wait()
794+
def _all_reduce(tensors: List[torch.Tensor]) -> None:
795+
custom_all_reduce(tensors)
799796

800-
allreduce = (
801-
custom_all_reduce
802-
if self._custom_all_reduce is not None
803-
else default_allreduce
804-
)
797+
else:
798+
799+
def _all_reduce(tensors: List[torch.Tensor]) -> None:
800+
self._replica_pg.allreduce_coalesced(tensors, opts=opts).wait()
805801

806802
for tensor_list in tensors_dict.values():
807-
allreduce(tensor_list)
803+
_all_reduce(tensor_list)
808804

809805
def set_all_reduce_hook(
810806
self,
811-
reduce_hook: Callable[[torch.Tensor], None],
807+
reduce_hook: Callable[[List[torch.Tensor]], None],
812808
) -> None:
813809
"""
814810
Replace default all reduce with custom callable. Users can alternatively
@@ -817,7 +813,7 @@ def set_all_reduce_hook(
817813
process group, and stream synchronization.
818814
819815
Args:
820-
reduce_hook (Callable[[torch.Tensor], torch.Tensor]): The custom all reduce function to use for
816+
reduce_hook (Callable[[List[torch.Tensor]], torch.Tensor]): The custom all reduce function to use for
821817
embedding weights and optimizer states
822818
"""
823819
if self._custom_all_reduce is not None:

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
504504
sharders=sharders,
505505
device=ctx.device,
506506
use_inter_host_allreduce=use_inter_host_allreduce,
507-
custom_all_reduce=all_reduce_func, # pyre-ignore[6]
507+
custom_all_reduce=all_reduce_func,
508508
)
509509
else:
510510
local_model = DistributedModelParallel(

0 commit comments

Comments
 (0)