Skip to content

New official Pytorch EMA gives RunTimeError when using Lightning with DDP #17847

@jphdotam

Description

@jphdotam

Bug description

Hi, I'm trying to use the official PyTorch EMA (new feature: https://pytorch.org/docs/main/optim.html#weight-averaging-swa-and-ema) in Lightning in DDP. To do this I've done:

self.ema_model = torch.optim.swa_utils.AveragedModel(self.model, multi_avg_fn = torch.optim.swa_utils.get_ema_multi_avg_fn(ema_rate))

And then I update the EMA model after the optimizer steps, using the LightningModule hook:

def on_before_zero_grad(self, optimizer):
    if self.ema_model is not None:
        self.ema_model.update_parameters(self.model)

And then at the end of training update the BN params and replace the 'normal' model with the EMA model:

def on_train_end(self):
    # Our final validation will use the ema model, as it replaces our normal model
    if self.ema_model is not None:
        logger.info("Updating the EMA model's BatchNormal layers...")
        torch.optim.swa_utils.update_bn(self.trainer.train_dataloader, self.ema_model)
        logger.info("Replacing the standard model with the EMA model for last validation run")
        self.model = self.ema_model

The issue is I get an error because Lightning tries to DDP the self.ema_model in the LightningModule, and because it doesn't have any gradients, Pytorch throws an error:

RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.

Do you know if I can/should just keep the self.ema_model on rank0, and stop it being distributed, like self.model?

What version are you seeing the problem on?

v2.0

How to reproduce the bug

As above

Error messages and logs

RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @justusschock @awaelchli

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions