Skip to content

Commit 631911c

Browse files
corwinjoyawaelchli
andauthored
Add special logic for 'step' in _optimizer_to_device (#20019)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
1 parent 345450b commit 631911c

File tree

4 files changed

+88
-28
lines changed

4 files changed

+88
-28
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
- Fixed an attribute error when loading a checkpoint into a quantized model using the `_lazy_load()` function ([#20121](https://github.com/Lightning-AI/lightning/pull/20121))
3939

4040

41-
-
41+
- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
4242

4343

4444

src/lightning/fabric/utilities/optimizer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import MutableMapping
1516
from typing import Iterable
1617

17-
from lightning_utilities.core.apply_func import apply_to_collection
1818
from torch import Tensor
1919
from torch.optim import Optimizer
2020

21-
from lightning.fabric.utilities.apply_func import move_data_to_device
21+
from lightning.fabric.utilities.apply_func import apply_to_collection, move_data_to_device
2222
from lightning.fabric.utilities.types import _DEVICE
2323

2424

@@ -31,4 +31,12 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N
3131
def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
3232
"""Moves the state of a single optimizer to the device."""
3333
for p, v in optimizer.state.items():
34-
optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
34+
if not isinstance(v, MutableMapping):
35+
# Support for custom optimizers
36+
optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
37+
continue
38+
for key, val in v.items():
39+
# The 'step' parameter needs to remain unmoved (possibly on the CPU) since that is where the optimizer
40+
# needs it. See https://github.com/pytorch/pytorch/issues/74424
41+
if key != "step":
42+
v[key] = move_data_to_device(val, device)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4949

5050
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
5151

52+
- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
53+
5254
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
5355

5456

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,86 @@
1-
import collections
21
import dataclasses
32

3+
import pytest
44
import torch
55
from lightning.fabric.utilities.optimizer import _optimizer_to_device
66
from torch import Tensor
77

8+
from tests_fabric.helpers.runif import RunIf
89

9-
def test_optimizer_to_device():
10-
@dataclasses.dataclass(frozen=True)
10+
11+
@pytest.mark.parametrize(
12+
"optimizer_class",
13+
[
14+
torch.optim.Adam,
15+
torch.optim.AdamW,
16+
torch.optim.SGD,
17+
torch.optim.RMSprop,
18+
torch.optim.Adagrad,
19+
torch.optim.Adadelta,
20+
torch.optim.Adamax,
21+
],
22+
)
23+
@pytest.mark.parametrize(
24+
"src_device",
25+
[
26+
torch.device("cpu"),
27+
pytest.param(torch.device("cuda"), marks=RunIf(min_cuda_gpus=1)),
28+
],
29+
)
30+
@pytest.mark.parametrize(
31+
"dst_device",
32+
[
33+
torch.device("cpu"),
34+
pytest.param(torch.device("cuda"), marks=RunIf(min_cuda_gpus=1)),
35+
],
36+
)
37+
def test_optimizer_to_device(optimizer_class, src_device, dst_device):
38+
# Optimizer with no state initialized
39+
model = torch.nn.Linear(2, 2, device=src_device)
40+
optimizer = optimizer_class(model.parameters(), lr=0.1)
41+
_optimizer_to_device(optimizer, dst_device)
42+
_assert_opt_parameters_on_device(optimizer, dst_device)
43+
44+
# Optimizer with state initialized
45+
model = torch.nn.Linear(2, 2, device=src_device)
46+
optimizer = optimizer_class(model.parameters(), lr=0.1)
47+
model(torch.randn(2, 2, device=src_device)).sum().backward()
48+
optimizer.step()
49+
_optimizer_to_device(optimizer, dst_device)
50+
_assert_opt_parameters_on_device(optimizer, dst_device)
51+
52+
53+
def _assert_opt_parameters_on_device(opt, device):
54+
for _, v in opt.state.items():
55+
for key, item in v.items():
56+
if not isinstance(item, Tensor):
57+
continue
58+
if key == "step":
59+
# The "step" tensor needs to remain on CPU
60+
assert item.device.type == "cpu"
61+
else:
62+
assert item.device.type == device.type
63+
64+
65+
@RunIf(min_cuda_gpus=1)
66+
@pytest.mark.parametrize("frozen", [True, False])
67+
def test_optimizer_to_device_with_dataclass_in_state(frozen):
68+
src_device = torch.device("cpu")
69+
dst_device = torch.device("cuda")
70+
model = torch.nn.Linear(32, 2, device=src_device)
71+
72+
@dataclasses.dataclass(frozen=frozen)
1173
class FooState:
12-
bar: int
74+
integer: int
75+
tensor: Tensor
1376

1477
class TestOptimizer(torch.optim.SGD):
1578
def __init__(self, *args, **kwargs):
1679
super().__init__(*args, **kwargs)
17-
self.state["dummy"] = torch.tensor(0)
18-
self.state["frozen"] = FooState(0)
19-
20-
layer = torch.nn.Linear(32, 2)
21-
opt = TestOptimizer(layer.parameters(), lr=0.1)
22-
_optimizer_to_device(opt, "cpu")
23-
if torch.cuda.is_available():
24-
_optimizer_to_device(opt, "cuda")
25-
assert_opt_parameters_on_device(opt, "cuda")
26-
27-
28-
def assert_opt_parameters_on_device(opt, device: str):
29-
for param in opt.state.values():
30-
# Not sure there are any global tensors in the state dict
31-
if isinstance(param, Tensor):
32-
assert param.data.device.type == device
33-
elif isinstance(param, collections.abc.Mapping):
34-
for subparam in param.values():
35-
if isinstance(subparam, Tensor):
36-
assert param.data.device.type == device
80+
self.state[model.weight] = {"dummy": torch.tensor(0)}
81+
self.state[model.bias] = FooState(0, torch.tensor(0))
82+
83+
optimizer = TestOptimizer(model.parameters(), lr=0.1)
84+
_optimizer_to_device(optimizer, dst_device)
85+
assert optimizer.state[model.weight]["dummy"].device.type == dst_device.type
86+
assert optimizer.state[model.bias].tensor.device.type == ("cpu" if frozen else dst_device.type)

0 commit comments

Comments
 (0)