Skip to content

Commit a5358ce

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add init_optim_state flag to restore options (#992)
Summary: Pull Request resolved: #992 Reviewed By: diego-urgell Differential Revision: D73678205 fbshipit-source-id: 28e1b748d2b6f9d347f4c78c049113f25c6a8457
1 parent c1dfeb7 commit a5358ce

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,27 @@ def test_save_restore_no_lr_scheduler_restore(
255255
app_state = mock_dist_cp.load.call_args.args[0]["app_state"].state_dict()
256256
self.assertIn("lr_scheduler", app_state)
257257

258+
@patch("torchtnt.framework.callbacks.dcp_saver._init_optim_state")
259+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
260+
def test_save_restore_no_init_optim_state(
261+
self, _: MagicMock, mock_init_optim_state: MagicMock
262+
) -> None:
263+
my_unit = DummyTrainUnit(input_dim=2)
264+
restore_options = RestoreOptions(init_optim_states=False)
265+
DistributedCheckpointSaver.restore(
266+
path="path/to/snapshot",
267+
unit=my_unit,
268+
restore_options=restore_options,
269+
)
270+
mock_init_optim_state.assert_not_called()
271+
272+
DistributedCheckpointSaver.restore(
273+
path="path/to/snapshot",
274+
unit=my_unit,
275+
restore_options=RestoreOptions(),
276+
)
277+
mock_init_optim_state.assert_called()
278+
258279
@skip_if_not_distributed
259280
def test_save_restore_ddp(self) -> None:
260281
spawn_multi_process(

torchtnt/framework/callbacks/checkpointer_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class RestoreOptions:
4343
restore_optimizers: Whether to restore the optimizer states.
4444
restore_lr_schedulers: Whether to restore the lr scheduler states.
4545
strict: Whether to strictly restore app state and the module state dict.
46+
init_optim_states: Whether to initialize the optimizer state. Defaults to True. Toggle off
47+
if running into issues with loading optimizer state. This will reset optimizer state,
48+
which may affect training in some cases.
4649
"""
4750

4851
restore_modules: bool = True
@@ -52,3 +55,4 @@ class RestoreOptions:
5255
restore_optimizers: bool = True
5356
restore_lr_schedulers: bool = True
5457
strict: bool = True
58+
init_optim_states: bool = True

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,15 @@ def restore_with_id(
370370
predict_dataloader,
371371
)
372372

373-
# necessary for loading optimizers since states are initialized lazy
374-
for obj in app_state.values():
375-
# sometimes optimizers are actually held in a wrapper which handles calling
376-
# state_dict and load_state_dict, sa is the case for
377-
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
378-
optimizer = getattr(obj, "optimizer", obj)
379-
if isinstance(optimizer, torch.optim.Optimizer):
380-
_init_optim_state(optimizer)
373+
if restore_options.init_optim_states:
374+
# if optimizers states are initialized lazy
375+
for obj in app_state.values():
376+
# sometimes optimizers are actually held in a wrapper which handles calling
377+
# state_dict and load_state_dict, sa is the case for
378+
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
379+
optimizer = getattr(obj, "optimizer", obj)
380+
if isinstance(optimizer, torch.optim.Optimizer):
381+
_init_optim_state(optimizer)
381382

382383
with get_or_create_gloo_pg(candidate_pg=process_group) as pg:
383384
dcp.load(

0 commit comments

Comments
 (0)