Skip to content

num_training_batches is inf in configure_optimizers #16060

@davidgilbertson

Description

@davidgilbertson

Bug description

The value of num_training_batches is inf when referenced in configure_optimizers(). It seems that it doesn't actually get its correct value until some point later. This causes a very hard-to-find issue because the training runs without error, except the loss is nan.

Something inside optim.lr_scheduler.CyclicLR actually sets the lr of the optimizer to nan.

It would be nice if:

  • This value was available configure_optimizers() was called, or
  • There was a warning if accessing it before it's set

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        print(f"{optimizer.param_groups[0]['lr'] = }")  # 0.1
        lr_scheduler = torch.optim.lr_scheduler.CyclicLR(
            optimizer=optimizer,
            base_lr=0.01,
            max_lr=0.1,
            step_size_up=self.trainer.num_training_batches * 1,  # problematic!
            step_size_down=self.trainer.num_training_batches * 2,  # problematic!
            cycle_momentum=False,
        )
        print(f"{optimizer.param_groups[0]['lr'] = }")  # nan
        return [optimizer], [lr_scheduler]


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        enable_checkpointing=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Error messages and logs

The main hint something is wrong is actually tensorboard printing "NaN or Inf found in input tensor" - but even that doesn't come with a trace telling me who's printing this.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @justusschock @awaelchli @carmocca

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdata handlingGeneric data-related topicloopsRelated to the Loop API

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions