-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Milestone
Description
Bug description
When I use "torch.compile()" on my LightningModule, retrieving the current epoch with "self.current_epoch" always returns 0. This is only when calling "self.current_epoch" from within "training_step()" and "validation_step()"; calling "self.current_epoch" from within "on_validation_epoch_end()" returns the correct value.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
class FooModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = torch.nn.Linear(10, 1)
def training_step(self, batch, batch_idx):
print(f"\ncurrent train_step epoch = {self.current_epoch}")
x, y = batch
y_hat = torch.squeeze(self.model(x).view(1, -1))
loss = F.binary_cross_entropy_with_logits(y_hat, y.float())
return loss
def validation_step(self, batch, batch_idx):
print(f"current val_step epoch = {self.current_epoch}")
x, y = batch
y_hat = torch.squeeze(self.model(x).view(1, -1))
loss = F.binary_cross_entropy_with_logits(y_hat, y.float())
def on_validation_epoch_end(self):
print(f"current val_end epoch = {self.current_epoch}")
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-4)
class FooDataset(torch.utils.data.Dataset):
def __init__(self):
self.sample_tuples_list = []
for i in range(2):
data = torch.rand(10, dtype=torch.float16)
if i % 2 == 0:
truth_val = torch.tensor(0)
else:
truth_val = torch.tensor(1)
self.sample_tuples_list.append((data, truth_val))
def __len__(self):
return len(self.sample_tuples_list)
def __getitem__(self, idx):
sample_tensor, truth_val = self.sample_tuples_list[idx]
return (sample_tensor, truth_val)
def main():
"""
Compiling the module results in 'self.current_epoch' (lines 13 & 20) to
always return 0. Not compiling results in returning the current epoch.
"""
module = torch.compile(FooModule())
# module = FooModule()
train_dataset = FooDataset()
val_dataset = FooDataset()
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=train_dataset.__len__()
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, batch_size=val_dataset.__len__()
)
trainer = pl.Trainer(
accelerator="gpu",
precision="16-mixed",
logger=False,
enable_checkpointing=False,
num_sanity_val_steps=0,
enable_progress_bar=False,
max_epochs=4,
)
trainer.fit(
model=module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
if __name__ == "__main__":
main()
Error messages and logs
current train epoch = 0
current val epoch = 0
current val_end epoch = 0
current train epoch = 0
current val epoch = 0
current val_end epoch = 1
current train epoch = 0
current val epoch = 0
current val_end epoch = 2
current train epoch = 0
current val epoch = 0
current val_end epoch = 3
Trainer.fit
stopped: max_epochs=4
reached.
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A40
- available: True
- version: 11.8 - Lightning:
- lightning: 2.0.4
- lightning-cloud: 0.5.36
- lightning-utilities: 0.8.0
- pytorch-lightning: 2.0.3
- pytorchvideo: 0.1.5
- torch: 2.0.1
- torchaudio: 2.0.2
- torchmetrics: 1.0.0rc0
- torchvision: 0.15.2 - Packages:
- aiohttp: 3.8.4
- aiosignal: 1.3.1
- anyio: 3.7.0
- appdirs: 1.4.4
- arrow: 1.2.3
- asteval: 0.9.30
- asttokens: 2.2.1
- async-timeout: 4.0.2
- attrs: 23.1.0
- av: 10.0.0
- backcall: 0.2.0
- backports.cached-property: 1.0.2
- backports.functools-lru-cache: 1.6.4
- beautifulsoup4: 4.12.2
- blessed: 1.19.1
- bottleneck: 1.3.5
- brotlipy: 0.7.0
- build: 0.10.0
- cachecontrol: 0.12.14
- certifi: 2023.5.7
- cffi: 1.15.0
- charset-normalizer: 3.1.0
- cleo: 2.0.1
- click: 8.1.3
- colorama: 0.4.6
- contourpy: 1.0.5
- crashtest: 0.4.1
- croniter: 1.3.15
- cryptography: 39.0.1
- cycler: 0.11.0
- dateutils: 0.6.12
- decorator: 5.1.1
- deepdiff: 6.2.2
- distlib: 0.3.6
- docker-pycreds: 0.4.0
- dulwich: 0.21.3
- exceptiongroup: 1.1.1
- executing: 1.2.0
- fastapi: 0.98.0
- filelock: 3.12.2
- fonttools: 4.25.0
- frozenlist: 1.3.3
- fsspec: 2023.6.0
- future: 0.18.3
- fvcore: 0.1.5.post20221221
- gitdb: 4.0.10
- gitpython: 3.1.31
- h11: 0.14.0
- html5lib: 1.1
- huggingface-hub: 0.15.1
- idna: 3.4
- imageio: 2.31.1
- importlib-metadata: 6.7.0
- importlib-resources: 5.12.0
- inquirer: 3.1.3
- installer: 0.7.0
- iopath: 0.1.10
- ipython: 8.14.0
- itsdangerous: 2.1.2
- jaraco.classes: 3.2.3
- jedi: 0.18.2
- jeepney: 0.8.0
- jinja2: 3.1.2
- jsonschema: 4.17.3
- keyring: 23.13.1
- kiwisolver: 1.4.4
- lightning: 2.0.4
- lightning-cloud: 0.5.36
- lightning-utilities: 0.8.0
- line-profiler: 3.4.0
- lmfit: 1.2.1
- lockfile: 0.12.2
- markdown-it-py: 3.0.0
- markupsafe: 2.1.1
- matplotlib: 3.7.1
- matplotlib-inline: 0.1.6
- mdurl: 0.1.0
- mkl-fft: 1.3.1
- mkl-random: 1.2.2
- mkl-service: 2.4.0
- more-itertools: 9.1.0
- mpmath: 1.3.0
- msgpack: 1.0.3
- multidict: 6.0.4
- munkres: 1.1.4
- networkx: 3.1
- numexpr: 2.8.4
- numpy: 1.24.3
- ordered-set: 4.1.0
- packaging: 23.1
- pandas: 1.4.2
- parameterized: 0.9.0
- parso: 0.8.3
- pathtools: 0.1.2
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.4.0
- pip: 23.1.2
- pkginfo: 1.9.6
- pkgutil-resolve-name: 1.3.10
- platformdirs: 3.6.0
- ply: 3.11
- poetry: 1.5.1
- poetry-core: 1.6.1
- poetry-plugin-export: 1.4.0
- portalocker: 2.7.0
- prompt-toolkit: 3.0.38
- protobuf: 3.20.3
- psutil: 5.9.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pycparser: 2.21
- pydantic: 1.9.1
- pygments: 2.15.1
- pyjwt: 2.7.0
- pyopenssl: 23.2.0
- pyparsing: 3.1.0
- pyproject-hooks: 1.0.0
- pyqt5-sip: 12.11.0
- pyrsistent: 0.18.1
- pysocks: 1.7.1
- python-dateutil: 2.8.2
- python-editor: 1.0.4
- python-multipart: 0.0.6
- pytorch-lightning: 2.0.3
- pytorchvideo: 0.1.5
- pytz: 2023.3
- pyyaml: 6.0
- qjobs: 0.2.0
- rapidfuzz: 2.13.7
- readchar: 4.0.5.dev0
- requests: 2.31.0
- requests-toolbelt: 1.0.0
- rich: 13.4.2
- scipy: 1.8.1
- secretstorage: 3.3.3
- sentry-sdk: 1.21.1
- setproctitle: 1.2.2
- setuptools: 67.7.2
- setuptools-scm: 7.1.0
- shellingham: 1.5.1
- sip: 6.6.2
- six: 1.16.0
- smmap: 3.0.5
- sniffio: 1.3.0
- soupsieve: 2.3.2.post1
- stack-data: 0.6.2
- starlette: 0.27.0
- starsessions: 1.3.0
- sympy: 1.12
- tabulate: 0.9.0
- termcolor: 2.3.0
- timm: 0.6.13
- toml: 0.10.2
- tomli: 2.0.1
- tomlkit: 0.11.8
- torch: 2.0.1
- torchaudio: 2.0.2
- torchmetrics: 1.0.0rc0
- torchvision: 0.15.2
- tornado: 6.1
- tqdm: 4.65.0
- traitlets: 5.9.0
- triton: 2.0.0
- trove-classifiers: 2023.5.24
- typing-extensions: 4.6.3
- uncertainties: 3.1.7
- urllib3: 1.26.15
- uvicorn: 0.22.0
- virtualenv: 20.23.1
- wandb: 0.15.4
- wcwidth: 0.2.6
- webencodings: 0.5.1
- websocket-client: 1.6.0
- websockets: 10.3
- wheel: 0.40.0
- yacs: 0.1.8
- yarl: 1.9.2
- zipp: 3.15.0 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.11
- release: 3.10.0-1160.83.1.el7.x86_64
- version: 1 SMP Wed Jan 25 16:41:43 UTC 2023
More info
No response
cc @carmocca
awaelchli