File tree Expand file tree Collapse file tree 3 files changed +34
-8
lines changed
tests/framework/callbacks
torchtnt/framework/callbacks Expand file tree Collapse file tree 3 files changed +34
-8
lines changed Original file line number Diff line number Diff line change @@ -255,6 +255,27 @@ def test_save_restore_no_lr_scheduler_restore(
255255 app_state = mock_dist_cp .load .call_args .args [0 ]["app_state" ].state_dict ()
256256 self .assertIn ("lr_scheduler" , app_state )
257257
258+ @patch ("torchtnt.framework.callbacks.dcp_saver._init_optim_state" )
259+ @patch ("torchtnt.framework.callbacks.dcp_saver.dcp" )
260+ def test_save_restore_no_init_optim_state (
261+ self , _ : MagicMock , mock_init_optim_state : MagicMock
262+ ) -> None :
263+ my_unit = DummyTrainUnit (input_dim = 2 )
264+ restore_options = RestoreOptions (init_optim_states = False )
265+ DistributedCheckpointSaver .restore (
266+ path = "path/to/snapshot" ,
267+ unit = my_unit ,
268+ restore_options = restore_options ,
269+ )
270+ mock_init_optim_state .assert_not_called ()
271+
272+ DistributedCheckpointSaver .restore (
273+ path = "path/to/snapshot" ,
274+ unit = my_unit ,
275+ restore_options = RestoreOptions (),
276+ )
277+ mock_init_optim_state .assert_called ()
278+
258279 @skip_if_not_distributed
259280 def test_save_restore_ddp (self ) -> None :
260281 spawn_multi_process (
Original file line number Diff line number Diff line change @@ -43,6 +43,9 @@ class RestoreOptions:
4343 restore_optimizers: Whether to restore the optimizer states.
4444 restore_lr_schedulers: Whether to restore the lr scheduler states.
4545 strict: Whether to strictly restore app state and the module state dict.
46+ init_optim_states: Whether to initialize the optimizer state. Defaults to True. Toggle off
47+ if running into issues with loading optimizer state. This will reset optimizer state,
48+ which may affect training in some cases.
4649 """
4750
4851 restore_modules : bool = True
@@ -52,3 +55,4 @@ class RestoreOptions:
5255 restore_optimizers : bool = True
5356 restore_lr_schedulers : bool = True
5457 strict : bool = True
58+ init_optim_states : bool = True
Original file line number Diff line number Diff line change @@ -370,14 +370,15 @@ def restore_with_id(
370370 predict_dataloader ,
371371 )
372372
373- # necessary for loading optimizers since states are initialized lazy
374- for obj in app_state .values ():
375- # sometimes optimizers are actually held in a wrapper which handles calling
376- # state_dict and load_state_dict, sa is the case for
377- # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
378- optimizer = getattr (obj , "optimizer" , obj )
379- if isinstance (optimizer , torch .optim .Optimizer ):
380- _init_optim_state (optimizer )
373+ if restore_options .init_optim_states :
374+ # if optimizers states are initialized lazy
375+ for obj in app_state .values ():
376+ # sometimes optimizers are actually held in a wrapper which handles calling
377+ # state_dict and load_state_dict, sa is the case for
378+ # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
379+ optimizer = getattr (obj , "optimizer" , obj )
380+ if isinstance (optimizer , torch .optim .Optimizer ):
381+ _init_optim_state (optimizer )
381382
382383 with get_or_create_gloo_pg (candidate_pg = process_group ) as pg :
383384 dcp .load (
You can’t perform that action at this time.
0 commit comments