From 00727e85605ae808439e82c2cd36a279d9cb2b03 Mon Sep 17 00:00:00 2001 From: Yunhoe Ku Date: Wed, 25 Jun 2025 11:26:41 +0900 Subject: [PATCH 1/2] add: `MultiModelDDPStrategy` and its execution codes --- .../generative_adversarial_net_ddp.py | 260 ++++++++++++++++++ src/lightning/pytorch/strategies/ddp.py | 50 +++- 2 files changed, 297 insertions(+), 13 deletions(-) create mode 100644 examples/pytorch/domain_templates/generative_adversarial_net_ddp.py diff --git a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py new file mode 100644 index 0000000000000..ba5e1d98b328a --- /dev/null +++ b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py @@ -0,0 +1,260 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""To run this template just do: python generative_adversarial_net.py. + +After a few epochs, launch TensorBoard to see the images being generated at every batch: + +tensorboard --logdir default + +""" +import math +from argparse import ArgumentParser, Namespace + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ! TESTING +import os +import sys + +sys.path.append(os.path.join(os.getcwd(), "src")) # noqa: E402 +# ! TESTING + +from lightning.pytorch import cli_lightning_logo +from lightning.pytorch.core import LightningModule +from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule +from lightning.pytorch.trainer import Trainer +from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE +from lightning.pytorch.strategies.ddp import DDPStrategy, MultiModelDDPStrategy + +if _TORCHVISION_AVAILABLE: + import torchvision + + +class Generator(nn.Module): + """ + >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Generator( + (model): Sequential(...) + ) + """ + + def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): + super().__init__() + self.img_shape = img_shape + + def block(in_feat, out_feat, normalize=True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *block(latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(math.prod(img_shape))), + nn.Tanh(), + ) + + def forward(self, z): + img = self.model(z) + return img.view(img.size(0), *self.img_shape) + + +class Discriminator(nn.Module): + """ + >>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Discriminator( + (model): Sequential(...) + ) + """ + + def __init__(self, img_shape): + super().__init__() + + self.model = nn.Sequential( + nn.Linear(int(math.prod(img_shape)), 512), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(512, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1), + ) + + def forward(self, img): + img_flat = img.view(img.size(0), -1) + return self.model(img_flat) + + +class GAN(LightningModule): + """ + >>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + GAN( + (generator): Generator( + (model): Sequential(...) + ) + (discriminator): Discriminator( + (model): Sequential(...) + ) + ) + """ + + def __init__( + self, + img_shape: tuple = (1, 28, 28), + lr: float = 0.0002, + b1: float = 0.5, + b2: float = 0.999, + latent_dim: int = 100, + ): + super().__init__() + self.save_hyperparameters() + self.automatic_optimization = False + + # networks + self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape) + self.discriminator = Discriminator(img_shape=img_shape) + + self.validation_z = torch.randn(8, self.hparams.latent_dim) + + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) + + # ! TESTING + self.save_path = "pl_test_multi_gpu" + os.makedirs(self.save_path, exist_ok=True) + + def forward(self, z): + return self.generator(z) + + @staticmethod + def adversarial_loss(y_hat, y): + return F.binary_cross_entropy_with_logits(y_hat, y) + + def training_step(self, batch): + imgs, _ = batch + + opt_g, opt_d = self.optimizers() + + # sample noise + z = torch.randn(imgs.shape[0], self.hparams.latent_dim) + z = z.type_as(imgs) + + # Train generator + # ground truth result (ie: all fake) + # put on GPU because we created this tensor inside training_loop + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + self.toggle_optimizer(opt_g) + # adversarial loss is binary cross-entropy + g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) + opt_g.zero_grad() + self.manual_backward(g_loss) + opt_g.step() + self.untoggle_optimizer(opt_g) + + # Train discriminator + # Measure discriminator's ability to classify real from generated samples + # how well can it label as real? + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + self.toggle_optimizer(opt_d) + real_loss = self.adversarial_loss(self.discriminator(imgs), valid) + + # how well can it label as fake? + fake = torch.zeros(imgs.size(0), 1) + fake = fake.type_as(imgs) + + fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake) + + # discriminator loss is the average of these + d_loss = (real_loss + fake_loss) / 2 + + opt_d.zero_grad() + self.manual_backward(d_loss) + opt_d.step() + self.untoggle_optimizer(opt_d) + + self.log_dict({"d_loss": d_loss, "g_loss": g_loss}) + + def configure_optimizers(self): + lr = self.hparams.lr + b1 = self.hparams.b1 + b2 = self.hparams.b2 + + opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) + return opt_g, opt_d + + # ! TESTING + def on_train_epoch_start(self): + if self.trainer.is_global_zero: + print("GEN: ", self.generator.module.model[0].bias[:10]) + print("DISC: ", self.discriminator.module.model[0].bias[:10]) + + # ! TESTING + def validation_step(self, batch, batch_idx): + pass + + # ! TESTING + @torch.no_grad() + def on_validation_epoch_end(self): + if self.current_epoch % 5: + self.generator.eval(), self.discriminator.eval() + + z = self.validation_z.type_as(self.generator.module.model[0].weight) + sample_imgs = self(z) + + if self.trainer.is_global_zero: + grid = torchvision.utils.make_grid(sample_imgs) + torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png")) + + self.generator.train(), self.discriminator.train() + + +def main(args: Namespace) -> None: + model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim) + + # ! `MultiModelDDPStrategy` is critical for multi-gpu training + # ! Otherwise, it will not work with multiple models. + # ! There are two ways to run training codes with previous `DDPStrategy`; + # ! 1) activate `find_unused_parameters=True`, 2) change from self.manual_backward(loss) to loss.backward() + # ! Neither of them is desirable. + dm = MNISTDataModule() + trainer = Trainer( + accelerator="auto", + devices=[0, 1, 2, 3], + strategy=MultiModelDDPStrategy(), + max_epochs=100, + ) + + trainer.fit(model, dm) + + +if __name__ == "__main__": + cli_lightning_logo() + parser = ArgumentParser() + + # Hyperparameters + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") + parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") + parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") + args = parser.parse_args() + + main(args) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd3f66ef42471..4e4bce82f2c5c 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -107,9 +107,7 @@ def __init__( @property def is_distributed(self) -> bool: # pragma: no-cover """Legacy property kept for backwards compatibility.""" - rank_zero_deprecation( - f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6 - ) + rank_zero_deprecation(f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6) return True @property @@ -229,9 +227,7 @@ def _register_ddp_hooks(self) -> None: def _enable_model_averaging(self) -> None: log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") if self._model_averaging_period is None: - raise ValueError( - "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." - ) + raise ValueError("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy.") from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer for optimizer in self.optimizers: @@ -240,10 +236,7 @@ def _enable_model_averaging(self) -> None: is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False if isinstance(optimizer, (ZeroRedundancyOptimizer, PostLocalSGDOptimizer)) or is_distributed_optimizer: - raise ValueError( - f"Currently model averaging cannot work with a distributed optimizer of type " - f"{optimizer.__class__.__name__}." - ) + raise ValueError(f"Currently model averaging cannot work with a distributed optimizer of type " f"{optimizer.__class__.__name__}.") assert self._ddp_comm_state is not None self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager( @@ -323,9 +316,7 @@ def model_to_device(self) -> None: self.model.to(self.root_device) @override - def reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: + def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -419,6 +410,39 @@ def teardown(self) -> None: super().teardown() +class MultiModelDDPStrategy(DDPStrategy): + @override + def _setup_model(self, model: Module) -> Module: + device_ids = self.determine_ddp_device_ids() + log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + with ctx: + for name, module in model.named_children(): + if isinstance(module, Module): + ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs) + setattr(model, name, ddp_module) + + return model + + @override + def _register_ddp_hooks(self) -> None: + log.debug(f"{self.__class__.__name__}: registering ddp hooks") + # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode + # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 + if self.root_device.type == "cuda": + assert isinstance(self.model, Module) + + for name, module in self.model.named_children(): + assert isinstance(module, DistributedDataParallel) + _register_ddp_comm_hook( + model=module, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + ) + + class _DDPForwardRedirection(_ForwardRedirection): @override def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: From e6b061afee0875490a9553f44cd7288df20209a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Jun 2025 02:45:24 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../generative_adversarial_net_ddp.py | 13 +++++++------ src/lightning/pytorch/strategies/ddp.py | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py index ba5e1d98b328a..7faec21cb8276 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py @@ -18,26 +18,27 @@ tensorboard --logdir default """ + import math + +# ! TESTING +import os +import sys from argparse import ArgumentParser, Namespace import torch import torch.nn as nn import torch.nn.functional as F -# ! TESTING -import os -import sys - -sys.path.append(os.path.join(os.getcwd(), "src")) # noqa: E402 +sys.path.append(os.path.join(os.getcwd(), "src")) # ! TESTING from lightning.pytorch import cli_lightning_logo from lightning.pytorch.core import LightningModule from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule +from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from lightning.pytorch.strategies.ddp import DDPStrategy, MultiModelDDPStrategy if _TORCHVISION_AVAILABLE: import torchvision diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 4e4bce82f2c5c..f69baa7ae2b13 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -107,7 +107,9 @@ def __init__( @property def is_distributed(self) -> bool: # pragma: no-cover """Legacy property kept for backwards compatibility.""" - rank_zero_deprecation(f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6) + rank_zero_deprecation( + f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6 + ) return True @property @@ -227,7 +229,9 @@ def _register_ddp_hooks(self) -> None: def _enable_model_averaging(self) -> None: log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") if self._model_averaging_period is None: - raise ValueError("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy.") + raise ValueError( + "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." + ) from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer for optimizer in self.optimizers: @@ -236,7 +240,10 @@ def _enable_model_averaging(self) -> None: is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False if isinstance(optimizer, (ZeroRedundancyOptimizer, PostLocalSGDOptimizer)) or is_distributed_optimizer: - raise ValueError(f"Currently model averaging cannot work with a distributed optimizer of type " f"{optimizer.__class__.__name__}.") + raise ValueError( + f"Currently model averaging cannot work with a distributed optimizer of type " + f"{optimizer.__class__.__name__}." + ) assert self._ddp_comm_state is not None self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager( @@ -316,7 +323,9 @@ def model_to_device(self) -> None: self.model.to(self.root_device) @override - def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor: + def reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: