File tree Expand file tree Collapse file tree 2 files changed +17
-2
lines changed
torchtune/training/checkpointing Expand file tree Collapse file tree 2 files changed +17
-2
lines changed Original file line number Diff line number Diff line change @@ -134,7 +134,6 @@ def __init__(self, cfg: DictConfig) -> None:
134
134
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
135
135
)
136
136
137
- self ._checkpoint_client = CheckpointClient (cfg )
138
137
# Set up the backend for distributed training (NCCL, GLOO, etc.)
139
138
self ._enable_async_checkpointing = cfg .get ("enable_async_checkpointing" , False )
140
139
self .fsdp_cpu_offload = cfg .get ("fsdp_cpu_offload" , False )
@@ -149,6 +148,8 @@ def __init__(self, cfg: DictConfig) -> None:
149
148
150
149
self ._is_rank_zero = self .rank == 0
151
150
151
+ self ._checkpoint_client = CheckpointClient (cfg )
152
+
152
153
# logging attributes
153
154
self ._output_dir = cfg .output_dir
154
155
self ._log_every_n_steps = cfg .get ("log_every_n_steps" , 1 )
@@ -296,7 +297,7 @@ def setup(self, cfg: DictConfig) -> None:
296
297
)
297
298
)
298
299
except Exception as e :
299
- log .warning (
300
+ self . _logger .warning (
300
301
f"Failed to load distributed checkpoint: { e } . Training will start from the base checkpoint."
301
302
)
302
303
Original file line number Diff line number Diff line change @@ -256,6 +256,20 @@ def _save_checkpoint_sync(
256
256
optim_state_dict = {}
257
257
258
258
if is_not_distributed_checkpointer and not single_device :
259
+ # this logic is needed because staging an async checkpoint needs cpu gathering
260
+ # which is also used here to save a sync checkpoint that causes issues when
261
+ # occurring concurrently. This case should never be called in theory because
262
+ # an epoch would be much longer than an async checkpoint. But running into this
263
+ # for a test case with a very fast epoch.
264
+ if self ._get_dcp_checkpointer ()._checkpoint_future is not None :
265
+ time_start_waiting = time .perf_counter ()
266
+ self ._get_dcp_checkpointer ()._checkpoint_future .result ()
267
+ if self ._is_rank_zero :
268
+ log .info (
269
+ "Waiting for async checkpoint to finish, to save sync checkpoint " ,
270
+ f"took { time .perf_counter () - time_start_waiting :.2f} secs" ,
271
+ )
272
+
259
273
# To prevent GPU memory from spiking during checkpoint save,
260
274
# we consolidate the full model and optim state dicts on CPU for rank 0
261
275
model_state_dict = training .gather_cpu_state_dict (
You can’t perform that action at this time.
0 commit comments