-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I have encountered a hard crash when using a non-scalar pytorch-metric alongside a scalar one and training with a ddp strategy across multiple nodes. This error is extremely frustrating as it effectively means one cannot use non-scalar pytorch-metrics with distributed training. Multiple scalar metrics without a non-scalar appear to work fine. Using the non-scalar metric on its own also works fine. The crash only occurs when running with multiple nodes.
Tested on MacOS and CENTOS 8.
What version are you seeing the problem on?
master
How to reproduce the bug
This is a minimal example I've created to show the crash. I use the ROC curve as an example, but the same crash occurs if I implement my own custom metric with a non-scalar state variable. Please take notice of the comments in training_step and on_train_epoch_end.
import numpy as np
import torch
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.cli import LightningCLI
from torch import nn
from torch.nn import Sequential
from torch.utils.data import DataLoader, Dataset
from torchmetrics import SumMetric, ROC
class RandomDS(Dataset):
def __init__(self):
self.data = np.random.random((1000, 1)).astype(np.float32)
self.labels = (self.data < 0.5).astype(np.float32)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, item):
return self.data[item], self.labels[item]
class SimpleDataModule(LightningDataModule):
def setup(self, stage: str):
self.ds = RandomDS()
def train_dataloader(self):
return DataLoader(
self.ds,
batch_size=16,
shuffle=True,
)
class SimpleModule(LightningModule):
def __init__(self):
super().__init__()
self.model = Sequential(
nn.Linear(1, 10),
nn.ReLU(),
nn.Linear(10, 1),
nn.Sigmoid(),
)
self.loss_fn = nn.BCELoss()
self.train_output_distribution = ROC('binary')
self.some_other_scalar = SumMetric()
self.some_other_scalar2 = SumMetric()
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-4)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
yhat = self(x)
loss = self.loss_fn(yhat, y)
self.log('train_loss', loss, prog_bar=True, sync_dist=True)
self.train_output_distribution.update(yhat, y.int())
self.some_other_scalar.update(yhat)
self.some_other_scalar2.update(y)
# Commenting out the below 2 lines OR setting on_epoch=False stops the crash from happening
# OR commenting out the add_figure call in on_train_epoch_end below
self.log('some_other_scalar', self.some_other_scalar, on_step=True, on_epoch=True)
self.log('some_other_scalar2', self.some_other_scalar2, on_step=True, on_epoch=True)
return loss
def on_train_epoch_end(self):
# Commenting out the below add_figure call stops the crash. self.train_output_distribution.compute() would also crash here
self.logger.experiment.add_figure(
'Train Output Dist',
self.train_output_distribution.plot()[0],
global_step=self.current_epoch
)
self.train_output_distribution.reset()
if __name__ == '__main__':
LightningCLI(SimpleModule, SimpleDataModule)
Here are the CLI calls for ease of testing if you save the above example as bug_report.py:
export MASTER_PORT=8008; export MASTER_ADDR=127.0.0.1; export WORLD_SIZE=2; export NODE_RANK=0; python bug_report.py fit --trainer.accelerator cpu --trainer.max_epochs 25 --trainer.num_nodes 2 --trainer.strategy ddp
and in a second shell
export MASTER_PORT=8008; export MASTER_ADDR=127.0.0.1; export WORLD_SIZE=2; export NODE_RANK=1; python bug_report.py fit --trainer.accelerator cpu --trainer.max_epochs 25 --trainer.num_nodes 2 --trainer.strategy ddp
Error messages and logs
At the end of the train epoch, the following exception is printed and the whole python process hard crashes.
libc++abi: terminating with uncaught exception of type gloo::EnforceNotMet: [enforce fail at /Users/runner/work/pytorch/pytorch/pytorch/third_party/gloo/gloo/transport/uv/pair.cc:248] op.nread == op.preamble.nbytes.
On exit, the following warning is also provided which may be relevant:
/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Environment
<details>
<summary>Current environment</summary>
* CUDA:
- GPU: None
- available: False
- version: None
* Lightning:
- lightning: 2.0.6
- lightning-cloud: 0.5.37
- lightning-utilities: 0.9.0
- pytorch-lightning: 2.0.6
- torch: 2.0.1
- torchmetrics: 1.0.3
* Packages:
- absl-py: 1.4.0
- aiohttp: 3.8.5
- aiosignal: 1.3.1
- annotated-types: 0.5.0
- antlr4-python3-runtime: 4.9.3
- anyio: 3.7.1
- arrow: 1.2.3
- async-timeout: 4.0.2
- attrs: 23.1.0
- awkward: 2.1.3
- awkward-cpp: 14
- backoff: 2.2.1
- beautifulsoup4: 4.12.2
- blessed: 1.20.0
- cachetools: 5.3.1
- certifi: 2023.7.22
- charset-normalizer: 3.2.0
- click: 8.1.6
- contourpy: 1.0.7
- croniter: 1.4.1
- cycler: 0.11.0
- dateutils: 0.6.12
- deepdiff: 6.3.1
- docstring-parser: 0.15
- exceptiongroup: 1.1.2
- fastapi: 0.100.1
- filelock: 3.12.0
- fonttools: 4.39.3
- frozenlist: 1.4.0
- fsspec: 2023.6.0
- google-auth: 2.22.0
- google-auth-oauthlib: 1.0.0
- grpcio: 1.56.2
- h11: 0.14.0
- htmap: 0.6.1
- hydra-core: 1.3.2
- idna: 3.4
- importlib-metadata: 6.8.0
- importlib-resources: 5.12.0
- inquirer: 3.1.3
- itsdangerous: 2.1.2
- jinja2: 3.1.2
- jsonargparse: 4.22.1
- kiwisolver: 1.4.4
- lightning: 2.0.6
- lightning-cloud: 0.5.37
- lightning-utilities: 0.9.0
- markdown: 3.4.4
- markdown-it-py: 3.0.0
- markupsafe: 2.1.2
- matplotlib: 3.7.1
- mdurl: 0.1.2
- mplhep: 0.3.27
- mplhep-data: 0.0.3
- mpmath: 1.3.0
- multidict: 6.0.4
- networkx: 3.1
- numpy: 1.24.3
- oauthlib: 3.2.2
- omegaconf: 2.3.0
- ordered-set: 4.1.0
- packaging: 23.1
- pandas: 2.0.3
- pillow: 9.5.0
- pip: 23.2.1
- plotagain: 1.0.6
- protobuf: 4.23.4
- psutil: 5.9.5
- pyasn1: 0.5.0
- pyasn1-modules: 0.3.0
- pydantic: 2.0.3
- pydantic-core: 2.3.0
- pygments: 2.15.1
- pyjwt: 2.8.0
- pyparsing: 3.0.9
- python-dateutil: 2.8.2
- python-editor: 1.0.4
- python-multipart: 0.0.6
- pytorch-lightning: 2.0.6
- pytz: 2023.3
- pyyaml: 6.0.1
- readchar: 4.0.5
- requests: 2.31.0
- requests-oauthlib: 1.3.1
- rich: 13.5.2
- rsa: 4.9
- scipy: 1.10.1
- setuptools: 60.2.0
- six: 1.16.0
- sniffio: 1.3.0
- soupsieve: 2.4.1
- starlette: 0.27.0
- starsessions: 1.3.0
- sympy: 1.11.1
- tensorboard: 2.13.0
- tensorboard-data-server: 0.7.1
- tensorboardx: 2.6.2
- torch: 2.0.1
- torchmetrics: 1.0.3
- tqdm: 4.65.0
- traitlets: 5.9.0
- typeshed-client: 2.3.0
- typing-extensions: 4.7.1
- tzdata: 2023.3
- uhi: 0.3.3
- uproot: 5.0.7
- urllib3: 1.26.15
- uvicorn: 0.23.2
- wcwidth: 0.2.6
- websocket-client: 1.6.1
- websockets: 11.0.3
- werkzeug: 2.3.6
- wheel: 0.37.1
- yarl: 1.9.2
- zipp: 3.15.0
* System:
- OS: Darwin
- architecture:
- 64bit
-
- processor: arm
- python: 3.9.6
- release: 22.1.0
- version: Darwin Kernel Version 22.1.0: Sun Oct 9 20:15:09 PDT 2022; root:xnu-8792.41.9~2/RELEASE_ARM64_T6000
</details>
Also tested in the following environment
<details>
<summary>Current environment</summary>
* CUDA:
- GPU: None
- available: False
- version: 11.7
* Lightning:
- lightning: 2.0.6
- lightning-cloud: 0.5.37
- lightning-utilities: 0.9.0
- pytorch-lightning: 2.0.6
- torch: 2.0.1
- torchmetrics: 1.0.1
* Packages:
- absl-py: 1.4.0
- aiohttp: 3.8.5
- aiosignal: 1.3.1
- annotated-types: 0.5.0
- antlr4-python3-runtime: 4.9.3
- anyio: 3.7.1
- arrow: 1.2.3
- async-timeout: 4.0.2
- attrs: 23.1.0
- awkward: 2.3.1
- awkward-cpp: 21
- backoff: 2.2.1
- beautifulsoup4: 4.12.2
- blessed: 1.20.0
- cachetools: 5.3.1
- certifi: 2023.7.22
- charset-normalizer: 3.2.0
- click: 8.1.6
- click-didyoumean: 0.3.0
- cloudpickle: 2.2.1
- cmake: 3.27.0
- colorama: 0.4.6
- contourpy: 1.1.0
- croniter: 1.4.1
- cycler: 0.11.0
- dateutils: 0.6.12
- deepdiff: 6.3.1
- docstring-parser: 0.15
- exceptiongroup: 1.1.2
- fastapi: 0.100.1
- filelock: 3.12.2
- fonttools: 4.41.1
- frozenlist: 1.4.0
- fsspec: 2023.6.0
- gloo: 0.1.2
- google-auth: 2.22.0
- google-auth-oauthlib: 1.0.0
- grpcio: 1.56.2
- h11: 0.14.0
- halo: 0.0.31
- htcondor: 10.7.0
- htmap: 0.6.1
- hydra-core: 1.3.2
- idna: 3.4
- importlib-metadata: 6.8.0
- importlib-resources: 6.0.0
- inquirer: 3.1.3
- itsdangerous: 2.1.2
- jinja2: 3.1.2
- jsonargparse: 4.22.1
- kiwisolver: 1.4.4
- lightning: 2.0.6
- lightning-cloud: 0.5.37
- lightning-utilities: 0.9.0
- lit: 16.0.6
- log-symbols: 0.0.14
- markdown: 3.4.4
- markdown-it-py: 3.0.0
- markupsafe: 2.1.3
- matplotlib: 3.7.2
- mdurl: 0.1.2
- mplhep: 0.3.28
- mplhep-data: 0.0.3
- mpmath: 1.3.0
- multidict: 6.0.4
- networkx: 3.1
- numpy: 1.25.2
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.2.10.91
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-nccl-cu11: 2.14.3
- nvidia-nvtx-cu11: 11.7.91
- oauthlib: 3.2.2
- omegaconf: 2.3.0
- ordered-set: 4.1.0
- packaging: 23.1
- pandas: 2.0.3
- pillow: 10.0.0
- pip: 22.0.4
- plotagain: 1.0.6
- protobuf: 4.23.4
- psutil: 5.9.5
- pyasn1: 0.5.0
- pyasn1-modules: 0.3.0
- pydantic: 2.0.3
- pydantic-core: 2.3.0
- pygments: 2.15.1
- pyjwt: 2.8.0
- pyparsing: 3.0.9
- python-dateutil: 2.8.2
- python-editor: 1.0.4
- python-multipart: 0.0.6
- pytorch-lightning: 2.0.6
- pytz: 2023.3
- pyyaml: 6.0.1
- readchar: 4.0.5
- requests: 2.31.0
- requests-oauthlib: 1.3.1
- rich: 13.5.2
- rsa: 4.9
- scipy: 1.11.1
- setuptools: 58.1.0
- six: 1.16.0
- sniffio: 1.3.0
- soupsieve: 2.4.1
- spinners: 0.0.24
- starlette: 0.27.0
- starsessions: 1.3.0
- sympy: 1.12
- tensorboard: 2.13.0
- tensorboard-data-server: 0.7.1
- tensorboardx: 2.6.2
- termcolor: 2.3.0
- toml: 0.10.2
- torch: 2.0.1
- torchmetrics: 1.0.1
- tqdm: 4.65.0
- traitlets: 5.9.0
- triton: 2.0.0
- typeshed-client: 2.3.0
- typing-extensions: 4.7.1
- tzdata: 2023.3
- uhi: 0.3.3
- uproot: 5.0.10
- urllib3: 1.26.16
- uvicorn: 0.23.2
- wcwidth: 0.2.6
- websocket-client: 1.6.1
- websockets: 11.0.3
- werkzeug: 2.3.6
- wheel: 0.41.0
- yarl: 1.9.2
- zipp: 3.16.2
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.14
- release: 3.10.0-1160.88.1.el7.x86_64
- version: #1 SMP Tue Mar 7 15:41:52 UTC 2023
</details>
More info
No response