-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Open
Labels
bugSomething isn't workingSomething isn't workingfabriclightning.fabric.Fabriclightning.fabric.Fabricstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.1.x
Description
Bug description
Gradient synchronisation in fabric.backward()
is broken when moving a model back to CPU and back again to GPU.
Moving a model temporarily to CPU is useful when GPU resources need to be temporarily allocated for another task (e.g., building a faiss
index). This issue happens silently (no error message) and causes models to be effectively trained on a single device (gradients are never synchronised, unless done explicitly).
What version are you seeing the problem on?
master
How to reproduce the bug
Run the following snippet to trigger the problem. The code works as expected when keeping the model on GPU but fails when moving back to CPU and then back to GPU:
import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.relu(self.net1(x))
return self.net2(x)
def check_grads_synchronized(fabric: L.Fabric, move_model_to_cpu_and_back: bool = True) -> None:
"""Create a dummy mode, run a forward pass and backward pass, and check that the gradients are synchronized."""
model = ToyModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
lit_model, lit_optimizer = fabric.setup(model, optimizer)
if move_model_to_cpu_and_back:
# move the model to CPU and back to GPU
# this is to test that the gradients are synchronized even if the model is moved to CPU
lit_model.cpu()
fabric.to_device(lit_model)
# create data
x = torch.randn(20, 10)
labels = torch.randn(20, 5)
# forward pass
x = fabric.to_device(x)
labels = fabric.to_device(labels)
outputs = lit_model(x)
lit_optimizer.zero_grad()
loss = loss_fn(outputs, labels)
fabric.backward(loss)
for k, v in lit_model.named_parameters():
if v.grad is not None:
# gather gradients from all processes
grad_list = [torch.zeros_like(v.grad) for _ in range(fabric.world_size)]
torch.distributed.all_gather(grad_list, v.grad, async_op=False)
for grad in grad_list:
if not torch.allclose(v.grad, grad):
raise RuntimeError(
f"move_model_to_cpu_and_back={move_model_to_cpu_and_back}. "
f"Gradients are not equal across processes (p={k})")
if __name__ == "__main__":
fabric = L.Fabric(strategy="ddp", devices=2)
fabric.launch()
check_grads_synchronized(fabric, move_model_to_cpu_and_back=False)
if fabric.is_global_zero:
print("\nmove_model_to_cpu_and_back=False: SUCCESS!\n")
check_grads_synchronized(fabric, move_model_to_cpu_and_back=True)
if fabric.is_global_zero:
print("\nmove_model_to_cpu_and_back=True: SUCCESS!\n")
Error messages and logs
This cause lightning to fail silently. No log here.
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- available: True
- version: 11.7
- GPU:
- Lightning:
- lightning: 2.1.0.dev0
- lightning-cloud: 0.5.37
- lightning-utilities: 0.8.0
- pytorch-lightning: 2.0.4
- torch: 2.0.1
- torchmetrics: 0.11.4
- Packages:
- aiohttp: 3.8.4
- aiosignal: 1.3.1
- anyio: 3.7.0
- arrow: 1.2.3
- async-timeout: 4.0.2
- attrs: 23.1.0
- beautifulsoup4: 4.12.2
- blessed: 1.20.0
- certifi: 2023.5.7
- charset-normalizer: 3.1.0
- click: 8.1.3
- cmake: 3.26.4
- croniter: 1.3.15
- dateutils: 0.6.12
- deepdiff: 6.3.0
- exceptiongroup: 1.1.1
- fastapi: 0.98.0
- filelock: 3.12.2
- frozenlist: 1.3.3
- fsspec: 2023.6.0
- h11: 0.14.0
- idna: 3.4
- inquirer: 3.1.3
- itsdangerous: 2.1.2
- jinja2: 3.1.2
- lightning: 2.1.0.dev0
- lightning-cloud: 0.5.37
- lightning-utilities: 0.8.0
- lit: 16.0.6
- markdown-it-py: 3.0.0
- markupsafe: 2.1.3
- mdurl: 0.1.2
- mpmath: 1.3.0
- multidict: 6.0.4
- networkx: 3.1
- numpy: 1.25.0
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.2.10.91
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-nccl-cu11: 2.14.3
- nvidia-nvtx-cu11: 11.7.91
- ordered-set: 4.1.0
- packaging: 23.1
- pip: 23.1.2
- psutil: 5.9.5
- pydantic: 1.10.9
- pygments: 2.15.1
- pyjwt: 2.7.0
- python-dateutil: 2.8.2
- python-editor: 1.0.4
- python-multipart: 0.0.6
- pytorch-lightning: 2.0.4
- pytz: 2023.3
- pyyaml: 6.0
- readchar: 4.0.5
- requests: 2.31.0
- rich: 13.4.2
- setuptools: 68.0.0
- six: 1.16.0
- sniffio: 1.3.0
- soupsieve: 2.4.1
- starlette: 0.27.0
- starsessions: 1.3.0
- sympy: 1.12
- torch: 2.0.1
- torchmetrics: 0.11.4
- tqdm: 4.65.0
- traitlets: 5.9.0
- triton: 2.0.0
- typing-extensions: 4.6.3
- urllib3: 2.0.3
- uvicorn: 0.22.0
- wcwidth: 0.2.6
- websocket-client: 1.6.1
- websockets: 11.0.3
- wheel: 0.40.0
- yarl: 1.9.2
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor:
- python: 3.10.12
- release: 4.19.0-24-cloud-amd64
- version: Proposal for help #1 SMP Debian 4.19.282-1 (2023-04-29)
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingfabriclightning.fabric.Fabriclightning.fabric.Fabricstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.1.x