Skip to content

DDP: moving model to CPU and back to GPU breaks gradient synchronization #17937

@vlievin

Description

@vlievin

Bug description

Gradient synchronisation in fabric.backward() is broken when moving a model back to CPU and back again to GPU.

Moving a model temporarily to CPU is useful when GPU resources need to be temporarily allocated for another task (e.g., building a faiss index). This issue happens silently (no error message) and causes models to be effectively trained on a single device (gradients are never synchronised, unless done explicitly).

What version are you seeing the problem on?

master

How to reproduce the bug

Run the following snippet to trigger the problem. The code works as expected when keeping the model on GPU but fails when moving back to CPU and then back to GPU:

import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.net1(x))
        return self.net2(x)


def check_grads_synchronized(fabric: L.Fabric, move_model_to_cpu_and_back: bool = True) -> None:
    """Create a dummy mode, run a forward pass and backward pass, and check that the gradients are synchronized."""
    model = ToyModel()
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    lit_model, lit_optimizer = fabric.setup(model, optimizer)

    if move_model_to_cpu_and_back:
        # move the model to CPU and back to GPU
        # this is to test that the gradients are synchronized even if the model is moved to CPU
        lit_model.cpu()
        fabric.to_device(lit_model)

    # create data
    x = torch.randn(20, 10)
    labels = torch.randn(20, 5)

    # forward pass
    x = fabric.to_device(x)
    labels = fabric.to_device(labels)
    outputs = lit_model(x)

    lit_optimizer.zero_grad()
    loss = loss_fn(outputs, labels)
    fabric.backward(loss)

    for k, v in lit_model.named_parameters():
        if v.grad is not None:
            # gather gradients from all processes
            grad_list = [torch.zeros_like(v.grad) for _ in range(fabric.world_size)]
            torch.distributed.all_gather(grad_list, v.grad, async_op=False)
            for grad in grad_list:
                if not torch.allclose(v.grad, grad):
                    raise RuntimeError(
                        f"move_model_to_cpu_and_back={move_model_to_cpu_and_back}. "
                        f"Gradients are not equal across processes (p={k})")


if __name__ == "__main__":
    fabric = L.Fabric(strategy="ddp", devices=2)
    fabric.launch()
    check_grads_synchronized(fabric, move_model_to_cpu_and_back=False)
    if fabric.is_global_zero:
        print("\nmove_model_to_cpu_and_back=False: SUCCESS!\n")
    check_grads_synchronized(fabric, move_model_to_cpu_and_back=True)
    if fabric.is_global_zero:
        print("\nmove_model_to_cpu_and_back=True: SUCCESS!\n")

Error messages and logs

This cause lightning to fail silently. No log here.

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.1.0.dev0
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.4
    • torch: 2.0.1
    • torchmetrics: 0.11.4
  • Packages:
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • anyio: 3.7.0
    • arrow: 1.2.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • beautifulsoup4: 4.12.2
    • blessed: 1.20.0
    • certifi: 2023.5.7
    • charset-normalizer: 3.1.0
    • click: 8.1.3
    • cmake: 3.26.4
    • croniter: 1.3.15
    • dateutils: 0.6.12
    • deepdiff: 6.3.0
    • exceptiongroup: 1.1.1
    • fastapi: 0.98.0
    • filelock: 3.12.2
    • frozenlist: 1.3.3
    • fsspec: 2023.6.0
    • h11: 0.14.0
    • idna: 3.4
    • inquirer: 3.1.3
    • itsdangerous: 2.1.2
    • jinja2: 3.1.2
    • lightning: 2.1.0.dev0
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.8.0
    • lit: 16.0.6
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • mdurl: 0.1.2
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • networkx: 3.1
    • numpy: 1.25.0
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cuda-cupti-cu11: 11.7.101
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • nvidia-cufft-cu11: 10.9.0.58
    • nvidia-curand-cu11: 10.2.10.91
    • nvidia-cusolver-cu11: 11.4.0.1
    • nvidia-cusparse-cu11: 11.7.4.91
    • nvidia-nccl-cu11: 2.14.3
    • nvidia-nvtx-cu11: 11.7.91
    • ordered-set: 4.1.0
    • packaging: 23.1
    • pip: 23.1.2
    • psutil: 5.9.5
    • pydantic: 1.10.9
    • pygments: 2.15.1
    • pyjwt: 2.7.0
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.4
    • pytz: 2023.3
    • pyyaml: 6.0
    • readchar: 4.0.5
    • requests: 2.31.0
    • rich: 13.4.2
    • setuptools: 68.0.0
    • six: 1.16.0
    • sniffio: 1.3.0
    • soupsieve: 2.4.1
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • sympy: 1.12
    • torch: 2.0.1
    • torchmetrics: 0.11.4
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • triton: 2.0.0
    • typing-extensions: 4.6.3
    • urllib3: 2.0.3
    • uvicorn: 0.22.0
    • wcwidth: 0.2.6
    • websocket-client: 1.6.1
    • websockets: 11.0.3
    • wheel: 0.40.0
    • yarl: 1.9.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor:
    • python: 3.10.12
    • release: 4.19.0-24-cloud-amd64
    • version: Proposal for help #1 SMP Debian 4.19.282-1 (2023-04-29)

More info

No response

cc @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingfabriclightning.fabric.Fabricstrategy: ddpDistributedDataParallelver: 2.1.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions