Skip to content

WIP: Integrate Collective into strategies #19881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 1 addition & 43 deletions src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,11 @@
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl
from lightning.fabric.strategies.strategy import _BackwardSyncControl
from lightning.fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.rank_zero import rank_zero_only

_DDP_FORK_ALIASES = (
Expand Down Expand Up @@ -132,44 +128,6 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
def module_to_device(self, module: Module) -> None:
module.to(self.root_device)

@override
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.

Return:
reduced value, except when the input was not a tensor the output remains is unchanged

"""
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
else:
torch.distributed.barrier()

@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj

obj = [obj]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

@override
def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]:
if isinstance(module, DistributedDataParallel):
Expand Down
31 changes: 0 additions & 31 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,15 @@
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import (
TBroadcast,
_apply_filter,
_BackwardSyncControl,
_Sharded,
_validate_keys_for_strict_loading,
)
from lightning.fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
Expand Down Expand Up @@ -351,32 +346,6 @@ def module_sharded_context(self) -> ContextManager:
**self._fsdp_kwargs,
)

@override
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=[self.root_device.index])
else:
torch.distributed.barrier()

@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj

obj = [obj]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

@override
def clip_gradients_norm(
self,
Expand Down
32 changes: 0 additions & 32 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import TypeGuard, override
Expand All @@ -37,19 +36,14 @@
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.strategy import (
TBroadcast,
_apply_filter,
_BackwardSyncControl,
_validate_keys_for_strict_loading,
)
from lightning.fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.init import _materialize_distributed_module
from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _move_state_into
Expand Down Expand Up @@ -198,32 +192,6 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag
stack.enter_context(precision_init_ctx)
return stack

@override
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=[self.root_device.index])
else:
torch.distributed.barrier()

@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj

obj = [obj]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

@override
def save_checkpoint(
self,
Expand Down
34 changes: 29 additions & 5 deletions src/lightning/fabric/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.plugins.collectives import Collective, TorchCollective
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.strategy import Strategy
from lightning.fabric.utilities.distributed import _all_gather_ddp_if_available
from lightning.fabric.strategies.strategy import Strategy, TBroadcast
from lightning.fabric.utilities.distributed import _all_gather_if_available, _all_reduce_if_available
from lightning.fabric.utilities.types import ReduceOp


Expand All @@ -37,10 +38,12 @@ def __init__(
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
collective: Optional[Collective] = None,
):
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)
self.parallel_devices = parallel_devices
self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment
self.collective: Collective = collective if collective is not None else TorchCollective()

@property
def global_rank(self) -> int:
Expand Down Expand Up @@ -82,8 +85,28 @@ def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]:

@override
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Perform a all_gather on all processes."""
return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
return _all_gather_if_available(tensor, collective=self.collective, sync_grads=sync_grads)

@override
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
return _all_reduce_if_available(tensor, collective=self.collective, reduce_op=reduce_op)

@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not self.collective.is_initialized():
return
self.collective.barrier(device_ids=([self.root_device.index] if self.root_device.index is not None else None))

@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not self.collective.is_initialized():
return obj

object_list = [obj]
self.collective.broadcast_object_list(object_list=object_list, src=src, device=self.root_device)
return object_list[0]

@override
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
Expand Down Expand Up @@ -111,4 +134,5 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
def teardown(self) -> None:
assert self.cluster_environment is not None
self.cluster_environment.teardown()
self.collective.teardown() # TODO: is this desired?
return super().teardown()
63 changes: 24 additions & 39 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.utils.data import Dataset, DistributedSampler, Sampler
from typing_extensions import Self, override

from lightning.fabric.plugins.collectives import Collective
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.rank_zero import rank_zero_info
Expand Down Expand Up @@ -94,6 +95,7 @@ def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, tim
return all_found


# TODO: This function has no usages in our code base
def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
"""Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes.

Expand Down Expand Up @@ -148,54 +150,39 @@ def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Ten
return gathered_result


# TODO: This function has no usages in our code base
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result


def _sync_ddp_if_available(
result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
def _all_reduce_if_available(
tensor: Tensor, collective: Collective, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
"""Function to reduce a tensor across worker processes during distributed training.

Args:
result: The value to sync and reduce (typically tensor or number)
group: The process group to gather results from. Defaults to all processes (world)
reduce_op: The reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value

"""
if _distributed_is_initialized():
return _sync_ddp(result, group=group, reduce_op=reduce_op)
return result


def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
"""Reduces a tensor across several distributed processes.

This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes.

Args:
result: The value to sync and reduce (typically tensor or number)
group: The process group to gather results from. Defaults to all processes (world)
tensor: The value to sync and reduce (typically tensor or number)
collective: The collective backend to use for the all-reduce.
reduce_op: The reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
The reduced value.

"""
if not collective.is_initialized():
return tensor

divide_by_world_size = False
group = torch.distributed.group.WORLD if group is None else group

op: Optional[ReduceOp]
if isinstance(reduce_op, str):
reduce_op = "avg" if reduce_op == "mean" else reduce_op
if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo":
if reduce_op.lower() == "avg" and torch.distributed.get_backend() == "gloo":
# The GLOO backend does not support the `ReduceOp.AVG` operation
op = ReduceOp.SUM # type: ignore[assignment]
divide_by_world_size = True
Expand All @@ -209,46 +196,44 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
if (
package_available("habana_frameworks")
and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
and result.type()
and tensor.type()
in (
"torch.LongTensor",
"torch.hpu.LongTensor",
)
):
rank_zero_info("Long tensor unsupported on HPU, casting to float")
result = result.float()
tensor = tensor.float()

# Sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
world_size = torch.distributed.get_world_size(group)
collective.barrier()
collective.all_reduce(tensor, op=op)

if not divide_by_world_size:
return result
# `torch.distributed.all_reduce` is in-place, so we should do the division in-place to leave the modified tensors
return tensor
# `all_reduce` is in-place, so we should do the division in-place to leave the modified tensors
# with the expected value
if not torch.is_floating_point(result):
return result.copy_(result / world_size)
return result.div_(world_size)
if not torch.is_floating_point(tensor):
return tensor.copy_(tensor / collective.world_size)
return tensor.div_(collective.world_size)


def _all_gather_ddp_if_available(
tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
) -> Tensor:
def _all_gather_if_available(tensor: Tensor, collective: Collective, sync_grads: bool = False) -> Tensor:
"""Function to gather a tensor from several distributed processes.

Args:
tensor: Tensor of shape (batch, ...)
group: The process group to gather results from. Defaults to all processes (world)
collective: The collective backend to use for the all-gather.
sync_grads: Flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)

"""
if not _distributed_is_initialized():
if not collective.is_initialized():
return tensor

# TODO: Enable all-gather with grads in TorchCollective
from torch.distributed.nn.functional import all_gather

tensor = tensor.contiguous() # https://github.com/pytorch/pytorch/issues/73515
Expand Down