diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 0ec5df1a6b0ae..5269c03ba6b41 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -32,15 +32,11 @@ from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.fabric.strategies.parallel import ParallelStrategy from lightning.fabric.strategies.registry import _StrategyRegistry -from lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl +from lightning.fabric.strategies.strategy import _BackwardSyncControl from lightning.fabric.utilities.distributed import ( - ReduceOp, - _distributed_is_initialized, _get_default_process_group_backend_for_device, _init_dist_connection, - _sync_ddp_if_available, ) -from lightning.fabric.utilities.distributed import group as _group from lightning.fabric.utilities.rank_zero import rank_zero_only _DDP_FORK_ALIASES = ( @@ -132,44 +128,6 @@ def setup_module(self, module: Module) -> DistributedDataParallel: def module_to_device(self, module: Module) -> None: module.to(self.root_device) - @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - - """ - if isinstance(tensor, Tensor): - return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - - @override - def barrier(self, *args: Any, **kwargs: Any) -> None: - if not _distributed_is_initialized(): - return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) - else: - torch.distributed.barrier() - - @override - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not _distributed_is_initialized(): - return obj - - obj = [obj] - torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) - return obj[0] - @override def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: if isinstance(module, DistributedDataParallel): diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index bd3548a22be9b..bc09870425ac4 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -48,20 +48,15 @@ from lightning.fabric.strategies.parallel import ParallelStrategy from lightning.fabric.strategies.registry import _StrategyRegistry from lightning.fabric.strategies.strategy import ( - TBroadcast, _apply_filter, _BackwardSyncControl, _Sharded, _validate_keys_for_strict_loading, ) from lightning.fabric.utilities.distributed import ( - ReduceOp, - _distributed_is_initialized, _get_default_process_group_backend_for_device, _init_dist_connection, - _sync_ddp_if_available, ) -from lightning.fabric.utilities.distributed import group as _group from lightning.fabric.utilities.imports import ( _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2, @@ -351,32 +346,6 @@ def module_sharded_context(self) -> ContextManager: **self._fsdp_kwargs, ) - @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: - if isinstance(tensor, Tensor): - return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - - @override - def barrier(self, *args: Any, **kwargs: Any) -> None: - if not _distributed_is_initialized(): - return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=[self.root_device.index]) - else: - torch.distributed.barrier() - - @override - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not _distributed_is_initialized(): - return obj - - obj = [obj] - torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) - return obj[0] - @override def clip_gradients_norm( self, diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 4141ea454ca51..e2b2fc29eda11 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -20,7 +20,6 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only -from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from typing_extensions import TypeGuard, override @@ -37,19 +36,14 @@ from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.fabric.strategies.parallel import ParallelStrategy from lightning.fabric.strategies.strategy import ( - TBroadcast, _apply_filter, _BackwardSyncControl, _validate_keys_for_strict_loading, ) from lightning.fabric.utilities.distributed import ( - ReduceOp, - _distributed_is_initialized, _get_default_process_group_backend_for_device, _init_dist_connection, - _sync_ddp_if_available, ) -from lightning.fabric.utilities.distributed import group as _group from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.init import _materialize_distributed_module from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _move_state_into @@ -198,32 +192,6 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag stack.enter_context(precision_init_ctx) return stack - @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: - if isinstance(tensor, Tensor): - return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - - @override - def barrier(self, *args: Any, **kwargs: Any) -> None: - if not _distributed_is_initialized(): - return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=[self.root_device.index]) - else: - torch.distributed.barrier() - - @override - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not _distributed_is_initialized(): - return obj - - obj = [obj] - torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) - return obj[0] - @override def save_checkpoint( self, diff --git a/src/lightning/fabric/strategies/parallel.py b/src/lightning/fabric/strategies/parallel.py index a12a0611c90ab..1aa047a59f327 100644 --- a/src/lightning/fabric/strategies/parallel.py +++ b/src/lightning/fabric/strategies/parallel.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch from torch import Tensor from typing_extensions import override from lightning.fabric.accelerators.accelerator import Accelerator +from lightning.fabric.plugins.collectives import Collective, TorchCollective from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO from lightning.fabric.plugins.precision import Precision -from lightning.fabric.strategies.strategy import Strategy -from lightning.fabric.utilities.distributed import _all_gather_ddp_if_available +from lightning.fabric.strategies.strategy import Strategy, TBroadcast +from lightning.fabric.utilities.distributed import _all_gather_if_available, _all_reduce_if_available from lightning.fabric.utilities.types import ReduceOp @@ -37,10 +38,12 @@ def __init__( cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, + collective: Optional[Collective] = None, ): super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) self.parallel_devices = parallel_devices self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment + self.collective: Collective = collective if collective is not None else TorchCollective() @property def global_rank(self) -> int: @@ -82,8 +85,28 @@ def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: @override def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """Perform a all_gather on all processes.""" - return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + return _all_gather_if_available(tensor, collective=self.collective, sync_grads=sync_grads) + + @override + def all_reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: + return _all_reduce_if_available(tensor, collective=self.collective, reduce_op=reduce_op) + + @override + def barrier(self, *args: Any, **kwargs: Any) -> None: + if not self.collective.is_initialized(): + return + self.collective.barrier(device_ids=([self.root_device.index] if self.root_device.index is not None else None)) + + @override + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + if not self.collective.is_initialized(): + return obj + + object_list = [obj] + self.collective.broadcast_object_list(object_list=object_list, src=src, device=self.root_device) + return object_list[0] @override def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: @@ -111,4 +134,5 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: def teardown(self) -> None: assert self.cluster_environment is not None self.cluster_environment.teardown() + self.collective.teardown() # TODO: is this desired? return super().teardown() diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 30bfe4e254a07..e3b6b1b5cf841 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -14,6 +14,7 @@ from torch.utils.data import Dataset, DistributedSampler, Sampler from typing_extensions import Self, override +from lightning.fabric.plugins.collectives import Collective from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.data import _num_cpus_available from lightning.fabric.utilities.rank_zero import rank_zero_info @@ -94,6 +95,7 @@ def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, tim return all_found +# TODO: This function has no usages in our code base def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: """Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes. @@ -148,40 +150,23 @@ def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Ten return gathered_result +# TODO: This function has no usages in our code base def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) return gathered_result -def _sync_ddp_if_available( - result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None +def _all_reduce_if_available( + tensor: Tensor, collective: Collective, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> Tensor: - """Function to reduce a tensor across worker processes during distributed training. - - Args: - result: The value to sync and reduce (typically tensor or number) - group: The process group to gather results from. Defaults to all processes (world) - reduce_op: The reduction operation. Defaults to sum. - Can also be a string of 'avg', 'mean' to calculate the mean during reduction. - - Return: - reduced value - - """ - if _distributed_is_initialized(): - return _sync_ddp(result, group=group, reduce_op=reduce_op) - return result - - -def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor: """Reduces a tensor across several distributed processes. This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes. Args: - result: The value to sync and reduce (typically tensor or number) - group: The process group to gather results from. Defaults to all processes (world) + tensor: The value to sync and reduce (typically tensor or number) + collective: The collective backend to use for the all-reduce. reduce_op: The reduction operation. Defaults to sum. Can also be a string of 'avg', 'mean' to calculate the mean during reduction. @@ -189,13 +174,15 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U The reduced value. """ + if not collective.is_initialized(): + return tensor + divide_by_world_size = False - group = torch.distributed.group.WORLD if group is None else group op: Optional[ReduceOp] if isinstance(reduce_op, str): reduce_op = "avg" if reduce_op == "mean" else reduce_op - if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo": + if reduce_op.lower() == "avg" and torch.distributed.get_backend() == "gloo": # The GLOO backend does not support the `ReduceOp.AVG` operation op = ReduceOp.SUM # type: ignore[assignment] divide_by_world_size = True @@ -209,46 +196,44 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U if ( package_available("habana_frameworks") and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" - and result.type() + and tensor.type() in ( "torch.LongTensor", "torch.hpu.LongTensor", ) ): rank_zero_info("Long tensor unsupported on HPU, casting to float") - result = result.float() + tensor = tensor.float() # Sync all processes before reduction - torch.distributed.barrier(group=group) - torch.distributed.all_reduce(result, op=op, group=group, async_op=False) - world_size = torch.distributed.get_world_size(group) + collective.barrier() + collective.all_reduce(tensor, op=op) if not divide_by_world_size: - return result - # `torch.distributed.all_reduce` is in-place, so we should do the division in-place to leave the modified tensors + return tensor + # `all_reduce` is in-place, so we should do the division in-place to leave the modified tensors # with the expected value - if not torch.is_floating_point(result): - return result.copy_(result / world_size) - return result.div_(world_size) + if not torch.is_floating_point(tensor): + return tensor.copy_(tensor / collective.world_size) + return tensor.div_(collective.world_size) -def _all_gather_ddp_if_available( - tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False -) -> Tensor: +def _all_gather_if_available(tensor: Tensor, collective: Collective, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. Args: tensor: Tensor of shape (batch, ...) - group: The process group to gather results from. Defaults to all processes (world) + collective: The collective backend to use for the all-gather. sync_grads: Flag that allows users to synchronize gradients for all_gather op Return: A tensor of shape (world_size, batch, ...) """ - if not _distributed_is_initialized(): + if not collective.is_initialized(): return tensor + # TODO: Enable all-gather with grads in TorchCollective from torch.distributed.nn.functional import all_gather tensor = tensor.contiguous() # https://github.com/pytorch/pytorch/issues/73515