diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 16aa0dbb0e..371d056e8b 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -289,6 +289,7 @@ def setup(self, cfg: DictConfig) -> None: ckpt_dict = self._checkpoint_client.load_distributed_checkpoint( self._model, self.optimizer, + single_device=True, ) except Exception as e: self._logger.warning( diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index 0947832ad2..3ca39dc03e 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -248,7 +248,7 @@ def setup(self, cfg: DictConfig) -> None: self._model, self._optimizer, self._adapter_config, - self._save_adapter_weights_only, + single_device=True, ) if training.ADAPTER_KEY not in checkpoint_dict: diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index d8149d8c8b..288ca5a2d6 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -233,6 +233,7 @@ def setup(self, cfg: DictConfig) -> None: self._model, self._optimizer, self._adapter_config, + single_device=True, ) if training.ADAPTER_KEY not in checkpoint_dict: diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index a56a4df269..1a61bb1289 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -285,6 +285,7 @@ def setup(self, cfg: DictConfig) -> None: self._model, self._optimizer, self._adapter_config, + single_device=True, ) if training.ADAPTER_KEY not in checkpoint_dict: diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index d594e9cd1d..5874a46363 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -8,6 +8,7 @@ from typing import Any, Generator, Literal, Optional, Protocol, runtime_checkable, Union import torch +import torch.distributed as dist from torch import nn from torchtune.utils._logging import deprecate_parameter @@ -194,6 +195,7 @@ def get_merged_lora_ckpt( state_dict: dict[str, Any], rank: int, alpha: float, + use_distributed_barriers: bool = False, ) -> dict[str, Any]: """ Merge LoRA weights into the base model format for efficient inference. @@ -207,18 +209,24 @@ def get_merged_lora_ckpt( state_dict (dict[str, Any]): State dict from a model. rank (int): The rank of LoRA matrices. alpha (float): The alpha value used for scaling LoRA decompositions. + use_distributed_barriers (bool): Whether to include a distributed barrier before operations. + This is useful when using distributed operations like distributed matrix multiplication, to keep + operations in sync across ranks. Default: False Returns: dict[str, Any]: The merged state dict. """ lora_modules = _get_lora_modules(state_dict) lora_moe_modules = _get_lora_moe_modules(state_dict) - for module in lora_modules.union(lora_moe_modules): + for module in sorted(lora_modules.union(lora_moe_modules)): # TODO: we don't currently support DoRA for MoE layers if "experts" in module: for param in ["gate", "up", "down"]: lora_a_weight = state_dict[f"{module}.lora_{param}_a"] lora_b_weight = state_dict[f"{module}.lora_{param}_b"] + + if use_distributed_barriers: + dist.barrier() state_dict[f"{module}.{param}_proj"] += ( (alpha / rank) * lora_b_weight.transpose(1, 2) @@ -236,8 +244,13 @@ def get_merged_lora_ckpt( if lora_magnitude is not None: base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype) + if use_distributed_barriers: + dist.barrier() lora_weight = (alpha / rank) * lora_b_weight @ lora_a_weight merged_weight = base_weight + lora_weight + + if use_distributed_barriers: + dist.barrier() weight_norm = torch.linalg.norm(base_weight + lora_weight, dim=1) mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1) merged_weight *= mag_norm_scale @@ -246,6 +259,8 @@ def get_merged_lora_ckpt( # Otherwise it is just vanilla LoRA else: + if use_distributed_barriers: + dist.barrier() state_dict[f"{module}.weight"] += ( (alpha / rank) * lora_b_weight @ lora_a_weight ) diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 9cf32717e2..293a84af33 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -130,6 +130,7 @@ def _save_checkpoint_async( epoch: int, adapter_config: Optional[dict[str, Any]], adapter_only: bool, + single_device: bool, ) -> None: """ Checkpoint the training state asynchronously as a distributed checkpoint. Saving @@ -170,18 +171,15 @@ def _save_checkpoint_async( ckpt_dict[training.MODEL_KEY], adapter_config["r"], adapter_config["lora_alpha"], + use_distributed_barriers=not single_device, ) dcp_saver = self._get_dcp_checkpointer() - if not adapter_only: - dcp_saver.save_checkpoint(ckpt_dict, epoch=epoch, save_async=True) - - if self._is_rank_zero: - log.info( - f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs" - ) if adapter_config is not None: + # save adapter weights first because it is faster + # so will block training for less time + # because you can only do async checkpointing one at a time adapter_start = time.perf_counter() save_path = dcp_saver.get_output_path(epoch=epoch) @@ -205,6 +203,14 @@ def _save_checkpoint_async( f"Saving asynchronous checkpoint for adapter weights took {time.perf_counter() - adapter_start:.2f} secs" ) + if not adapter_only: + dcp_saver.save_checkpoint(ckpt_dict, epoch=epoch, save_async=True) + + if self._is_rank_zero: + log.info( + f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs" + ) + def _save_checkpoint_sync( self, model: torch.nn.Module, @@ -368,6 +374,7 @@ def save_checkpoint( epoch, adapter_config, adapter_only, + single_device, ) else: self._save_checkpoint_sync( @@ -392,6 +399,7 @@ def load_distributed_checkpoint( model: torch.nn.Module, optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper], adapter_config: Optional[dict[str, Any]] = None, + single_device: bool = False, ) -> dict[str, Any]: """ This method is used to resume training from a distributed checkpoint state. @@ -442,6 +450,7 @@ def load_distributed_checkpoint( checkpoint_dict[training.MODEL_KEY], adapter_config["r"], adapter_config["lora_alpha"], + use_distributed_barriers=not single_device, ) adapter_only = False