Skip to content

'self.current_epoch' always returns 0 when compiling LightningModule #17933

@ebschroeder

Description

@ebschroeder

Bug description

When I use "torch.compile()" on my LightningModule, retrieving the current epoch with "self.current_epoch" always returns 0. This is only when calling "self.current_epoch" from within "training_step()" and "validation_step()"; calling "self.current_epoch" from within "on_validation_epoch_end()" returns the correct value.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import lightning.pytorch as pl
import torch
import torch.nn.functional as F


class FooModule(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.model = torch.nn.Linear(10, 1)

    def training_step(self, batch, batch_idx):
        print(f"\ncurrent train_step epoch = {self.current_epoch}")
        x, y = batch
        y_hat = torch.squeeze(self.model(x).view(1, -1))
        loss = F.binary_cross_entropy_with_logits(y_hat, y.float())
        return loss

    def validation_step(self, batch, batch_idx):
        print(f"current val_step epoch = {self.current_epoch}")
        x, y = batch
        y_hat = torch.squeeze(self.model(x).view(1, -1))
        loss = F.binary_cross_entropy_with_logits(y_hat, y.float())

    def on_validation_epoch_end(self):
        print(f"current val_end epoch = {self.current_epoch}")

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4)


class FooDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.sample_tuples_list = []
        for i in range(2):
            data = torch.rand(10, dtype=torch.float16)

            if i % 2 == 0:
                truth_val = torch.tensor(0)
            else:
                truth_val = torch.tensor(1)

            self.sample_tuples_list.append((data, truth_val))

    def __len__(self):
        return len(self.sample_tuples_list)

    def __getitem__(self, idx):
        sample_tensor, truth_val = self.sample_tuples_list[idx]
        return (sample_tensor, truth_val)


def main():
    """
    Compiling the module results in 'self.current_epoch' (lines 13 & 20) to
    always return 0. Not compiling results in returning the current epoch.
    """
    module = torch.compile(FooModule())
    # module = FooModule()

    train_dataset = FooDataset()
    val_dataset = FooDataset()

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_dataset.__len__()
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=val_dataset.__len__()
    )

    trainer = pl.Trainer(
        accelerator="gpu",
        precision="16-mixed",
        logger=False,
        enable_checkpointing=False,
        num_sanity_val_steps=0,
        enable_progress_bar=False,
        max_epochs=4,
    )

    trainer.fit(
        model=module,
        train_dataloaders=train_dataloader,
        val_dataloaders=val_dataloader,
    )


if __name__ == "__main__":
    main()

Error messages and logs

current train epoch = 0
current val epoch = 0
current val_end epoch = 0

current train epoch = 0
current val epoch = 0
current val_end epoch = 1

current train epoch = 0
current val epoch = 0
current val_end epoch = 2

current train epoch = 0
current val epoch = 0
current val_end epoch = 3
Trainer.fit stopped: max_epochs=4 reached.

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A40
    - available: True
    - version: 11.8
  • Lightning:
    - lightning: 2.0.4
    - lightning-cloud: 0.5.36
    - lightning-utilities: 0.8.0
    - pytorch-lightning: 2.0.3
    - pytorchvideo: 0.1.5
    - torch: 2.0.1
    - torchaudio: 2.0.2
    - torchmetrics: 1.0.0rc0
    - torchvision: 0.15.2
  • Packages:
    - aiohttp: 3.8.4
    - aiosignal: 1.3.1
    - anyio: 3.7.0
    - appdirs: 1.4.4
    - arrow: 1.2.3
    - asteval: 0.9.30
    - asttokens: 2.2.1
    - async-timeout: 4.0.2
    - attrs: 23.1.0
    - av: 10.0.0
    - backcall: 0.2.0
    - backports.cached-property: 1.0.2
    - backports.functools-lru-cache: 1.6.4
    - beautifulsoup4: 4.12.2
    - blessed: 1.19.1
    - bottleneck: 1.3.5
    - brotlipy: 0.7.0
    - build: 0.10.0
    - cachecontrol: 0.12.14
    - certifi: 2023.5.7
    - cffi: 1.15.0
    - charset-normalizer: 3.1.0
    - cleo: 2.0.1
    - click: 8.1.3
    - colorama: 0.4.6
    - contourpy: 1.0.5
    - crashtest: 0.4.1
    - croniter: 1.3.15
    - cryptography: 39.0.1
    - cycler: 0.11.0
    - dateutils: 0.6.12
    - decorator: 5.1.1
    - deepdiff: 6.2.2
    - distlib: 0.3.6
    - docker-pycreds: 0.4.0
    - dulwich: 0.21.3
    - exceptiongroup: 1.1.1
    - executing: 1.2.0
    - fastapi: 0.98.0
    - filelock: 3.12.2
    - fonttools: 4.25.0
    - frozenlist: 1.3.3
    - fsspec: 2023.6.0
    - future: 0.18.3
    - fvcore: 0.1.5.post20221221
    - gitdb: 4.0.10
    - gitpython: 3.1.31
    - h11: 0.14.0
    - html5lib: 1.1
    - huggingface-hub: 0.15.1
    - idna: 3.4
    - imageio: 2.31.1
    - importlib-metadata: 6.7.0
    - importlib-resources: 5.12.0
    - inquirer: 3.1.3
    - installer: 0.7.0
    - iopath: 0.1.10
    - ipython: 8.14.0
    - itsdangerous: 2.1.2
    - jaraco.classes: 3.2.3
    - jedi: 0.18.2
    - jeepney: 0.8.0
    - jinja2: 3.1.2
    - jsonschema: 4.17.3
    - keyring: 23.13.1
    - kiwisolver: 1.4.4
    - lightning: 2.0.4
    - lightning-cloud: 0.5.36
    - lightning-utilities: 0.8.0
    - line-profiler: 3.4.0
    - lmfit: 1.2.1
    - lockfile: 0.12.2
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.1
    - matplotlib: 3.7.1
    - matplotlib-inline: 0.1.6
    - mdurl: 0.1.0
    - mkl-fft: 1.3.1
    - mkl-random: 1.2.2
    - mkl-service: 2.4.0
    - more-itertools: 9.1.0
    - mpmath: 1.3.0
    - msgpack: 1.0.3
    - multidict: 6.0.4
    - munkres: 1.1.4
    - networkx: 3.1
    - numexpr: 2.8.4
    - numpy: 1.24.3
    - ordered-set: 4.1.0
    - packaging: 23.1
    - pandas: 1.4.2
    - parameterized: 0.9.0
    - parso: 0.8.3
    - pathtools: 0.1.2
    - pexpect: 4.8.0
    - pickleshare: 0.7.5
    - pillow: 9.4.0
    - pip: 23.1.2
    - pkginfo: 1.9.6
    - pkgutil-resolve-name: 1.3.10
    - platformdirs: 3.6.0
    - ply: 3.11
    - poetry: 1.5.1
    - poetry-core: 1.6.1
    - poetry-plugin-export: 1.4.0
    - portalocker: 2.7.0
    - prompt-toolkit: 3.0.38
    - protobuf: 3.20.3
    - psutil: 5.9.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pycparser: 2.21
    - pydantic: 1.9.1
    - pygments: 2.15.1
    - pyjwt: 2.7.0
    - pyopenssl: 23.2.0
    - pyparsing: 3.1.0
    - pyproject-hooks: 1.0.0
    - pyqt5-sip: 12.11.0
    - pyrsistent: 0.18.1
    - pysocks: 1.7.1
    - python-dateutil: 2.8.2
    - python-editor: 1.0.4
    - python-multipart: 0.0.6
    - pytorch-lightning: 2.0.3
    - pytorchvideo: 0.1.5
    - pytz: 2023.3
    - pyyaml: 6.0
    - qjobs: 0.2.0
    - rapidfuzz: 2.13.7
    - readchar: 4.0.5.dev0
    - requests: 2.31.0
    - requests-toolbelt: 1.0.0
    - rich: 13.4.2
    - scipy: 1.8.1
    - secretstorage: 3.3.3
    - sentry-sdk: 1.21.1
    - setproctitle: 1.2.2
    - setuptools: 67.7.2
    - setuptools-scm: 7.1.0
    - shellingham: 1.5.1
    - sip: 6.6.2
    - six: 1.16.0
    - smmap: 3.0.5
    - sniffio: 1.3.0
    - soupsieve: 2.3.2.post1
    - stack-data: 0.6.2
    - starlette: 0.27.0
    - starsessions: 1.3.0
    - sympy: 1.12
    - tabulate: 0.9.0
    - termcolor: 2.3.0
    - timm: 0.6.13
    - toml: 0.10.2
    - tomli: 2.0.1
    - tomlkit: 0.11.8
    - torch: 2.0.1
    - torchaudio: 2.0.2
    - torchmetrics: 1.0.0rc0
    - torchvision: 0.15.2
    - tornado: 6.1
    - tqdm: 4.65.0
    - traitlets: 5.9.0
    - triton: 2.0.0
    - trove-classifiers: 2023.5.24
    - typing-extensions: 4.6.3
    - uncertainties: 3.1.7
    - urllib3: 1.26.15
    - uvicorn: 0.22.0
    - virtualenv: 20.23.1
    - wandb: 0.15.4
    - wcwidth: 0.2.6
    - webencodings: 0.5.1
    - websocket-client: 1.6.0
    - websockets: 10.3
    - wheel: 0.40.0
    - yacs: 0.1.8
    - yarl: 1.9.2
    - zipp: 3.15.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.11
    - release: 3.10.0-1160.83.1.el7.x86_64
    - version: 1 SMP Wed Jan 25 16:41:43 UTC 2023

More info

No response

cc @carmocca

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions