Skip to content

Commit d195d2b

Browse files
rustamzhpre-commit-ci[bot]Borda
authored
refactor: add toggled_optimizer context manager (#20771)
* Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
1 parent 9ddb418 commit d195d2b

File tree

5 files changed

+48
-1
lines changed

5 files changed

+48
-1
lines changed

docs/source-pytorch/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
487487
("py:meth", "setup"),
488488
("py:meth", "test_step"),
489489
("py:meth", "toggle_optimizer"),
490+
("py:meth", "toggled_optimizer"),
490491
("py:class", "torch.ScriptModule"),
491492
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload"),
492493
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision"),

docs/source-pytorch/model/manual_optimization.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ To manually optimize, do the following:
1717
* ``optimizer.zero_grad()`` to clear the gradients from the previous training step
1818
* ``self.manual_backward(loss)`` instead of ``loss.backward()``
1919
* ``optimizer.step()`` to update your model parameters
20-
* ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()`` if needed
20+
* ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()``, or ``self.toggled_optimizer()`` if needed
2121

2222
Here is a minimal example of manual optimization.
2323

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))
1313

14+
15+
- Add `toggled_optimizer(optimizer)` method to the LightningModule, which is a context manager version of `toggle_optimize` and `untoggle_optimizer` ([#20771](https://github.com/Lightning-AI/pytorch-lightning/pull/20771))
16+
17+
1418
- For cross-device local checkpoints, instruct users to install `fsspec>=2025.5.0` if unavailable ([#20780](https://github.com/Lightning-AI/pytorch-lightning/pull/20780))
1519

1620

src/lightning/pytorch/core/module.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,32 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) ->
11411141
# save memory
11421142
self._param_requires_grad_state = {}
11431143

1144+
@contextmanager
1145+
def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> Generator:
1146+
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to
1147+
prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and
1148+
:meth:`untoggle_optimizer` into context manager.
1149+
1150+
Args:
1151+
optimizer: The optimizer to toggle.
1152+
1153+
Example::
1154+
1155+
def training_step(...):
1156+
opt = self.optimizers()
1157+
with self.toggled_optimizer(opt):
1158+
loss = ...
1159+
opt.zero_grad()
1160+
self.manual_backward(loss)
1161+
opt.step()
1162+
1163+
"""
1164+
self.toggle_optimizer(optimizer)
1165+
try:
1166+
yield
1167+
finally:
1168+
self.untoggle_optimizer(optimizer)
1169+
11441170
def clip_gradients(
11451171
self,
11461172
optimizer: Optimizer,

tests/tests_pytorch/core/test_lightning_module.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,22 @@ def test_1_optimizer_toggle_model():
119119
assert not model._param_requires_grad_state
120120

121121

122+
def test_optimizer_toggle_model_context_manager():
123+
"""Test toggle_model runs when only one optimizer is used."""
124+
model = BoringModel()
125+
trainer = Mock()
126+
model.trainer = trainer
127+
params = model.parameters()
128+
optimizer = torch.optim.SGD(params, lr=0.1)
129+
trainer.optimizers = [optimizer]
130+
131+
assert not model._param_requires_grad_state
132+
# toggle optimizer was failing with a single optimizer
133+
with model.toggled_optimizer(optimizer):
134+
assert model._param_requires_grad_state
135+
assert not model._param_requires_grad_state
136+
137+
122138
def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmp_path):
123139
class TestModel(BoringModel):
124140
def __init__(self):

0 commit comments

Comments
 (0)