-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workinglr schedulerprecision: ampAutomatic Mixed PrecisionAutomatic Mixed Precisionpriority: 2Low priority taskLow priority task
Milestone
Description
🐛 Bug
When using mixed-precision training, scheduler and optimizer are called in the wrong order. Warning is generated:
UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.
Please reproduce using the BoringModel
https://colab.research.google.com/drive/1G7pk6E9XUYq-pS41DXKhqM9Srx8sikiP?usp=sharing
There are four tests. Three of them doesn't raise the warning:
- test_amp_scheduler(precision=16, configure_optimizers=configure_optimizers_1)
- test_amp_scheduler(precision=32, configure_optimizers=configure_optimizers_1)
- test_amp_scheduler(precision=32, configure_optimizers=configure_optimizers_2)
This testcase raises the warning:
- test_amp_scheduler(precision=16, configure_optimizers=configure_optimizers_2)
To Reproduce
- Create model with
configure_optimizers
in a following dictionary style:
def configure_optimizers_2(model):
optimizer = torch.optim.SGD(model.layer.parameters(), lr=0.1)
scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, step_size=1),
'name': 'learning_rate',
'interval':'step',
'frequency': 1}
return {"optimizer": optimizer, "lr_scheduler": scheduler}
- Enable mixed-precision training by setting
precision=16
in aTrainer
- Start training
Note
When scheduler is defined in another way, the issue seems to not occur:
def configure_optimizers_1(model):
optimizer = torch.optim.SGD(model.layer.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
Expected behavior
No warning
Environment
- CUDA:
- GPU:
- Tesla P100-PCIE-16GB
- available: True
- version: 10.1
- GPU:
- Packages:
- numpy: 1.19.5
- pyTorch_debug: True
- pyTorch_version: 1.7.0+cu101
- pytorch-lightning: 1.1.4
- tqdm: 4.41.1
- System:
- OS: Linux
- architecture:
- 64bit
- processor: x86_64
- python: 3.6.9
- version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020
cc @tchaton @rohitgr7 @carmocca @justusschock @awaelchli @akihironitta
sanxing-chen, MaveriQ, akihironitta, OverLordGoldDragon, ma-batita and 11 more
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinglr schedulerprecision: ampAutomatic Mixed PrecisionAutomatic Mixed Precisionpriority: 2Low priority taskLow priority task