-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.0.x
Description
Bug description
Hi, I'm trying to use the official PyTorch EMA (new feature: https://pytorch.org/docs/main/optim.html#weight-averaging-swa-and-ema) in Lightning in DDP. To do this I've done:
self.ema_model = torch.optim.swa_utils.AveragedModel(self.model, multi_avg_fn = torch.optim.swa_utils.get_ema_multi_avg_fn(ema_rate))
And then I update the EMA model after the optimizer steps, using the LightningModule hook:
def on_before_zero_grad(self, optimizer):
if self.ema_model is not None:
self.ema_model.update_parameters(self.model)
And then at the end of training update the BN params and replace the 'normal' model with the EMA model:
def on_train_end(self):
# Our final validation will use the ema model, as it replaces our normal model
if self.ema_model is not None:
logger.info("Updating the EMA model's BatchNormal layers...")
torch.optim.swa_utils.update_bn(self.trainer.train_dataloader, self.ema_model)
logger.info("Replacing the standard model with the EMA model for last validation run")
self.model = self.ema_model
The issue is I get an error because Lightning tries to DDP the self.ema_model in the LightningModule, and because it doesn't have any gradients, Pytorch throws an error:
RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
Do you know if I can/should just keep the self.ema_model on rank0, and stop it being distributed, like self.model?
What version are you seeing the problem on?
v2.0
How to reproduce the bug
As above
Error messages and logs
RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
angshine, carlthome and will-rice
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.0.x