From c8d50bdef6b93ed1b97e4f7c84d42a00f411f53c Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 9 Jan 2025 12:28:08 +0100 Subject: [PATCH 1/9] Weight averaging callback * A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs. * The user can provide a callback that defines after which steps or epochs the average model is updated. --- docs/source-pytorch/api_references.rst | 1 + docs/source-pytorch/extensions/callbacks.rst | 1 + src/lightning/pytorch/callbacks/__init__.py | 2 + .../pytorch/callbacks/weight_averaging.py | 288 ++++++++++++++++++ .../utilities/test_distributed.py | 2 + .../callbacks/test_weight_averaging.py | 288 ++++++++++++++++++ 6 files changed, 582 insertions(+) create mode 100644 src/lightning/pytorch/callbacks/weight_averaging.py create mode 100644 tests/tests_pytorch/callbacks/test_weight_averaging.py diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 1f58f6ac23dd5..278cc98ef5547 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -48,6 +48,7 @@ callbacks ThroughputMonitor Timer TQDMProgressBar + WeightAveraging cli ----- diff --git a/docs/source-pytorch/extensions/callbacks.rst b/docs/source-pytorch/extensions/callbacks.rst index c2a621f8b6d7b..7ed285591c4dc 100644 --- a/docs/source-pytorch/extensions/callbacks.rst +++ b/docs/source-pytorch/extensions/callbacks.rst @@ -83,6 +83,7 @@ Lightning has a few built-in callbacks. StochasticWeightAveraging Timer TQDMProgressBar + WeightAveraging ---------- diff --git a/src/lightning/pytorch/callbacks/__init__.py b/src/lightning/pytorch/callbacks/__init__.py index 9ee34f3866b27..d0ffb7b6a990c 100644 --- a/src/lightning/pytorch/callbacks/__init__.py +++ b/src/lightning/pytorch/callbacks/__init__.py @@ -32,6 +32,7 @@ from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor from lightning.pytorch.callbacks.timer import Timer +from lightning.pytorch.callbacks.weight_averaging import WeightAveraging __all__ = [ "BackboneFinetuning", @@ -58,4 +59,5 @@ "ThroughputMonitor", "Timer", "TQDMProgressBar", + "WeightAveraging", ] diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py new file mode 100644 index 0000000000000..1595e7f987cf8 --- /dev/null +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -0,0 +1,288 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Weight Averaging Callback +^^^^^^^^^^^^^^^^^^^^^^^^^ +""" + +import itertools +from copy import deepcopy +from typing import Any, Callable, Optional, Union + +import torch +from torch import Tensor +from torch.optim.swa_utils import AveragedModel + +import lightning.pytorch as pl +from lightning.pytorch.callbacks.callback import Callback +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn +from lightning.pytorch.utilities.types import STEP_OUTPUT + + +def _return_true(x: int) -> bool: + return True + + +def _return_false(x: int) -> bool: + return False + + +class WeightAveraging(Callback): + r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average + (EMA) after each training step. + + The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average + model should be updated. If neither function is provided, the average model will be updated after every optimizer + step. + + During validation and after the training finishes, the current model parameters will be replaced with the averaged + values. + + Args: + device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be + inferred from the original model. + avg_fn: The averaging function used to update the parameters. The function must take in an + :class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If + ``None``, an equally weighted average will be used. + update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average + model should be updated. + update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model + should be updated. + + """ + + def __init__( + self, + device: Optional[Union[torch.device, int]] = torch.device("cpu"), + avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None, + update_on_step: Optional[Callable[[int], bool]] = None, + update_on_epoch: Optional[Callable[[int], bool]] = None, + ): + self._device = device + self._avg_fn = avg_fn + + if (update_on_step is None) and (update_on_epoch is None): + self._update_on_step: Callable[[int], bool] = _return_true + self._update_on_epoch: Callable[[int], bool] = _return_false + else: + self._update_on_step = _return_false if update_on_step is None else update_on_step + self._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch + + self._average_model: Optional[AveragedModel] = None + + # Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures + # that the average model will be first updated after the first optimizer step, which takes place after N batches + # when using accumulate_grad_batches=N. + self._latest_update_step = 0 + # The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a + # negative value means that if update_on_step(0) returns True, the first update is after the first epoch. + self._latest_update_epoch = -1 + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + """Called when fit, validate, test, predict, or tune begins. + + Creates an :class:`AveragedModel` when fit begins. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + stage: The :class:`~lightning.pytorch.trainer.trainer.Trainer` state. + + """ + if stage == "fit": + device = self._device or pl_module.device + self._average_model = AveragedModel(model=pl_module, device=device, avg_fn=self._avg_fn, use_buffers=True) + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + """Called when a training batch ends. + + Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()``. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + outputs: Outputs from the training batch. + batch: The training batch. + batch_idx: Index of the training batch. + + """ + if self._update_on_step(trainer.global_step) and (trainer.global_step > self._latest_update_step): + assert self._average_model is not None + self._average_model.update_parameters(pl_module) + self._latest_update_step = trainer.global_step + + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when a training epoch ends. + + Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + if self._update_on_epoch(trainer.current_epoch) and (trainer.current_epoch > self._latest_update_epoch): + assert self._average_model is not None + self._average_model.update_parameters(pl_module) + self._latest_update_epoch = trainer.current_epoch + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when training ends. + + Transfers parameters from the :class:`AveragedModel` to the current model. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + assert self._average_model is not None + self._copy_average_to_current(pl_module) + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when a validation epoch begins. + + Transfers parameter values from the :class:`AveragedModel` to the current model. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + if self._average_model is not None: + rank_zero_info("Loading the average model parameters for validation.") + self._swap_models(pl_module) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when a validation epoch ends. + + Recovers the current model parameters from the :class:`AveragedModel`. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + if self._average_model is not None: + rank_zero_info("Recovering the current model parameters after validation.") + self._swap_models(pl_module) + + def state_dict(self) -> dict[str, Any]: + """Called when saving a checkpoint. + + Creates a ``state_dict`` of the callback state. + + Returns: + A dictionary containing the callback state. + + """ + return {"latest_update_step": self._latest_update_step} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Called when loading a checkpoint. + + Reloads the callback state given a ``state_dict``. + + Args: + state_dict: A dictionary containing the callback state. + + """ + self._latest_update_step = state_dict["latest_update_step"] + + def on_save_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] + ) -> None: + r"""Called when saving a checkpoint. + + Moves the current model state to the key ``current_model_state``, and places the average model state in + ``state_dict`` instead. Any other state variables of the ``AveragedModel`` will be saved in + ``averaging_state``. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + checkpoint: The checkpoint dictionary that will be saved. + + """ + if self._average_model is None: + raise Exception("Trying to save a checkpoint, but no average model (outside fit). Don't know what to do.") + + rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.") + average_model_state = self._average_model.state_dict() + checkpoint["current_model_state"] = checkpoint["state_dict"] + checkpoint["state_dict"] = { + name[7:]: value for name, value in average_model_state.items() if name.startswith("module.") + } + checkpoint["averaging_state"] = { + name: value for name, value in average_model_state.items() if not name.startswith("module.") + } + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] + ) -> None: + r"""Called when loading a model checkpoint. + + Loads the current model and the :class:`AveragedModel` parameters from the checkpoint. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + checkpoint: The full checkpoint dictionary that got loaded by the Trainer. + + """ + if self._average_model is None: + raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.") + + if ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint): + rank_zero_info("Found current_model_state in the checkpoint. This will be used to initialize the model.") + average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()} + average_model_state |= checkpoint["averaging_state"] + self._average_model.load_state_dict(average_model_state) + checkpoint["state_dict"] = checkpoint["current_model_state"] + else: + rank_zero_warn( + "The checkpoint was not created with WeightAveraging. Both the current and the average model will be " + "initialized with state_dict." + ) + self._average_model.module.load_state_dict(deepcopy(checkpoint["state_dict"]), strict=False) + + def _swap_models(self, pl_module: "pl.LightningModule") -> None: + """Swaps the parameter values of the current model and the :class:`AveragedModel`. + + Args: + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + assert self._average_model is not None + average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers()) + current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) + for average_param, current_param in zip(average_params, current_params): + tmp = average_param.data.clone() + average_param.data.copy_(current_param.data) + current_param.data.copy_(tmp) + + def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None: + """Copies the parameter values from the :class:`AveragedModel` to the current model. + + Args: + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + assert self._average_model is not None + average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers()) + current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) + for average_param, current_param in zip(average_params, current_params): + current_param.data.copy_(average_param.data) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 3450ed89f6cc7..9282f00f1ffb6 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -105,6 +105,8 @@ def _test_all_reduce(strategy): assert result is tensor # inplace +# flaky with "process 0 terminated with signal SIGABRT" (GLOO) +@pytest.mark.flaky(reruns=3, only_rerun="torch.multiprocessing.spawn.ProcessExitedException") @RunIf(skip_windows=True) @pytest.mark.parametrize( "process", diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py new file mode 100644 index 0000000000000..4dc28a9b71c6b --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -0,0 +1,288 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path +from typing import Any, Optional + +import pytest +import torch +from torch import Tensor, nn +from torch.optim.swa_utils import get_swa_avg_fn +from torch.utils.data import DataLoader + +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import WeightAveraging +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset +from tests_pytorch.helpers.runif import RunIf + + +class WeightAveragingTestModel(BoringModel): + def __init__( + self, batch_norm: bool = True, iterable_dataset: bool = False, crash_on_epoch: Optional[int] = None + ) -> None: + super().__init__() + layers = [nn.Linear(32, 32)] + if batch_norm: + layers.append(nn.BatchNorm1d(32)) + layers += [nn.ReLU(), nn.Linear(32, 2)] + self.layer = nn.Sequential(*layers) + self.iterable_dataset = iterable_dataset + self.crash_on_epoch = crash_on_epoch + + def training_step(self, batch: Tensor, batch_idx: int) -> None: + if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch: + raise Exception("CRASH TEST") + return super().training_step(batch, batch_idx) + + def train_dataloader(self) -> None: + dataset_class = RandomIterableDataset if self.iterable_dataset else RandomDataset + return DataLoader(dataset_class(32, 32), batch_size=4) + + def configure_optimizers(self) -> None: + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +class EMAAveragingFunction: + """EMA averaging function. + + Functionally equivalent to the closure that ``get_ema_avg_fn()`` would return. This class is needed because we + cannot use a closure with ddp_spawn. (``Popen(process_obj)`` would fail with + ``Can't get local object 'get_ema_avg_fn..ema_update'``). + + """ + + def __init__(self, decay: float = 0.999) -> None: + self.decay = decay + + @torch.no_grad() + def __call__(self, ema_param: Tensor, current_param: Tensor, num_averaged: Tensor) -> Tensor: + return self.decay * ema_param + (1 - self.decay) * current_param + + +class EMATestCallback(WeightAveraging): + def __init__(self, devices: int = 1, **kwargs: Any) -> None: + super().__init__(avg_fn=EMAAveragingFunction(), **kwargs) + self.devices = devices + self.swap_calls = 0 + self.copy_calls = 0 + # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0. + self.first_epoch: Optional[int] = None + + def _swap_models(self, *args: Any, **kwargs: Any): + self.swap_calls += 1 + return super()._swap_models(*args, **kwargs) + + def _copy_average_to_current(self, *args: Any, **kwargs: Any): + self.copy_calls += 1 + return super()._copy_average_to_current(*args, **kwargs) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_start(trainer, pl_module) + assert self.swap_calls == 0 + assert self.copy_calls == 0 + + def on_train_epoch_start(self, trainer: Trainer, *args: Any) -> None: + super().on_train_epoch_start(trainer, *args) + # Since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will not update the + # model and will just call the epoch-level hooks. For that reason, we check that we are not restarting before + # choosing the first epoch. + if self.first_epoch is None and not trainer.fit_loop.restarting: + self.first_epoch = trainer.current_epoch + + def on_train_epoch_end(self, trainer: Trainer, *args: Any) -> None: + super().on_train_epoch_end(trainer, *args) + assert self._average_model.n_averaged == trainer.global_step + assert self.swap_calls == (trainer.current_epoch + 1 - self.first_epoch) * 2 + assert self.copy_calls == 0 + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_end(trainer, pl_module) + # length=32, batch_size=4, accumulate_grad_batches=2 + # => Using one process we have 4 optimizer steps per epoch. + # => Using two processes we have 2 optimizer steps per epoch. + steps_per_epoch = 4 // self.devices + assert self._average_model.n_averaged == trainer.max_epochs * steps_per_epoch + assert self.swap_calls == (trainer.max_epochs - self.first_epoch) * 2 + assert self.copy_calls == 1 + + +class SWATestCallback(WeightAveraging): + def __init__(self, **kwargs: Any) -> None: + avg_fn = get_swa_avg_fn() + update_on_epoch = lambda x: x in (3, 5, 7) + super().__init__(avg_fn=avg_fn, update_on_epoch=update_on_epoch, **kwargs) + + self.swap_calls = 0 + self.copy_calls = 0 + # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0. + self.first_epoch: Optional[int] = None + + def _swap_models(self, *args: Any, **kwargs: Any): + self.swap_calls += 1 + return super()._swap_models(*args, **kwargs) + + def _copy_average_to_current(self, *args: Any, **kwargs: Any): + self.copy_calls += 1 + return super()._copy_average_to_current(*args, **kwargs) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_start(trainer, pl_module) + assert self.swap_calls == 0 + assert self.copy_calls == 0 + + def on_train_epoch_start(self, trainer: Trainer, *args: Any) -> None: + super().on_train_epoch_start(trainer, *args) + # Since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will not update the + # model and will just call the epoch-level hooks. For that reason, we check that we are not restarting before + # choosing the first epoch. + if self.first_epoch is None and not trainer.fit_loop.restarting: + self.first_epoch = trainer.current_epoch + + def on_train_epoch_end(self, trainer: Trainer, *args: Any) -> None: + super().on_train_epoch_end(trainer, *args) + if trainer.current_epoch < 3: + assert self._average_model.n_averaged == 0 + elif trainer.current_epoch < 5: + assert self._average_model.n_averaged == 1 + elif trainer.current_epoch < 7: + assert self._average_model.n_averaged == 2 + else: + assert self._average_model.n_averaged == 3 + assert self.swap_calls == (trainer.current_epoch + 1 - self.first_epoch) * 2 + assert self.copy_calls == 0 + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_end(trainer, pl_module) + assert self._average_model.n_averaged == 3 + assert self.swap_calls == (trainer.max_epochs - self.first_epoch) * 2 + assert self.copy_calls == 1 + + +def test_weight_averaging_deepcopy(tmp_path): + """Ensure that WeightAveraging callback doesn't deepcopy the data loaders or the data module and consume memory + more than necessary.""" + + class TestCallback(WeightAveraging): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setup_called = False + + def setup(self, trainer, pl_module, stage) -> None: + super().setup(trainer, pl_module, stage) + assert self._average_model.module.train_dataloader is not pl_module.train_dataloader + assert self._average_model.module.train_dataloader.__self__ == self._average_model.module + assert self._average_model.module._trainer is None + self.setup_called = True + + callback = TestCallback() + trainer = Trainer(default_root_dir=tmp_path, callbacks=callback, fast_dev_run=True) + trainer.fit(BoringModel(), train_dataloaders=DataLoader(RandomDataset(32, 2))) + assert callback.setup_called + + +@pytest.mark.parametrize("batch_norm", [True, False]) +@pytest.mark.parametrize("iterable_dataset", [True, False]) +def test_ema(tmp_path, batch_norm: bool, iterable_dataset: bool): + _train(tmp_path, EMATestCallback(), batch_norm=batch_norm, iterable_dataset=iterable_dataset) + + +@pytest.mark.parametrize( + "accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))] +) +def test_ema_accelerator(tmp_path, accelerator): + _train(tmp_path, EMATestCallback(), accelerator=accelerator, devices=1) + + +@RunIf(min_cuda_gpus=2, standalone=True) +def test_ema_ddp(tmp_path): + _train(tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2) + + +@RunIf(min_cuda_gpus=2) +def test_ema_ddp_spawn(tmp_path): + _train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2) + + +@RunIf(skip_windows=True) +def test_ema_ddp_spawn_cpu(tmp_path): + _train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2) + + +@pytest.mark.parametrize("crash_on_epoch", [1, 3]) +def test_ema_resume(tmp_path, crash_on_epoch): + _train_and_resume(tmp_path, crash_on_epoch=crash_on_epoch) + + +@RunIf(skip_windows=True) +def test_ema_resume_ddp(tmp_path): + _train_and_resume(tmp_path, crash_on_epoch=3, use_ddp=True) + + +def test_swa(tmp_path): + _train(tmp_path, SWATestCallback()) + + +def _train( + tmp_path: str, + callback: WeightAveraging, + batch_norm: bool = True, + strategy: str = "auto", + accelerator: str = "cpu", + devices: int = 1, + iterable_dataset: bool = False, + checkpoint_path: Optional[str] = None, + crash_on_epoch: Optional[int] = None, +) -> None: + trainer = Trainer( + default_root_dir=tmp_path, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + max_epochs=8, + num_sanity_val_steps=0, + callbacks=callback, + accumulate_grad_batches=2, + strategy=strategy, + accelerator=accelerator, + devices=devices, + ) + model = WeightAveragingTestModel( + batch_norm=batch_norm, iterable_dataset=iterable_dataset, crash_on_epoch=crash_on_epoch + ) + + if crash_on_epoch is None: + trainer.fit(model, ckpt_path=checkpoint_path) + else: + with pytest.raises(Exception, match="CRASH TEST"): + trainer.fit(model, ckpt_path=checkpoint_path) + + assert trainer.lightning_module == model + + +def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False) -> None: + strategy = "ddp_spawn" if use_ddp else "auto" + devices = 2 if use_ddp else 1 + + _train( + tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, crash_on_epoch=crash_on_epoch + ) + + checkpoint_dir = Path(tmp_path) / "checkpoints" + checkpoint_names = os.listdir(checkpoint_dir) + assert len(checkpoint_names) == 1 + checkpoint_path = str(checkpoint_dir / checkpoint_names[0]) + + _train( + tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, checkpoint_path=checkpoint_path + ) From 99b6638fc113270a897d3a07fb6f12f0d1c536ec Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 4 Feb 2025 16:58:18 +0200 Subject: [PATCH 2/9] More generic customization of the WeightAveraging callback - The user can specify when to update the average model by overriding the should_update() method - Any keyword arguments will be passed to the AveragedModel constructor --- src/lightning/pytorch/CHANGELOG.md | 12 ++ .../pytorch/callbacks/weight_averaging.py | 166 ++++++++++++------ .../callbacks/test_weight_averaging.py | 124 +++++++------ 3 files changed, 191 insertions(+), 111 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9f7317c218c30..877b507fea9e9 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - YYYY-MM-DD + +### Added + +- WeightAveraging callback that wraps the PyTorch AveragedModel class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545)) + +### Changed + +### Removed + +### Fixed + ## [2.5.0] - 2024-12-19 ### Added diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 1595e7f987cf8..4b4df2c095fc3 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -18,11 +18,11 @@ import itertools from copy import deepcopy -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch -from torch import Tensor from torch.optim.swa_utils import AveragedModel +from typing_extensions import override import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback @@ -30,65 +30,97 @@ from lightning.pytorch.utilities.types import STEP_OUTPUT -def _return_true(x: int) -> bool: - return True - - -def _return_false(x: int) -> bool: - return False - - class WeightAveraging(Callback): r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) after each training step. - The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average - model should be updated. If neither function is provided, the average model will be updated after every optimizer - step. + Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of + differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set + to ``None``, the device will be inferred from the original model. By default, the callback will compute running + averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only + the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using + ``torch.optim.swa_utils.update_bn()``). + + You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the + :class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the + equally-weighted average of the weights (SWA). + + You can customize when the average model is updated by overriding the ``should_update()`` method. The callback calls + it with either ``step_idx`` or ``epoch_idx`` and the method returns a boolean indicating whether to update after the + given step or epoch. The default is to update after every step. During validation and after the training finishes, the current model parameters will be replaced with the averaged values. + Example:: + + from lightning.pytorch.callbacks import WeightAveraging + from torch.optim.swa_utils import get_ema_avg_fn + + class EMAWeightAveraging(WeightAveraging): + def __init__(self): + super().__init__(avg_fn=get_ema_avg_fn()) + + def should_update(self, step_idx=None, epoch_idx=None): + # Start after 100 steps. + return (step_idx is not None) and (step_idx >= 100) + + trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10) + trainer.fit(model, dataloader) + Args: device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be inferred from the original model. - avg_fn: The averaging function used to update the parameters. The function must take in an - :class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If - ``None``, an equally weighted average will be used. - update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average - model should be updated. - update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model - should be updated. + use_buffers: If ``False``, the buffers of the model will not be averaged. + kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn`` + or ``multi_avg_fn``. """ def __init__( self, - device: Optional[Union[torch.device, int]] = torch.device("cpu"), - avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None, - update_on_step: Optional[Callable[[int], bool]] = None, - update_on_epoch: Optional[Callable[[int], bool]] = None, - ): - self._device = device - self._avg_fn = avg_fn - - if (update_on_step is None) and (update_on_epoch is None): - self._update_on_step: Callable[[int], bool] = _return_true - self._update_on_epoch: Callable[[int], bool] = _return_false + device: Optional[Union[torch.device, str, int]] = "cpu", + use_buffers: bool = True, + **kwargs: Any, + ) -> None: + # The default value is a string so that jsonargparse knows how to serialize it. + if isinstance(device, str): + self._device: Optional[Union[torch.device, int]] = torch.device(device) else: - self._update_on_step = _return_false if update_on_step is None else update_on_step - self._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch + self._device = device + self._use_buffers = use_buffers + self._kwargs = kwargs self._average_model: Optional[AveragedModel] = None # Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures - # that the average model will be first updated after the first optimizer step, which takes place after N batches - # when using accumulate_grad_batches=N. + # that self.should_update() will be first called after the first optimizer step, which takes place after N + # batches when using accumulate_grad_batches=N. self._latest_update_step = 0 # The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a - # negative value means that if update_on_step(0) returns True, the first update is after the first epoch. + # negative value means that if self.should_update(epoch_idx=0) returns True, the first update is after the first + # epoch. self._latest_update_epoch = -1 + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: + """Called after every optimizer step and after every training epoch to check whether the average model should + be updated. + + One of the arguments is set to the zero-based index of the last training step or epoch. The default + implementation returns ``True`` when any ``step_idx`` is provided. The user can customize when the average model + gets updated by overriding this method. + + Args: + step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end. + epoch_idx: Index of the last epoch, or ``None`` when called after an optimizer step. + + Returns: + ``True`` if the average model should be updated and ``False`` if not. + + """ + return step_idx is not None + + @override def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: """Called when fit, validate, test, predict, or tune begins. @@ -102,14 +134,17 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s """ if stage == "fit": device = self._device or pl_module.device - self._average_model = AveragedModel(model=pl_module, device=device, avg_fn=self._avg_fn, use_buffers=True) + self._average_model = AveragedModel( + model=pl_module, device=device, use_buffers=self._use_buffers, **self._kwargs + ) + @override def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: """Called when a training batch ends. - Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()``. + Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update()``. Args: trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. @@ -119,26 +154,31 @@ def on_train_batch_end( batch_idx: Index of the training batch. """ - if self._update_on_step(trainer.global_step) and (trainer.global_step > self._latest_update_step): + # trainer.global_step is the number of optimizer steps taken so far, i.e. 1 after the first optimizer step. To + # make step_idx consistent with epoch_idx, we'll pass a zero-based index. + step_idx = trainer.global_step - 1 + if (trainer.global_step > self._latest_update_step) and self.should_update(step_idx=step_idx): assert self._average_model is not None self._average_model.update_parameters(pl_module) self._latest_update_step = trainer.global_step + @override def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when a training epoch ends. - Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``. + Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update()``. Args: trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. """ - if self._update_on_epoch(trainer.current_epoch) and (trainer.current_epoch > self._latest_update_epoch): + if (trainer.current_epoch > self._latest_update_epoch) and self.should_update(epoch_idx=trainer.current_epoch): assert self._average_model is not None self._average_model.update_parameters(pl_module) self._latest_update_epoch = trainer.current_epoch + @override def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when training ends. @@ -150,8 +190,10 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - """ assert self._average_model is not None + rank_zero_info("Loading the average model parameters to the final model.") self._copy_average_to_current(pl_module) + @override def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when a validation epoch begins. @@ -166,6 +208,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn rank_zero_info("Loading the average model parameters for validation.") self._swap_models(pl_module) + @override def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when a validation epoch ends. @@ -180,6 +223,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin rank_zero_info("Recovering the current model parameters after validation.") self._swap_models(pl_module) + @override def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint. @@ -191,6 +235,7 @@ def state_dict(self) -> dict[str, Any]: """ return {"latest_update_step": self._latest_update_step} + @override def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint. @@ -202,6 +247,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: """ self._latest_update_step = state_dict["latest_update_step"] + @override def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] ) -> None: @@ -218,18 +264,23 @@ def on_save_checkpoint( """ if self._average_model is None: - raise Exception("Trying to save a checkpoint, but no average model (outside fit). Don't know what to do.") - - rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.") - average_model_state = self._average_model.state_dict() - checkpoint["current_model_state"] = checkpoint["state_dict"] - checkpoint["state_dict"] = { - name[7:]: value for name, value in average_model_state.items() if name.startswith("module.") - } - checkpoint["averaging_state"] = { - name: value for name, value in average_model_state.items() if not name.startswith("module.") - } - + rank_zero_info( + "You're using the WeightAveraging callback, but saving a checkpoint outside the 'fit' stage. The state " + "of the WeightAveraging callback won't be saved in the checkpoint. If training has finished, the " + "average model parameters will be saved to the state_dict in the checkpoint." + ) + else: + rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.") + average_model_state = self._average_model.state_dict() + checkpoint["current_model_state"] = checkpoint["state_dict"] + checkpoint["state_dict"] = { + name[7:]: value for name, value in average_model_state.items() if name.startswith("module.") + } + checkpoint["averaging_state"] = { + name: value for name, value in average_model_state.items() if not name.startswith("module.") + } + + @override def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] ) -> None: @@ -244,9 +295,12 @@ def on_load_checkpoint( """ if self._average_model is None: - raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.") - - if ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint): + rank_zero_warn( + "You're using the WeightAveraging callback, but loading a checkpoint outside the 'fit' stage. The " + "WeightAveraging state cannot be restored. If you're using the checkpoint for prediction or testing, " + "you can ignore this warning. To disable the warning, remove the WeightAveraging callback." + ) + elif ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint): rank_zero_info("Found current_model_state in the checkpoint. This will be used to initialize the model.") average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()} average_model_state |= checkpoint["averaging_state"] diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index 4dc28a9b71c6b..e5591e2f2f6a0 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from copy import deepcopy from pathlib import Path from typing import Any, Optional @@ -19,7 +20,7 @@ import torch from torch import Tensor, nn from torch.optim.swa_utils import get_swa_avg_fn -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import WeightAveraging @@ -27,28 +28,21 @@ from tests_pytorch.helpers.runif import RunIf -class WeightAveragingTestModel(BoringModel): - def __init__( - self, batch_norm: bool = True, iterable_dataset: bool = False, crash_on_epoch: Optional[int] = None - ) -> None: +class TestModel(BoringModel): + def __init__(self, batch_norm: bool = True) -> None: super().__init__() layers = [nn.Linear(32, 32)] if batch_norm: layers.append(nn.BatchNorm1d(32)) layers += [nn.ReLU(), nn.Linear(32, 2)] self.layer = nn.Sequential(*layers) - self.iterable_dataset = iterable_dataset - self.crash_on_epoch = crash_on_epoch + self.crash_on_epoch = None def training_step(self, batch: Tensor, batch_idx: int) -> None: if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch: - raise Exception("CRASH TEST") + raise Exception("CRASH") return super().training_step(batch, batch_idx) - def train_dataloader(self) -> None: - dataset_class = RandomIterableDataset if self.iterable_dataset else RandomDataset - return DataLoader(dataset_class(32, 32), batch_size=4) - def configure_optimizers(self) -> None: return torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -119,15 +113,15 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: class SWATestCallback(WeightAveraging): def __init__(self, **kwargs: Any) -> None: - avg_fn = get_swa_avg_fn() - update_on_epoch = lambda x: x in (3, 5, 7) - super().__init__(avg_fn=avg_fn, update_on_epoch=update_on_epoch, **kwargs) - + super().__init__(avg_fn=get_swa_avg_fn(), **kwargs) self.swap_calls = 0 self.copy_calls = 0 # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0. self.first_epoch: Optional[int] = None + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: + return epoch_idx in (3, 5, 7) + def _swap_models(self, *args: Any, **kwargs: Any): self.swap_calls += 1 return super()._swap_models(*args, **kwargs) @@ -194,95 +188,115 @@ def setup(self, trainer, pl_module, stage) -> None: @pytest.mark.parametrize("batch_norm", [True, False]) @pytest.mark.parametrize("iterable_dataset", [True, False]) def test_ema(tmp_path, batch_norm: bool, iterable_dataset: bool): - _train(tmp_path, EMATestCallback(), batch_norm=batch_norm, iterable_dataset=iterable_dataset) + model = TestModel(batch_norm=batch_norm) + dataset = RandomIterableDataset(32, 32) if iterable_dataset else RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback()) @pytest.mark.parametrize( "accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))] ) def test_ema_accelerator(tmp_path, accelerator): - _train(tmp_path, EMATestCallback(), accelerator=accelerator, devices=1) + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(), accelerator=accelerator, devices=1) @RunIf(min_cuda_gpus=2, standalone=True) def test_ema_ddp(tmp_path): - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2) + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2) @RunIf(min_cuda_gpus=2) def test_ema_ddp_spawn(tmp_path): - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2) + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2) @RunIf(skip_windows=True) def test_ema_ddp_spawn_cpu(tmp_path): - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2) + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2) -@pytest.mark.parametrize("crash_on_epoch", [1, 3]) +@pytest.mark.parametrize("crash_on_epoch", [1, 3, 5]) def test_ema_resume(tmp_path, crash_on_epoch): - _train_and_resume(tmp_path, crash_on_epoch=crash_on_epoch) + dataset = RandomDataset(32, 32) + model1 = TestModel() + model2 = deepcopy(model1) + + _train(model1, dataset, tmp_path, EMATestCallback()) + + model2.crash_on_epoch = crash_on_epoch + model2 = _train_and_resume(model2, dataset, tmp_path) + + for param1, param2 in zip(model1.parameters(), model2.parameters()): + assert torch.allclose(param1, param2, atol=0.001) @RunIf(skip_windows=True) def test_ema_resume_ddp(tmp_path): - _train_and_resume(tmp_path, crash_on_epoch=3, use_ddp=True) + model = TestModel() + model.crash_on_epoch = 3 + dataset = RandomDataset(32, 32) + _train_and_resume(model, dataset, tmp_path, strategy="ddp_spawn", devices=2) def test_swa(tmp_path): - _train(tmp_path, SWATestCallback()) + model = TestModel() + dataset = RandomDataset(32, 32) + _train(model, dataset, tmp_path, SWATestCallback()) def _train( + model: TestModel, + dataset: Dataset, tmp_path: str, callback: WeightAveraging, - batch_norm: bool = True, strategy: str = "auto", accelerator: str = "cpu", devices: int = 1, - iterable_dataset: bool = False, checkpoint_path: Optional[str] = None, - crash_on_epoch: Optional[int] = None, -) -> None: + will_crash: bool = False, +) -> TestModel: + deterministic = accelerator == "cpu" trainer = Trainer( - default_root_dir=tmp_path, - enable_progress_bar=False, - enable_model_summary=False, + accelerator=accelerator, + strategy=strategy, + devices=devices, logger=False, + callbacks=callback, max_epochs=8, num_sanity_val_steps=0, - callbacks=callback, + enable_checkpointing=will_crash, + enable_progress_bar=False, + enable_model_summary=False, accumulate_grad_batches=2, - strategy=strategy, - accelerator=accelerator, - devices=devices, - ) - model = WeightAveragingTestModel( - batch_norm=batch_norm, iterable_dataset=iterable_dataset, crash_on_epoch=crash_on_epoch + deterministic=deterministic, + default_root_dir=tmp_path, ) - - if crash_on_epoch is None: - trainer.fit(model, ckpt_path=checkpoint_path) + dataloader = DataLoader(dataset, batch_size=4, shuffle=False) + if will_crash: + with pytest.raises(Exception, match="CRASH"): + trainer.fit(model, dataloader, ckpt_path=checkpoint_path) else: - with pytest.raises(Exception, match="CRASH TEST"): - trainer.fit(model, ckpt_path=checkpoint_path) - + trainer.fit(model, dataloader, ckpt_path=checkpoint_path) assert trainer.lightning_module == model -def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False) -> None: - strategy = "ddp_spawn" if use_ddp else "auto" - devices = 2 if use_ddp else 1 - - _train( - tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, crash_on_epoch=crash_on_epoch - ) +def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices: int = 1, **kwargs) -> TestModel: + _train(model, dataset, tmp_path, EMATestCallback(devices=devices), devices=devices, will_crash=True, **kwargs) checkpoint_dir = Path(tmp_path) / "checkpoints" checkpoint_names = os.listdir(checkpoint_dir) assert len(checkpoint_names) == 1 checkpoint_path = str(checkpoint_dir / checkpoint_names[0]) - _train( - tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, checkpoint_path=checkpoint_path - ) + model = TestModel.load_from_checkpoint(checkpoint_path) + callback = EMATestCallback(devices=devices) + _train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs) + return model From aec9f6e3f1639251bab171dd69e6ccd1e74c6353 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Sat, 22 Mar 2025 09:52:15 +0200 Subject: [PATCH 3/9] Training tricks mentions WeightAveraging and EMA --- .../advanced/training_tricks.rst | 43 ++++++++++++++----- .../model/build_model_intermediate.rst | 2 +- docs/source-pytorch/starter/introduction.rst | 2 +- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/docs/source-pytorch/advanced/training_tricks.rst b/docs/source-pytorch/advanced/training_tricks.rst index 25dd996c628a4..23c1c67b8734e 100644 --- a/docs/source-pytorch/advanced/training_tricks.rst +++ b/docs/source-pytorch/advanced/training_tricks.rst @@ -50,23 +50,44 @@ Read more about :ref:`Configuring Gradient Clipping `__ by the PyTorch team. +Lightning provides two callbacks to facilitate weight averaging. :class:`~lightning.pytorch.callbacks.WeightAveraging` +is a generic callback that wraps the +`AveragedModel `__ class from +PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used and it can be customized to run at specific steps +or epochs. -.. seealso:: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback +The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA +procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant +learning rate schedule (`SWALR `__) when the +procedure starts. + +.. seealso:: + For a more detailed explanation of SWA and how it works, read + `this post `__ by the PyTorch team. .. testcode:: - # Enable Stochastic Weight Averaging using the callback - trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)]) + from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging + from torch.optim.swa_utils import get_ema_avg_fn + + # Enable Exponential Moving Average after 100 steps + class EMAWeightAveraging(WeightAveraging): + def __init__(self): + super().__init__(avg_fn=get_ema_avg_fn()) + def should_update(self, step_idx=None, epoch_idx=None): + return (step_idx is not None) and (step_idx >= 100) + trainer = Trainer(callbacks=EMAWeightAveraging()) + + # Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01 + trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01)) ---------- diff --git a/docs/source-pytorch/model/build_model_intermediate.rst b/docs/source-pytorch/model/build_model_intermediate.rst index 82362af7ecc83..8a56d20947334 100644 --- a/docs/source-pytorch/model/build_model_intermediate.rst +++ b/docs/source-pytorch/model/build_model_intermediate.rst @@ -27,7 +27,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni ) # access the latest state of the art techniques - trainer = Trainer(callbacks=[StochasticWeightAveraging(...)]) + trainer = Trainer(callbacks=[WeightAveraging(...)]) ---- diff --git a/docs/source-pytorch/starter/introduction.rst b/docs/source-pytorch/starter/introduction.rst index 8e55afb907aab..ecdda6ac1c53f 100644 --- a/docs/source-pytorch/starter/introduction.rst +++ b/docs/source-pytorch/starter/introduction.rst @@ -252,7 +252,7 @@ Enable advanced training features using Trainer arguments. These are state-of-th ) # access the latest state of the art techniques - trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)]) + trainer = L.Trainer(callbacks=[WeightAveraging(...)]) ---- From 247935f63648c9dccc28ef2a913a57ea0b64fc27 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 3 Apr 2025 20:50:29 +0300 Subject: [PATCH 4/9] Removed logging from WeightAveraging --- src/lightning/pytorch/callbacks/weight_averaging.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 4b4df2c095fc3..c3f019d62244e 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -190,7 +190,6 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - """ assert self._average_model is not None - rank_zero_info("Loading the average model parameters to the final model.") self._copy_average_to_current(pl_module) @override @@ -205,7 +204,6 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn """ if self._average_model is not None: - rank_zero_info("Loading the average model parameters for validation.") self._swap_models(pl_module) @override @@ -220,7 +218,6 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin """ if self._average_model is not None: - rank_zero_info("Recovering the current model parameters after validation.") self._swap_models(pl_module) @override @@ -270,7 +267,6 @@ def on_save_checkpoint( "average model parameters will be saved to the state_dict in the checkpoint." ) else: - rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.") average_model_state = self._average_model.state_dict() checkpoint["current_model_state"] = checkpoint["state_dict"] checkpoint["state_dict"] = { From 822231f0aef2262c9b1c526232cb5dc86fa6ae7d Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 3 Apr 2025 21:11:49 +0300 Subject: [PATCH 5/9] Fixed the documentation --- docs/source-pytorch/glossary/index.rst | 16 ++++++++-------- .../pytorch/callbacks/stochastic_weight_avg.py | 2 +- .../pytorch/callbacks/weight_averaging.py | 3 +++ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst index 6b5e4b12b307f..333ef9834ef84 100644 --- a/docs/source-pytorch/glossary/index.rst +++ b/docs/source-pytorch/glossary/index.rst @@ -42,13 +42,13 @@ Strategy registry <../advanced/strategy_registry> Strategy integrations <../integrations/strategies/index> Style guide <../starter/style_guide> - SWA <../advanced/training_tricks> SLURM <../clouds/cluster_advanced> Tensor Parallel <../advanced/model_parallel/tp> Transfer learning <../advanced/transfer_learning> Trainer <../common/trainer> TorchRun (TorchElastic) <../clouds/cluster_intermediate_2> Warnings <../advanced/warnings> + Weight averaging <../advanced/training_tricks> ######## @@ -326,13 +326,6 @@ Glossary :button_link: ../starter/style_guide.html :height: 100 -.. displayitem:: - :header: SWA - :description: Stochastic Weight Averaging (SWA) can make your models generalize better - :col_css: col-md-12 - :button_link: ../advanced/training_tricks.html#stochastic-weight-averaging - :height: 100 - .. displayitem:: :header: SLURM :description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters @@ -375,6 +368,13 @@ Glossary :button_link: ../advanced/warnings.html :height: 100 +.. displayitem:: + :header: Weight averaging + :description: Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) can make your models generalize better + :col_css: col-md-12 + :button_link: ../advanced/training_tricks.html#weight-averaging + :height: 100 + .. raw:: html diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 375bd15f29051..79c5423c54084 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -65,7 +65,7 @@ def __init__( .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. - See also how to :ref:`enable it directly on the Trainer ` + See also how to :ref:`enable it directly on the Trainer `. Arguments: diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index c3f019d62244e..e24febb429d0e 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -52,6 +52,9 @@ class WeightAveraging(Callback): During validation and after the training finishes, the current model parameters will be replaced with the averaged values. + See also the documentation on the :ref:`weight averaging callbacks ` + provided by Lightning. + Example:: from lightning.pytorch.callbacks import WeightAveraging From 5deb0bbc01c7414ac50fe2635ec97898060f50ab Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 3 Apr 2025 21:50:32 +0300 Subject: [PATCH 6/9] Fixed checkpoint loading with WeightAveraging --- src/lightning/pytorch/callbacks/weight_averaging.py | 5 ++++- tests/tests_pytorch/callbacks/test_weight_averaging.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index e24febb429d0e..34b373e7be1f0 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -304,7 +304,10 @@ def on_load_checkpoint( average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()} average_model_state |= checkpoint["averaging_state"] self._average_model.load_state_dict(average_model_state) - checkpoint["state_dict"] = checkpoint["current_model_state"] + # The current model state has already been loaded from "state_dict" (which contains the average model + # weights) at this point, so overwriting "state_dict" in the checkpoint dictionary makes no difference. We + # have to reload the model state from "current_model_state". + pl_module.load_state_dict(checkpoint["current_model_state"]) else: rank_zero_warn( "The checkpoint was not created with WeightAveraging. Both the current and the average model will be " diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index e5591e2f2f6a0..57ad62e9706d4 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -235,7 +235,7 @@ def test_ema_resume(tmp_path, crash_on_epoch): model2 = _train_and_resume(model2, dataset, tmp_path) for param1, param2 in zip(model1.parameters(), model2.parameters()): - assert torch.allclose(param1, param2, atol=0.001) + assert torch.allclose(param1, param2) @RunIf(skip_windows=True) From 5a690570da9f8f36320f8aec6a2d4df282635fc7 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Sat, 26 Apr 2025 11:23:31 +0300 Subject: [PATCH 7/9] WeightAveraging calls the configure_model hook but issues a warning --- .../pytorch/callbacks/weight_averaging.py | 18 ++++++++++ .../callbacks/test_weight_averaging.py | 35 +++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 34b373e7be1f0..16479f3107242 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -26,6 +26,7 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback +from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -55,6 +56,13 @@ class WeightAveraging(Callback): See also the documentation on the :ref:`weight averaging callbacks ` provided by Lightning. + Note: + To ensure that the :class:`AveragedModel` will contain all layers, + :meth:`~lightning.pytorch.callbacks.weight_averaging.WeightAveraging.setup` will call + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before instantiating the + :class:`AveragedModel`. However, that hook is not called in a strategy aware context, sharded models do not work + with weight averaging, and a warning will be issued. + Example:: from lightning.pytorch.callbacks import WeightAveraging @@ -137,6 +145,16 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s """ if stage == "fit": device = self._device or pl_module.device + + # If the configure_model hook is overridden, call it to create the layers before constructing the + # AveragedModel. However, sharding will not be done and a warning will be issued. + if is_overridden("configure_model", pl_module): + rank_zero_warn( + "You're using the WeightAveraging callback with a model that overrides the configure_model " + "callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory." + ) + pl_module.configure_model() + self._average_model = AveragedModel( model=pl_module, device=device, use_buffers=self._use_buffers, **self._kwargs ) diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index 57ad62e9706d4..d54856e3bda1f 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -47,6 +47,19 @@ def configure_optimizers(self) -> None: return torch.optim.SGD(self.layer.parameters(), lr=0.1) +class LargeTestModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = None + + def configure_model(self): + print("XXX configure_model") + self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2)) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + class EMAAveragingFunction: """EMA averaging function. @@ -252,8 +265,26 @@ def test_swa(tmp_path): _train(model, dataset, tmp_path, SWATestCallback()) +@pytest.mark.parametrize( + ("strategy", "accelerator", "devices"), + [ + ("auto", "cpu", 1), + pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)), + pytest.param("fsdp", "gpu", 1, marks=RunIf(min_cuda_gpus=1)), + pytest.param("ddp", "gpu", 2, marks=RunIf(min_cuda_gpus=2)), + pytest.param("fsdp", "gpu", 2, marks=RunIf(min_cuda_gpus=2)), + ], +) +def test_ema_configure_model(tmp_path, strategy, accelerator, devices): + model = LargeTestModel() + dataset = RandomDataset(32, 32) + callback = EMATestCallback() + _train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices) + assert isinstance(callback._average_model.module.layer, nn.Sequential) + + def _train( - model: TestModel, + model: BoringModel, dataset: Dataset, tmp_path: str, callback: WeightAveraging, @@ -262,7 +293,7 @@ def _train( devices: int = 1, checkpoint_path: Optional[str] = None, will_crash: bool = False, -) -> TestModel: +) -> None: deterministic = accelerator == "cpu" trainer = Trainer( accelerator=accelerator, From 3dafb4c0b24745a3ca8d5a877978c2cf332ae6b3 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Sat, 26 Apr 2025 11:50:04 +0300 Subject: [PATCH 8/9] Fixed unit tests * Fixed a reference in a docstring. * Removed two unit tests to avoid running out of memory in the CI pipeline. --- src/lightning/pytorch/callbacks/weight_averaging.py | 3 +-- tests/tests_pytorch/callbacks/test_weight_averaging.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 16479f3107242..a983b32a1a161 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -57,8 +57,7 @@ class WeightAveraging(Callback): provided by Lightning. Note: - To ensure that the :class:`AveragedModel` will contain all layers, - :meth:`~lightning.pytorch.callbacks.weight_averaging.WeightAveraging.setup` will call + To ensure that the :class:`AveragedModel` will contain all layers, ``setup()`` will call :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before instantiating the :class:`AveragedModel`. However, that hook is not called in a strategy aware context, sharded models do not work with weight averaging, and a warning will be issued. diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index d54856e3bda1f..ec230b2fd6c97 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -271,8 +271,6 @@ def test_swa(tmp_path): ("auto", "cpu", 1), pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)), pytest.param("fsdp", "gpu", 1, marks=RunIf(min_cuda_gpus=1)), - pytest.param("ddp", "gpu", 2, marks=RunIf(min_cuda_gpus=2)), - pytest.param("fsdp", "gpu", 2, marks=RunIf(min_cuda_gpus=2)), ], ) def test_ema_configure_model(tmp_path, strategy, accelerator, devices): From 410fe1400a3ea58a2be21f3884bc647fce37deef Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Sun, 6 Jul 2025 12:24:28 +0300 Subject: [PATCH 9/9] The default device for the averaged model is the device of the original model --- .../pytorch/callbacks/weight_averaging.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index a983b32a1a161..2ffe02ae1bc24 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -35,12 +35,11 @@ class WeightAveraging(Callback): r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) after each training step. - Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of - differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set - to ``None``, the device will be inferred from the original model. By default, the callback will compute running - averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only - the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using - ``torch.optim.swa_utils.update_bn()``). + Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. If no ``device`` is + specified, the device of the original model will be used. Contrary to :class:`AveragedModel`, ``use_buffers`` is set + to ``True`` by default. That is, by default the callback will compute running averages for both the parameters and + the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only the model parameters to be averaged, + leaving updating the batch normalization statistics to the user (using ``torch.optim.swa_utils.update_bn()``). You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the :class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the @@ -79,8 +78,9 @@ def should_update(self, step_idx=None, epoch_idx=None): trainer.fit(model, dataloader) Args: - device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be - inferred from the original model. + device: By default, the :class:`AveragedModel` will be stored on the same device as the original model. If the + ``device`` argument is provided, the :class:`AveragedModel` will be stored on this device instead. If you + run out of GPU memory, you might want to use ``"cpu"``. use_buffers: If ``False``, the buffers of the model will not be averaged. kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn`` or ``multi_avg_fn``. @@ -89,7 +89,7 @@ def should_update(self, step_idx=None, epoch_idx=None): def __init__( self, - device: Optional[Union[torch.device, str, int]] = "cpu", + device: Optional[Union[torch.device, str, int]] = None, use_buffers: bool = True, **kwargs: Any, ) -> None: