-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onloopsRelated to the Loop APIRelated to the Loop APIver: 2.0.x
Milestone
Description
Bug description
First of all, my task relies on step count instead of epochs. So I am doing validation checks by steps and saving checkpoints after that. However, as I turned gradient accumulation on, and the batch count is not divisible, I encountered weird drifts for the actual step when the validation is performed, and thus the checkpointing.
In the example below, I override the _save_checkpoint
function to monitor the actual file name and it turns out to be drifting. My general setting is val_check_interval=accumulation*5
to make it validate every 5 effective optimizer steps, accumulation=3
and #batches=67
so there is one batch leftover.
How to reproduce the bug
import numpy as np
import pathlib
import time
import torch
import torch.nn as nn
import torch.optim
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
class Quadratic(pl.LightningModule):
def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.tensor(0.0))
self.b = nn.Parameter(torch.tensor(0.0))
self.c = nn.Parameter(torch.tensor(0.0))
def forward(self, x):
time.sleep(0.02)
return self.a * x * x + self.b * x + self.c
def _common_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.mse_loss(y_hat, y)
return loss
def training_step(self, batch, batch_idx):
loss = self._common_step(batch, batch_idx)
return loss
def validation_step(self, batch, batch_idx):
loss = self._common_step(batch, batch_idx)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
class CustomModelCheckpoint(ModelCheckpoint):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
monitor_candidates = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self._monitor_candidates(trainer).items()}
print("\n", "Save checkpoint, global_step: ", trainer.global_step, pathlib.Path(filepath).stem, "monitor_candidates: " + str(monitor_candidates), "\n", flush=True)
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
# print("Remove checkpoint: ", filepath, flush=True)
pass
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-a', type=float, default=2.0)
parser.add_argument('-b', type=float, default=3.0)
parser.add_argument('-c', type=float, default=4.0)
parser.add_argument('--epoch', type=int, default=500)
args = parser.parse_args()
x = torch.from_numpy(np.random.uniform(-10, 10, 2144)).float() # Make 67 batches
y = args.a * x * x + args.b * x + args.c
x2 = torch.from_numpy(np.random.uniform(-10, 10, 100)).float()
y2 = args.a * x2 * x2 + args.b * x2 + args.c
dataset = torch.utils.data.TensorDataset(x, y)
val_dataset = torch.utils.data.TensorDataset(x2, y2)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
model = Quadratic()
####
accumulate_grad_batches = 3
val_check_interval = 5 * accumulate_grad_batches # to make interval for effective batches
####
trainer = pl.Trainer(max_epochs=args.epoch, accelerator='cpu', callbacks=[CustomModelCheckpoint(
dirpath='.',
filename='steps_{step}',
monitor='step',
mode='max',
save_last=False,
save_top_k=5
)],
val_check_interval=val_check_interval,
check_val_every_n_epoch=None,
num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches)
trainer.fit(model, dataloader, val_dataloader)
# Print the results
print("a = ", model.a.item())
print("b = ", model.b.item())
print("c = ", model.c.item())
Error messages and logs
Save checkpoint, global_step: 5 steps_step=5 monitor_candidates: {'epoch': 0, 'step': 5}
Save checkpoint, global_step: 10 steps_step=10 monitor_candidates: {'epoch': 0, 'step': 10}
Save checkpoint, global_step: 15 steps_step=15 monitor_candidates: {'epoch': 0, 'step': 15}
Save checkpoint, global_step: 20 steps_step=20 monitor_candidates: {'epoch': 0, 'step': 20}
Save checkpoint, global_step: 25 steps_step=25 monitor_candidates: {'epoch': 1, 'step': 25}
Save checkpoint, global_step: 30 steps_step=30 monitor_candidates: {'epoch': 1, 'step': 30}
Save checkpoint, global_step: 35 steps_step=35 monitor_candidates: {'epoch': 1, 'step': 35}
Save checkpoint, global_step: 40 steps_step=40 monitor_candidates: {'epoch': 1, 'step': 40}
Save checkpoint, global_step: 46 steps_step=46 monitor_candidates: {'epoch': 2, 'step': 46} <-- drift
Save checkpoint, global_step: 51 steps_step=51 monitor_candidates: {'epoch': 2, 'step': 51}
Save checkpoint, global_step: 56 steps_step=56 monitor_candidates: {'epoch': 2, 'step': 56}
Environment
Current environment
* CUDA:
- GPU:
- NVIDIA RTX A5000
- NVIDIA RTX A5000
- NVIDIA RTX A5000
- NVIDIA RTX A5000
- available: True
- version: 11.7
* Lightning:
- lightning: 2.0.0
- lightning-cloud: 0.5.32
- lightning-lite: 1.8.6
- lightning-utilities: 0.8.0
- pytorch-lightning: 2.0.0
- torch: 1.13.1
- torchaudio: 0.13.1
- torchcrepe: 0.0.17
- torchmetrics: 0.11.4
- torchvision: 0.14.1
* Packages:
- absl-py: 1.3.0
- aiobotocore: 2.4.2
- aiohttp: 3.8.4
- aioitertools: 0.11.0
- aiosignal: 1.3.1
- altgraph: 0.17.3
- anyio: 3.6.2
- appdirs: 1.4.4
- arrow: 1.2.3
- async-timeout: 4.0.2
- attrs: 22.2.0
- audioread: 3.0.0
- backcall: 0.2.0
- beautifulsoup4: 4.12.0
- blessed: 1.20.0
- blinker: 1.4
- botocore: 1.27.59
- brotlipy: 0.7.0
- cachetools: 5.3.0
- certifi: 2022.12.7
- cffi: 1.15.1
- charset-normalizer: 2.0.4
- click: 8.1.3
- contourpy: 1.0.7
- croniter: 1.3.8
- cryptography: 39.0.1
- cycler: 0.11.0
- dateutils: 0.6.12
- decorator: 5.1.1
- deepdiff: 6.3.0
- distance: 0.1.3
- dnspython: 2.3.0
- einops: 0.6.0
- email-validator: 1.3.1
- et-xmlfile: 1.0.1
- fastapi: 0.88.0
- fire: 0.5.0
- flit-core: 3.8.0
- fonttools: 4.39.2
- frozenlist: 1.3.3
- fsspec: 2023.3.0
- future: 0.18.2
- g2p-en: 2.1.0
- g2pm: 0.1.2.5
- google-auth: 2.16.3
- google-auth-oauthlib: 0.4.6
- grpcio: 1.51.3
- h11: 0.14.0
- h5py: 3.7.0
- httpcore: 0.16.3
- httptools: 0.5.0
- httpx: 0.23.3
- idna: 3.4
- imageio: 2.23.0
- importlib-metadata: 6.1.0
- inflect: 6.0.2
- inquirer: 3.1.3
- itsdangerous: 2.1.2
- jinja2: 3.1.2
- jmespath: 1.0.1
- joblib: 1.2.0
- kiwisolver: 1.4.4
- librosa: 0.9.1
- lightning: 2.0.0
- lightning-cloud: 0.5.32
- lightning-lite: 1.8.6
- lightning-utilities: 0.8.0
- llvmlite: 0.39.1
- markdown: 3.4.3
- markdown-it-py: 2.2.0
- markupsafe: 2.1.2
- matplotlib: 3.6.2
- mdurl: 0.1.2
- mkl-fft: 1.3.1
- mkl-random: 1.2.2
- mkl-service: 2.4.0
- multidict: 6.0.4
- networkx: 3.0
- nltk: 3.8.1
- numba: 0.56.4
- numpy: 1.23.5
- oauthlib: 3.2.2
- ordered-set: 4.1.0
- orjson: 3.8.8
- packaging: 23.0
- pillow: 9.4.0
- pip: 23.0.1
- platformdirs: 3.1.1
- pooch: 1.7.0
- praat-parselmouth: 0.4.3
- protobuf: 3.13.0
- psutil: 5.9.4
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pycparser: 2.21
- pycwt: 0.3.0a22
- pydantic: 1.10.7
- pygments: 2.14.0
- pyjwt: 2.6.0
- pyloudnorm: 0.1.0
- pyopenssl: 23.0.0
- pyparsing: 3.0.9
- pypinyin: 0.39.0
- pysocks: 1.7.1
- python-dateutil: 2.8.2
- python-dotenv: 1.0.0
- python-editor: 1.0.4
- python-levenshtein: 0.12.2
- python-multipart: 0.0.6
- pytorch-lightning: 2.0.0
- pytz: 2022.7.1
- pywavelets: 1.4.1
- pyyaml: 6.0
- readchar: 4.0.5
- regex: 2023.3.23
- requests: 2.28.1
- requests-oauthlib: 1.3.1
- resampy: 0.4.2
- resemblyzer: 0.1.1.dev0
- rfc3986: 1.5.0
- rich: 13.3.2
- rsa: 4.9
- s3fs: 2023.3.0
- scikit-image: 0.19.3
- scikit-learn: 1.2.2
- scipy: 1.9.3
- setuptools: 65.6.3
- six: 1.16.0
- snakeviz: 2.1.1
- sniffio: 1.3.0
- soundfile: 0.12.1
- soupsieve: 2.4
- starlette: 0.22.0
- starsessions: 1.3.0
- tensorboard: 2.11.0
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.1
- tensorboardx: 2.6
- termcolor: 2.2.0
- threadpoolctl: 3.1.0
- tifffile: 2023.3.21
- torch: 1.13.1
- torchaudio: 0.13.1
- torchcrepe: 0.0.17
- torchmetrics: 0.11.4
- torchvision: 0.14.1
- tornado: 6.2
- tqdm: 4.65.0
- traitlets: 5.9.0
- typing: 3.7.4.3
- typing-extensions: 4.4.0
- ujson: 5.7.0
- urllib3: 1.26.14
- uvicorn: 0.21.1
- uvloop: 0.17.0
- watchfiles: 0.18.1
- wcwidth: 0.2.6
- webrtcvad: 2.0.10
- websocket-client: 1.5.1
- websockets: 10.4
- werkzeug: 2.2.3
- wheel: 0.38.4
- wrapt: 1.15.0
- yarl: 1.8.2
- zipp: 3.15.0
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.16
- version: #153-Ubuntu SMP Thu Nov 24 15:56:58 UTC 2022
More info
Other than this phenomenon, I have two more questions
- Why is
val_check_interval
tied to the number of batches rather thanglobal_step
? - Why is validation re-run after loading a checkpoint just saved after the validation step? This is also going to produce a duplicate checkpoint, which is very frustrating
awaelchli and TianxingWu
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onloopsRelated to the Loop APIRelated to the Loop APIver: 2.0.x