Skip to content

Step when validation happens drifts for val_check_interval when gradient accumulation turned on  #17207

@hrukalive

Description

@hrukalive

Bug description

First of all, my task relies on step count instead of epochs. So I am doing validation checks by steps and saving checkpoints after that. However, as I turned gradient accumulation on, and the batch count is not divisible, I encountered weird drifts for the actual step when the validation is performed, and thus the checkpointing.

In the example below, I override the _save_checkpoint function to monitor the actual file name and it turns out to be drifting. My general setting is val_check_interval=accumulation*5 to make it validate every 5 effective optimizer steps, accumulation=3 and #batches=67 so there is one batch leftover.

How to reproduce the bug

import numpy as np
import pathlib

import time
import torch
import torch.nn as nn
import torch.optim

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

class Quadratic(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.tensor(0.0))
        self.b = nn.Parameter(torch.tensor(0.0))
        self.c = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        time.sleep(0.02)
        return self.a * x * x + self.b * x + self.c
    
    def _common_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        return loss 

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

class CustomModelCheckpoint(ModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        monitor_candidates = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self._monitor_candidates(trainer).items()}
        print("\n", "Save checkpoint, global_step: ", trainer.global_step, pathlib.Path(filepath).stem, "monitor_candidates: " + str(monitor_candidates), "\n", flush=True)
        
    def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        # print("Remove checkpoint: ", filepath, flush=True)
        pass
    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-a', type=float, default=2.0)
    parser.add_argument('-b', type=float, default=3.0)
    parser.add_argument('-c', type=float, default=4.0)
    parser.add_argument('--epoch', type=int, default=500)
    args = parser.parse_args()

    x = torch.from_numpy(np.random.uniform(-10, 10, 2144)).float() # Make 67 batches
    y = args.a * x * x + args.b * x + args.c
    x2 = torch.from_numpy(np.random.uniform(-10, 10, 100)).float()
    y2 = args.a * x2 * x2 + args.b * x2 + args.c

    dataset = torch.utils.data.TensorDataset(x, y)
    val_dataset = torch.utils.data.TensorDataset(x2, y2)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    model = Quadratic()
    
    ####
    accumulate_grad_batches = 3
    val_check_interval = 5 * accumulate_grad_batches # to make interval for effective batches
    ####

    trainer = pl.Trainer(max_epochs=args.epoch, accelerator='cpu', callbacks=[CustomModelCheckpoint(
                    dirpath='.',
                    filename='steps_{step}',
                    monitor='step',
                    mode='max',
                    save_last=False,
                    save_top_k=5
                )],
            val_check_interval=val_check_interval,
            check_val_every_n_epoch=None,
            num_sanity_val_steps=0,
            accumulate_grad_batches=accumulate_grad_batches)
    trainer.fit(model, dataloader, val_dataloader)
    
    # Print the results
    print("a = ", model.a.item())
    print("b = ", model.b.item())
    print("c = ", model.c.item())

Error messages and logs

Save checkpoint, global_step:  5 steps_step=5 monitor_candidates: {'epoch': 0, 'step': 5}
Save checkpoint, global_step:  10 steps_step=10 monitor_candidates: {'epoch': 0, 'step': 10}
Save checkpoint, global_step:  15 steps_step=15 monitor_candidates: {'epoch': 0, 'step': 15}
Save checkpoint, global_step:  20 steps_step=20 monitor_candidates: {'epoch': 0, 'step': 20}
Save checkpoint, global_step:  25 steps_step=25 monitor_candidates: {'epoch': 1, 'step': 25}
Save checkpoint, global_step:  30 steps_step=30 monitor_candidates: {'epoch': 1, 'step': 30}
Save checkpoint, global_step:  35 steps_step=35 monitor_candidates: {'epoch': 1, 'step': 35}
Save checkpoint, global_step:  40 steps_step=40 monitor_candidates: {'epoch': 1, 'step': 40}

Save checkpoint, global_step:  46 steps_step=46 monitor_candidates: {'epoch': 2, 'step': 46}  <-- drift
Save checkpoint, global_step:  51 steps_step=51 monitor_candidates: {'epoch': 2, 'step': 51}
Save checkpoint, global_step:  56 steps_step=56 monitor_candidates: {'epoch': 2, 'step': 56}

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
        - available:         True
        - version:           11.7
* Lightning:
        - lightning:         2.0.0
        - lightning-cloud:   0.5.32
        - lightning-lite:    1.8.6
        - lightning-utilities: 0.8.0
        - pytorch-lightning: 2.0.0
        - torch:             1.13.1
        - torchaudio:        0.13.1
        - torchcrepe:        0.0.17
        - torchmetrics:      0.11.4
        - torchvision:       0.14.1
* Packages:
        - absl-py:           1.3.0
        - aiobotocore:       2.4.2
        - aiohttp:           3.8.4
        - aioitertools:      0.11.0
        - aiosignal:         1.3.1
        - altgraph:          0.17.3
        - anyio:             3.6.2
        - appdirs:           1.4.4
        - arrow:             1.2.3
        - async-timeout:     4.0.2
        - attrs:             22.2.0
        - audioread:         3.0.0
        - backcall:          0.2.0
        - beautifulsoup4:    4.12.0
        - blessed:           1.20.0
        - blinker:           1.4
        - botocore:          1.27.59
        - brotlipy:          0.7.0
        - cachetools:        5.3.0
        - certifi:           2022.12.7
        - cffi:              1.15.1
        - charset-normalizer: 2.0.4
        - click:             8.1.3
        - contourpy:         1.0.7
        - croniter:          1.3.8
        - cryptography:      39.0.1
        - cycler:            0.11.0
        - dateutils:         0.6.12
        - decorator:         5.1.1
        - deepdiff:          6.3.0
        - distance:          0.1.3
        - dnspython:         2.3.0
        - einops:            0.6.0
        - email-validator:   1.3.1
        - et-xmlfile:        1.0.1
        - fastapi:           0.88.0
        - fire:              0.5.0
        - flit-core:         3.8.0
        - fonttools:         4.39.2
        - frozenlist:        1.3.3
        - fsspec:            2023.3.0
        - future:            0.18.2
        - g2p-en:            2.1.0
        - g2pm:              0.1.2.5
        - google-auth:       2.16.3
        - google-auth-oauthlib: 0.4.6
        - grpcio:            1.51.3
        - h11:               0.14.0
        - h5py:              3.7.0
        - httpcore:          0.16.3
        - httptools:         0.5.0
        - httpx:             0.23.3
        - idna:              3.4
        - imageio:           2.23.0
        - importlib-metadata: 6.1.0
        - inflect:           6.0.2
        - inquirer:          3.1.3
        - itsdangerous:      2.1.2
        - jinja2:            3.1.2
        - jmespath:          1.0.1
        - joblib:            1.2.0
        - kiwisolver:        1.4.4
        - librosa:           0.9.1
        - lightning:         2.0.0
        - lightning-cloud:   0.5.32
        - lightning-lite:    1.8.6
        - lightning-utilities: 0.8.0
        - llvmlite:          0.39.1
        - markdown:          3.4.3
        - markdown-it-py:    2.2.0
        - markupsafe:        2.1.2
        - matplotlib:        3.6.2
        - mdurl:             0.1.2
        - mkl-fft:           1.3.1
        - mkl-random:        1.2.2
        - mkl-service:       2.4.0
        - multidict:         6.0.4
        - networkx:          3.0
        - nltk:              3.8.1
        - numba:             0.56.4
        - numpy:             1.23.5
        - oauthlib:          3.2.2
        - ordered-set:       4.1.0
        - orjson:            3.8.8
        - packaging:         23.0
        - pillow:            9.4.0
        - pip:               23.0.1
        - platformdirs:      3.1.1
        - pooch:             1.7.0
        - praat-parselmouth: 0.4.3
        - protobuf:          3.13.0
        - psutil:            5.9.4
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pycparser:         2.21
        - pycwt:             0.3.0a22
        - pydantic:          1.10.7
        - pygments:          2.14.0
        - pyjwt:             2.6.0
        - pyloudnorm:        0.1.0
        - pyopenssl:         23.0.0
        - pyparsing:         3.0.9
        - pypinyin:          0.39.0
        - pysocks:           1.7.1
        - python-dateutil:   2.8.2
        - python-dotenv:     1.0.0
        - python-editor:     1.0.4
        - python-levenshtein: 0.12.2
        - python-multipart:  0.0.6
        - pytorch-lightning: 2.0.0
        - pytz:              2022.7.1
        - pywavelets:        1.4.1
        - pyyaml:            6.0
        - readchar:          4.0.5
        - regex:             2023.3.23
        - requests:          2.28.1
        - requests-oauthlib: 1.3.1
        - resampy:           0.4.2
        - resemblyzer:       0.1.1.dev0
        - rfc3986:           1.5.0
        - rich:              13.3.2
        - rsa:               4.9
        - s3fs:              2023.3.0
        - scikit-image:      0.19.3
        - scikit-learn:      1.2.2
        - scipy:             1.9.3
        - setuptools:        65.6.3
        - six:               1.16.0
        - snakeviz:          2.1.1
        - sniffio:           1.3.0
        - soundfile:         0.12.1
        - soupsieve:         2.4
        - starlette:         0.22.0
        - starsessions:      1.3.0
        - tensorboard:       2.11.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.1
        - tensorboardx:      2.6
        - termcolor:         2.2.0
        - threadpoolctl:     3.1.0
        - tifffile:          2023.3.21
        - torch:             1.13.1
        - torchaudio:        0.13.1
        - torchcrepe:        0.0.17
        - torchmetrics:      0.11.4
        - torchvision:       0.14.1
        - tornado:           6.2
        - tqdm:              4.65.0
        - traitlets:         5.9.0
        - typing:            3.7.4.3
        - typing-extensions: 4.4.0
        - ujson:             5.7.0
        - urllib3:           1.26.14
        - uvicorn:           0.21.1
        - uvloop:            0.17.0
        - watchfiles:        0.18.1
        - wcwidth:           0.2.6
        - webrtcvad:         2.0.10
        - websocket-client:  1.5.1
        - websockets:        10.4
        - werkzeug:          2.2.3
        - wheel:             0.38.4
        - wrapt:             1.15.0
        - yarl:              1.8.2
        - zipp:              3.15.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.9.16
        - version:           #153-Ubuntu SMP Thu Nov 24 15:56:58 UTC 2022

More info

Other than this phenomenon, I have two more questions

  1. Why is val_check_interval tied to the number of batches rather than global_step?
  2. Why is validation re-run after loading a checkpoint just saved after the validation step? This is also going to produce a duplicate checkpoint, which is very frustrating

cc @carmocca @justusschock

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onloopsRelated to the Loop APIver: 2.0.x

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions