Skip to content

Crash when using a non-scalar together with scalar pytorch-metrics in distributed training #18290

@jjhw3

Description

@jjhw3

Bug description

I have encountered a hard crash when using a non-scalar pytorch-metric alongside a scalar one and training with a ddp strategy across multiple nodes. This error is extremely frustrating as it effectively means one cannot use non-scalar pytorch-metrics with distributed training. Multiple scalar metrics without a non-scalar appear to work fine. Using the non-scalar metric on its own also works fine. The crash only occurs when running with multiple nodes.

Tested on MacOS and CENTOS 8.

What version are you seeing the problem on?

master

How to reproduce the bug

This is a minimal example I've created to show the crash. I use the ROC curve as an example, but the same crash occurs if I implement my own custom metric with a non-scalar state variable. Please take notice of the comments in training_step and on_train_epoch_end.

import numpy as np
import torch
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.cli import LightningCLI
from torch import nn
from torch.nn import Sequential
from torch.utils.data import DataLoader, Dataset
from torchmetrics import SumMetric, ROC


class RandomDS(Dataset):
    def __init__(self):
        self.data = np.random.random((1000, 1)).astype(np.float32)
        self.labels = (self.data < 0.5).astype(np.float32)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, item):
        return self.data[item], self.labels[item]


class SimpleDataModule(LightningDataModule):
    def setup(self, stage: str):
        self.ds = RandomDS()

    def train_dataloader(self):
        return DataLoader(
            self.ds,
            batch_size=16,
            shuffle=True,
        )


class SimpleModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = Sequential(
            nn.Linear(1, 10),
            nn.ReLU(),
            nn.Linear(10, 1),
            nn.Sigmoid(),
        )
        self.loss_fn = nn.BCELoss()
        self.train_output_distribution = ROC('binary')
        self.some_other_scalar = SumMetric()
        self.some_other_scalar2 = SumMetric()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        yhat = self(x)
        loss = self.loss_fn(yhat, y)
        self.log('train_loss', loss, prog_bar=True, sync_dist=True)
        self.train_output_distribution.update(yhat, y.int())
        self.some_other_scalar.update(yhat)
        self.some_other_scalar2.update(y)

        # Commenting out the below 2 lines OR setting on_epoch=False stops the crash from happening
        # OR commenting out the add_figure call in on_train_epoch_end below
        self.log('some_other_scalar', self.some_other_scalar, on_step=True, on_epoch=True)
        self.log('some_other_scalar2', self.some_other_scalar2, on_step=True, on_epoch=True)

        return loss

    def on_train_epoch_end(self):
        # Commenting out the below add_figure call stops the crash. self.train_output_distribution.compute() would also crash here
        self.logger.experiment.add_figure(
            'Train Output Dist',
            self.train_output_distribution.plot()[0],
            global_step=self.current_epoch
        )

        self.train_output_distribution.reset()


if __name__ == '__main__':
    LightningCLI(SimpleModule, SimpleDataModule)

Here are the CLI calls for ease of testing if you save the above example as bug_report.py:

export MASTER_PORT=8008; export MASTER_ADDR=127.0.0.1; export WORLD_SIZE=2; export NODE_RANK=0; python bug_report.py fit --trainer.accelerator cpu --trainer.max_epochs 25 --trainer.num_nodes 2 --trainer.strategy ddp

and in a second shell

export MASTER_PORT=8008; export MASTER_ADDR=127.0.0.1; export WORLD_SIZE=2; export NODE_RANK=1; python bug_report.py fit --trainer.accelerator cpu --trainer.max_epochs 25 --trainer.num_nodes 2 --trainer.strategy ddp

Error messages and logs

At the end of the train epoch, the following exception is printed and the whole python process hard crashes.

libc++abi: terminating with uncaught exception of type gloo::EnforceNotMet: [enforce fail at /Users/runner/work/pytorch/pytorch/pytorch/third_party/gloo/gloo/transport/uv/pair.cc:248] op.nread == op.preamble.nbytes. 

On exit, the following warning is also provided which may be relevant:

/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Environment

<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:               None
        - available:         False
        - version:           None
* Lightning:
        - lightning:         2.0.6
        - lightning-cloud:   0.5.37
        - lightning-utilities: 0.9.0
        - pytorch-lightning: 2.0.6
        - torch:             2.0.1
        - torchmetrics:      1.0.3
* Packages:
        - absl-py:           1.4.0
        - aiohttp:           3.8.5
        - aiosignal:         1.3.1
        - annotated-types:   0.5.0
        - antlr4-python3-runtime: 4.9.3
        - anyio:             3.7.1
        - arrow:             1.2.3
        - async-timeout:     4.0.2
        - attrs:             23.1.0
        - awkward:           2.1.3
        - awkward-cpp:       14
        - backoff:           2.2.1
        - beautifulsoup4:    4.12.2
        - blessed:           1.20.0
        - cachetools:        5.3.1
        - certifi:           2023.7.22
        - charset-normalizer: 3.2.0
        - click:             8.1.6
        - contourpy:         1.0.7
        - croniter:          1.4.1
        - cycler:            0.11.0
        - dateutils:         0.6.12
        - deepdiff:          6.3.1
        - docstring-parser:  0.15
        - exceptiongroup:    1.1.2
        - fastapi:           0.100.1
        - filelock:          3.12.0
        - fonttools:         4.39.3
        - frozenlist:        1.4.0
        - fsspec:            2023.6.0
        - google-auth:       2.22.0
        - google-auth-oauthlib: 1.0.0
        - grpcio:            1.56.2
        - h11:               0.14.0
        - htmap:             0.6.1
        - hydra-core:        1.3.2
        - idna:              3.4
        - importlib-metadata: 6.8.0
        - importlib-resources: 5.12.0
        - inquirer:          3.1.3
        - itsdangerous:      2.1.2
        - jinja2:            3.1.2
        - jsonargparse:      4.22.1
        - kiwisolver:        1.4.4
        - lightning:         2.0.6
        - lightning-cloud:   0.5.37
        - lightning-utilities: 0.9.0
        - markdown:          3.4.4
        - markdown-it-py:    3.0.0
        - markupsafe:        2.1.2
        - matplotlib:        3.7.1
        - mdurl:             0.1.2
        - mplhep:            0.3.27
        - mplhep-data:       0.0.3
        - mpmath:            1.3.0
        - multidict:         6.0.4
        - networkx:          3.1
        - numpy:             1.24.3
        - oauthlib:          3.2.2
        - omegaconf:         2.3.0
        - ordered-set:       4.1.0
        - packaging:         23.1
        - pandas:            2.0.3
        - pillow:            9.5.0
        - pip:               23.2.1
        - plotagain:         1.0.6
        - protobuf:          4.23.4
        - psutil:            5.9.5
        - pyasn1:            0.5.0
        - pyasn1-modules:    0.3.0
        - pydantic:          2.0.3
        - pydantic-core:     2.3.0
        - pygments:          2.15.1
        - pyjwt:             2.8.0
        - pyparsing:         3.0.9
        - python-dateutil:   2.8.2
        - python-editor:     1.0.4
        - python-multipart:  0.0.6
        - pytorch-lightning: 2.0.6
        - pytz:              2023.3
        - pyyaml:            6.0.1
        - readchar:          4.0.5
        - requests:          2.31.0
        - requests-oauthlib: 1.3.1
        - rich:              13.5.2
        - rsa:               4.9
        - scipy:             1.10.1
        - setuptools:        60.2.0
        - six:               1.16.0
        - sniffio:           1.3.0
        - soupsieve:         2.4.1
        - starlette:         0.27.0
        - starsessions:      1.3.0
        - sympy:             1.11.1
        - tensorboard:       2.13.0
        - tensorboard-data-server: 0.7.1
        - tensorboardx:      2.6.2
        - torch:             2.0.1
        - torchmetrics:      1.0.3
        - tqdm:              4.65.0
        - traitlets:         5.9.0
        - typeshed-client:   2.3.0
        - typing-extensions: 4.7.1
        - tzdata:            2023.3
        - uhi:               0.3.3
        - uproot:            5.0.7
        - urllib3:           1.26.15
        - uvicorn:           0.23.2
        - wcwidth:           0.2.6
        - websocket-client:  1.6.1
        - websockets:        11.0.3
        - werkzeug:          2.3.6
        - wheel:             0.37.1
        - yarl:              1.9.2
        - zipp:              3.15.0
* System:
        - OS:                Darwin
        - architecture:
                - 64bit
                - 
        - processor:         arm
        - python:            3.9.6
        - release:           22.1.0
        - version:           Darwin Kernel Version 22.1.0: Sun Oct  9 20:15:09 PDT 2022; root:xnu-8792.41.9~2/RELEASE_ARM64_T6000

</details>

Also tested in the following environment

<details>
  <summary>Current environment</summary>

* CUDA:
	- GPU:               None
	- available:         False
	- version:           11.7
* Lightning:
	- lightning:         2.0.6
	- lightning-cloud:   0.5.37
	- lightning-utilities: 0.9.0
	- pytorch-lightning: 2.0.6
	- torch:             2.0.1
	- torchmetrics:      1.0.1
* Packages:
	- absl-py:           1.4.0
	- aiohttp:           3.8.5
	- aiosignal:         1.3.1
	- annotated-types:   0.5.0
	- antlr4-python3-runtime: 4.9.3
	- anyio:             3.7.1
	- arrow:             1.2.3
	- async-timeout:     4.0.2
	- attrs:             23.1.0
	- awkward:           2.3.1
	- awkward-cpp:       21
	- backoff:           2.2.1
	- beautifulsoup4:    4.12.2
	- blessed:           1.20.0
	- cachetools:        5.3.1
	- certifi:           2023.7.22
	- charset-normalizer: 3.2.0
	- click:             8.1.6
	- click-didyoumean:  0.3.0
	- cloudpickle:       2.2.1
	- cmake:             3.27.0
	- colorama:          0.4.6
	- contourpy:         1.1.0
	- croniter:          1.4.1
	- cycler:            0.11.0
	- dateutils:         0.6.12
	- deepdiff:          6.3.1
	- docstring-parser:  0.15
	- exceptiongroup:    1.1.2
	- fastapi:           0.100.1
	- filelock:          3.12.2
	- fonttools:         4.41.1
	- frozenlist:        1.4.0
	- fsspec:            2023.6.0
	- gloo:              0.1.2
	- google-auth:       2.22.0
	- google-auth-oauthlib: 1.0.0
	- grpcio:            1.56.2
	- h11:               0.14.0
	- halo:              0.0.31
	- htcondor:          10.7.0
	- htmap:             0.6.1
	- hydra-core:        1.3.2
	- idna:              3.4
	- importlib-metadata: 6.8.0
	- importlib-resources: 6.0.0
	- inquirer:          3.1.3
	- itsdangerous:      2.1.2
	- jinja2:            3.1.2
	- jsonargparse:      4.22.1
	- kiwisolver:        1.4.4
	- lightning:         2.0.6
	- lightning-cloud:   0.5.37
	- lightning-utilities: 0.9.0
	- lit:               16.0.6
	- log-symbols:       0.0.14
	- markdown:          3.4.4
	- markdown-it-py:    3.0.0
	- markupsafe:        2.1.3
	- matplotlib:        3.7.2
	- mdurl:             0.1.2
	- mplhep:            0.3.28
	- mplhep-data:       0.0.3
	- mpmath:            1.3.0
	- multidict:         6.0.4
	- networkx:          3.1
	- numpy:             1.25.2
	- nvidia-cublas-cu11: 11.10.3.66
	- nvidia-cuda-cupti-cu11: 11.7.101
	- nvidia-cuda-nvrtc-cu11: 11.7.99
	- nvidia-cuda-runtime-cu11: 11.7.99
	- nvidia-cudnn-cu11: 8.5.0.96
	- nvidia-cufft-cu11: 10.9.0.58
	- nvidia-curand-cu11: 10.2.10.91
	- nvidia-cusolver-cu11: 11.4.0.1
	- nvidia-cusparse-cu11: 11.7.4.91
	- nvidia-nccl-cu11:  2.14.3
	- nvidia-nvtx-cu11:  11.7.91
	- oauthlib:          3.2.2
	- omegaconf:         2.3.0
	- ordered-set:       4.1.0
	- packaging:         23.1
	- pandas:            2.0.3
	- pillow:            10.0.0
	- pip:               22.0.4
	- plotagain:         1.0.6
	- protobuf:          4.23.4
	- psutil:            5.9.5
	- pyasn1:            0.5.0
	- pyasn1-modules:    0.3.0
	- pydantic:          2.0.3
	- pydantic-core:     2.3.0
	- pygments:          2.15.1
	- pyjwt:             2.8.0
	- pyparsing:         3.0.9
	- python-dateutil:   2.8.2
	- python-editor:     1.0.4
	- python-multipart:  0.0.6
	- pytorch-lightning: 2.0.6
	- pytz:              2023.3
	- pyyaml:            6.0.1
	- readchar:          4.0.5
	- requests:          2.31.0
	- requests-oauthlib: 1.3.1
	- rich:              13.5.2
	- rsa:               4.9
	- scipy:             1.11.1
	- setuptools:        58.1.0
	- six:               1.16.0
	- sniffio:           1.3.0
	- soupsieve:         2.4.1
	- spinners:          0.0.24
	- starlette:         0.27.0
	- starsessions:      1.3.0
	- sympy:             1.12
	- tensorboard:       2.13.0
	- tensorboard-data-server: 0.7.1
	- tensorboardx:      2.6.2
	- termcolor:         2.3.0
	- toml:              0.10.2
	- torch:             2.0.1
	- torchmetrics:      1.0.1
	- tqdm:              4.65.0
	- traitlets:         5.9.0
	- triton:            2.0.0
	- typeshed-client:   2.3.0
	- typing-extensions: 4.7.1
	- tzdata:            2023.3
	- uhi:               0.3.3
	- uproot:            5.0.10
	- urllib3:           1.26.16
	- uvicorn:           0.23.2
	- wcwidth:           0.2.6
	- websocket-client:  1.6.1
	- websockets:        11.0.3
	- werkzeug:          2.3.6
	- wheel:             0.41.0
	- yarl:              1.9.2
	- zipp:              3.16.2
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.9.14
	- release:           3.10.0-1160.88.1.el7.x86_64
	- version:           #1 SMP Tue Mar 7 15:41:52 UTC 2023

</details>

More info

No response

cc @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions