Skip to content

Commit 5b2e881

Browse files
authored
Update get_merged_lora_ckpt for dist checkpoints (#2834)
1 parent 9d91fe3 commit 5b2e881

File tree

6 files changed

+36
-9
lines changed

6 files changed

+36
-9
lines changed

recipes/full_finetune_single_device.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def setup(self, cfg: DictConfig) -> None:
289289
ckpt_dict = self._checkpoint_client.load_distributed_checkpoint(
290290
self._model,
291291
self.optimizer,
292+
single_device=True,
292293
)
293294
except Exception as e:
294295
self._logger.warning(

recipes/knowledge_distillation_single_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def setup(self, cfg: DictConfig) -> None:
248248
self._model,
249249
self._optimizer,
250250
self._adapter_config,
251-
self._save_adapter_weights_only,
251+
single_device=True,
252252
)
253253

254254
if training.ADAPTER_KEY not in checkpoint_dict:

recipes/lora_dpo_single_device.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def setup(self, cfg: DictConfig) -> None:
233233
self._model,
234234
self._optimizer,
235235
self._adapter_config,
236+
single_device=True,
236237
)
237238

238239
if training.ADAPTER_KEY not in checkpoint_dict:

recipes/lora_finetune_single_device.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def setup(self, cfg: DictConfig) -> None:
285285
self._model,
286286
self._optimizer,
287287
self._adapter_config,
288+
single_device=True,
288289
)
289290

290291
if training.ADAPTER_KEY not in checkpoint_dict:

torchtune/modules/peft/_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Generator, Literal, Optional, Protocol, runtime_checkable, Union
99

1010
import torch
11+
import torch.distributed as dist
1112
from torch import nn
1213
from torchtune.utils._logging import deprecate_parameter
1314

@@ -194,6 +195,7 @@ def get_merged_lora_ckpt(
194195
state_dict: dict[str, Any],
195196
rank: int,
196197
alpha: float,
198+
use_distributed_barriers: bool = False,
197199
) -> dict[str, Any]:
198200
"""
199201
Merge LoRA weights into the base model format for efficient inference.
@@ -207,18 +209,24 @@ def get_merged_lora_ckpt(
207209
state_dict (dict[str, Any]): State dict from a model.
208210
rank (int): The rank of LoRA matrices.
209211
alpha (float): The alpha value used for scaling LoRA decompositions.
212+
use_distributed_barriers (bool): Whether to include a distributed barrier before operations.
213+
This is useful when using distributed operations like distributed matrix multiplication, to keep
214+
operations in sync across ranks. Default: False
210215
211216
Returns:
212217
dict[str, Any]: The merged state dict.
213218
"""
214219
lora_modules = _get_lora_modules(state_dict)
215220
lora_moe_modules = _get_lora_moe_modules(state_dict)
216-
for module in lora_modules.union(lora_moe_modules):
221+
for module in sorted(lora_modules.union(lora_moe_modules)):
217222
# TODO: we don't currently support DoRA for MoE layers
218223
if "experts" in module:
219224
for param in ["gate", "up", "down"]:
220225
lora_a_weight = state_dict[f"{module}.lora_{param}_a"]
221226
lora_b_weight = state_dict[f"{module}.lora_{param}_b"]
227+
228+
if use_distributed_barriers:
229+
dist.barrier()
222230
state_dict[f"{module}.{param}_proj"] += (
223231
(alpha / rank)
224232
* lora_b_weight.transpose(1, 2)
@@ -236,8 +244,13 @@ def get_merged_lora_ckpt(
236244
if lora_magnitude is not None:
237245
base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype)
238246

247+
if use_distributed_barriers:
248+
dist.barrier()
239249
lora_weight = (alpha / rank) * lora_b_weight @ lora_a_weight
240250
merged_weight = base_weight + lora_weight
251+
252+
if use_distributed_barriers:
253+
dist.barrier()
241254
weight_norm = torch.linalg.norm(base_weight + lora_weight, dim=1)
242255
mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1)
243256
merged_weight *= mag_norm_scale
@@ -246,6 +259,8 @@ def get_merged_lora_ckpt(
246259

247260
# Otherwise it is just vanilla LoRA
248261
else:
262+
if use_distributed_barriers:
263+
dist.barrier()
249264
state_dict[f"{module}.weight"] += (
250265
(alpha / rank) * lora_b_weight @ lora_a_weight
251266
)

torchtune/training/checkpointing/_checkpoint_client.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _save_checkpoint_async(
130130
epoch: int,
131131
adapter_config: Optional[dict[str, Any]],
132132
adapter_only: bool,
133+
single_device: bool,
133134
) -> None:
134135
"""
135136
Checkpoint the training state asynchronously as a distributed checkpoint. Saving
@@ -170,18 +171,15 @@ def _save_checkpoint_async(
170171
ckpt_dict[training.MODEL_KEY],
171172
adapter_config["r"],
172173
adapter_config["lora_alpha"],
174+
use_distributed_barriers=not single_device,
173175
)
174176

175177
dcp_saver = self._get_dcp_checkpointer()
176-
if not adapter_only:
177-
dcp_saver.save_checkpoint(ckpt_dict, epoch=epoch, save_async=True)
178-
179-
if self._is_rank_zero:
180-
log.info(
181-
f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs"
182-
)
183178

184179
if adapter_config is not None:
180+
# save adapter weights first because it is faster
181+
# so will block training for less time
182+
# because you can only do async checkpointing one at a time
185183
adapter_start = time.perf_counter()
186184

187185
save_path = dcp_saver.get_output_path(epoch=epoch)
@@ -205,6 +203,14 @@ def _save_checkpoint_async(
205203
f"Saving asynchronous checkpoint for adapter weights took {time.perf_counter() - adapter_start:.2f} secs"
206204
)
207205

206+
if not adapter_only:
207+
dcp_saver.save_checkpoint(ckpt_dict, epoch=epoch, save_async=True)
208+
209+
if self._is_rank_zero:
210+
log.info(
211+
f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs"
212+
)
213+
208214
def _save_checkpoint_sync(
209215
self,
210216
model: torch.nn.Module,
@@ -368,6 +374,7 @@ def save_checkpoint(
368374
epoch,
369375
adapter_config,
370376
adapter_only,
377+
single_device,
371378
)
372379
else:
373380
self._save_checkpoint_sync(
@@ -392,6 +399,7 @@ def load_distributed_checkpoint(
392399
model: torch.nn.Module,
393400
optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper],
394401
adapter_config: Optional[dict[str, Any]] = None,
402+
single_device: bool = False,
395403
) -> dict[str, Any]:
396404
"""
397405
This method is used to resume training from a distributed checkpoint state.
@@ -442,6 +450,7 @@ def load_distributed_checkpoint(
442450
checkpoint_dict[training.MODEL_KEY],
443451
adapter_config["r"],
444452
adapter_config["lora_alpha"],
453+
use_distributed_barriers=not single_device,
445454
)
446455

447456
adapter_only = False

0 commit comments

Comments
 (0)