Skip to content

Commit eaac21f

Browse files
committed
tests passing
1 parent d1d5fae commit eaac21f

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

recipes/lora_dpo_distributed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def __init__(self, cfg: DictConfig) -> None:
134134
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
135135
)
136136

137-
self._checkpoint_client = CheckpointClient(cfg)
138137
# Set up the backend for distributed training (NCCL, GLOO, etc.)
139138
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
140139
self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False)
@@ -149,6 +148,8 @@ def __init__(self, cfg: DictConfig) -> None:
149148

150149
self._is_rank_zero = self.rank == 0
151150

151+
self._checkpoint_client = CheckpointClient(cfg)
152+
152153
# logging attributes
153154
self._output_dir = cfg.output_dir
154155
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
@@ -296,7 +297,7 @@ def setup(self, cfg: DictConfig) -> None:
296297
)
297298
)
298299
except Exception as e:
299-
log.warning(
300+
self._logger.warning(
300301
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
301302
)
302303

torchtune/training/checkpointing/_checkpoint_client.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,20 @@ def _save_checkpoint_sync(
256256
optim_state_dict = {}
257257

258258
if is_not_distributed_checkpointer and not single_device:
259+
# this logic is needed because staging an async checkpoint needs cpu gathering
260+
# which is also used here to save a sync checkpoint that causes issues when
261+
# occurring concurrently. This case should never be called in theory because
262+
# an epoch would be much longer than an async checkpoint. But running into this
263+
# for a test case with a very fast epoch.
264+
if self._get_dcp_checkpointer()._checkpoint_future is not None:
265+
time_start_waiting = time.perf_counter()
266+
self._get_dcp_checkpointer()._checkpoint_future.result()
267+
if self._is_rank_zero:
268+
log.info(
269+
"Waiting for async checkpoint to finish, to save sync checkpoint ",
270+
f"took {time.perf_counter() - time_start_waiting:.2f} secs",
271+
)
272+
259273
# To prevent GPU memory from spiking during checkpoint save,
260274
# we consolidate the full model and optim state dicts on CPU for rank 0
261275
model_state_dict = training.gather_cpu_state_dict(

0 commit comments

Comments
 (0)