From bca098dfbe4a5139e278a771cffe43df039b8c4a Mon Sep 17 00:00:00 2001 From: Ankita George Date: Tue, 17 Jun 2025 20:55:46 -0700 Subject: [PATCH 1/7] dist merge --- torchtune/modules/peft/__init__.py | 1 + torchtune/modules/peft/_utils.py | 118 ++++++++++++++++++ .../checkpointing/_checkpoint_client.py | 35 +++--- 3 files changed, 140 insertions(+), 14 deletions(-) diff --git a/torchtune/modules/peft/__init__.py b/torchtune/modules/peft/__init__.py index ddf983d885..827d8ea0f4 100644 --- a/torchtune/modules/peft/__init__.py +++ b/torchtune/modules/peft/__init__.py @@ -11,6 +11,7 @@ get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, + get_merged_lora_dist_ckpt, LORA_ATTN_MODULES, set_trainable_params, validate_missing_and_unexpected_for_lora, diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index d594e9cd1d..c637f3f545 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -8,7 +8,9 @@ from typing import Any, Generator, Literal, Optional, Protocol, runtime_checkable, Union import torch +import torch.distributed as dist from torch import nn +from torchao.dtypes.nf4tensor import NF4Tensor from torchtune.utils._logging import deprecate_parameter # Modules from MultiHeadAttention that LoRA can be applied to @@ -256,6 +258,122 @@ def get_merged_lora_ckpt( return state_dict +@torch.no_grad +def get_merged_lora_dist_ckpt( + state_dict: dict[str, Any], + rank: int, + alpha: float, +) -> dict[str, Any]: + """ + Merge LoRA weights into the base model format for efficient inference using distributed operations. + This function is designed for distributed training scenarios and uses distributed operations + like distributed matrix multiplication. + NOTE: This function modifies state_dict inplace. If you do not want to do that, + make a copy prior to calling this function. + NOTE: This does not work for NF4Tensors as they don't support the add and mul operations used here. + For every LoRA module in the state dict, this function will convert its + base weight then delete the LoRA-specific parameters. + Args: + state_dict (dict[str, Any]): State dict from a model. + rank (int): The rank of LoRA matrices. + alpha (float): The alpha value used for scaling LoRA decompositions. + Returns: + dict[str, Any]: The merged state dict. + """ + + lora_modules = _get_lora_modules(state_dict) + lora_moe_modules = _get_lora_moe_modules(state_dict) + + # Create a simple module for matrix multiplication + class MatMulModule(torch.nn.Module): + def forward(self, x, y): + return (alpha / rank) * torch.matmul(x, y) + + for module in sorted(lora_modules.union(lora_moe_modules)): + # TODO: we don't currently support DoRA for MoE layers + if "experts" in module: + for param in ["gate", "up", "down"]: + lora_a_weight = state_dict[f"{module}.lora_{param}_a"] + lora_b_weight = state_dict[f"{module}.lora_{param}_b"] + + # Create a simple module for transpose operation + class TransposeModule(torch.nn.Module): + def __init__(self, dim0, dim1): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + return torch.transpose(x, self.dim0, self.dim1) + + # Parallelize transpose operations + transpose_module = TransposeModule(1, 2) + dist.barrier() + # Apply distributed transpose + transposed_b = transpose_module(lora_b_weight) + transposed_a = transpose_module(lora_a_weight) + + mm_module = MatMulModule() + dist.barrier() + result = mm_module(transposed_b, transposed_a) + + # Apply the result using out-of-place addition + proj_weight = state_dict[f"{module}.{param}_proj"] + + dist.barrier() + transposed_result = transpose_module(result) + + state_dict[f"{module}.{param}_proj"] = proj_weight + transposed_result + + del state_dict[f"{module}.lora_{param}_a"] + del state_dict[f"{module}.lora_{param}_b"] + continue + + lora_a_weight = state_dict[f"{module}.lora_a.weight"] + lora_b_weight = state_dict[f"{module}.lora_b.weight"] + lora_magnitude = state_dict.get(f"{module}.magnitude", None) + + # If magnitude is present, calculate merged DoRA weight + if lora_magnitude is not None: + base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype) + + mm_module = MatMulModule() + dist.barrier() + lora_weight = mm_module(lora_b_weight, lora_a_weight) + + merged_weight = base_weight + lora_weight + dist.barrier() + + # Create a simple module for norm calculation + class NormModule(torch.nn.Module): + def forward(self, x): + return torch.linalg.norm(x, dim=1) + + norm_module = NormModule() + dist.barrier() + weight_norm = norm_module(merged_weight) + + mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1) + merged_weight *= mag_norm_scale + state_dict[f"{module}.weight"] = merged_weight + del state_dict[f"{module}.magnitude"] + + # Otherwise it is just vanilla LoRA + else: + mm_module = MatMulModule() + dist.barrier() + lora_weight = mm_module( + lora_b_weight, + lora_a_weight, + ) + + del state_dict[f"{module}.lora_a.weight"] + del state_dict[f"{module}.lora_b.weight"] + + dist.barrier() + return state_dict + + @contextlib.contextmanager def disable_adapter(model: nn.Module) -> Generator[None, None, None]: """ diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 9cf32717e2..a74545145b 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -23,6 +23,7 @@ from torchtune.modules.peft import ( get_adapter_state_dict, get_merged_lora_ckpt, + get_merged_lora_dist_ckpt, validate_missing_and_unexpected_for_lora, ) from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer @@ -87,7 +88,7 @@ def __init__( device = self._cfg.get("device", None) self._device = utils.get_device(device=device) - _, self._rank = utils.get_world_size_and_rank() + self._world_size, self._rank = utils.get_world_size_and_rank() self._is_rank_zero = self._rank == 0 def _get_checkpointer(self): @@ -166,11 +167,18 @@ def _save_checkpoint_async( } ) - get_merged_lora_ckpt( - ckpt_dict[training.MODEL_KEY], - adapter_config["r"], - adapter_config["lora_alpha"], - ) + if self._world_size == 0: + get_merged_lora_ckpt( + ckpt_dict[training.MODEL_KEY], + adapter_config["r"], + adapter_config["lora_alpha"], + ) + else: + get_merged_lora_dist_ckpt( + ckpt_dict[training.MODEL_KEY], + adapter_config["r"], + adapter_config["lora_alpha"], + ) dcp_saver = self._get_dcp_checkpointer() if not adapter_only: @@ -269,10 +277,10 @@ def _save_checkpoint_sync( # This check can be removed once we fully migrate over to ``OptimizerInBackward`` if isinstance(optimizer, OptimizerInBackwardWrapper): for param, opt in optimizer.optim_map.items(): - optim_state_dict[ - param - ] = training.get_full_optimizer_state_dict( - model, opt, self._is_rank_zero, device=self._device + optim_state_dict[param] = ( + training.get_full_optimizer_state_dict( + model, opt, self._is_rank_zero, device=self._device + ) ) elif isinstance(optimizer, OptimizerInBackward): optim_state_dict = optimizer.state_dict() @@ -359,7 +367,6 @@ def save_checkpoint( checkpointer user has configured. """ intermediate_checkpoint = epoch + 1 < training_progress.total_epochs - if intermediate_checkpoint and self._enable_async_checkpointing: self._save_checkpoint_async( model, @@ -414,9 +421,9 @@ def load_distributed_checkpoint( if "param_groups" in optim_state_dict: for param_group in optim_state_dict["param_groups"]: if param_group.get("initial_lr") is None: - param_group[ - "initial_lr" - ] = 0.0 # This will get overriden by the actual value in optimizer + param_group["initial_lr"] = ( + 0.0 # This will get overriden by the actual value in optimizer + ) checkpoint_dict.update( { From 172174c7cf7e54f18d813e557b9077b327f68304 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Wed, 18 Jun 2025 05:57:52 -0700 Subject: [PATCH 2/7] add single device as arg --- torchtune/training/checkpointing/_checkpoint_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index a74545145b..be9c844ded 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -88,7 +88,7 @@ def __init__( device = self._cfg.get("device", None) self._device = utils.get_device(device=device) - self._world_size, self._rank = utils.get_world_size_and_rank() + _, self._rank = utils.get_world_size_and_rank() self._is_rank_zero = self._rank == 0 def _get_checkpointer(self): @@ -131,6 +131,7 @@ def _save_checkpoint_async( epoch: int, adapter_config: Optional[dict[str, Any]], adapter_only: bool, + single_device: bool, ) -> None: """ Checkpoint the training state asynchronously as a distributed checkpoint. Saving @@ -167,7 +168,7 @@ def _save_checkpoint_async( } ) - if self._world_size == 0: + if single_device: get_merged_lora_ckpt( ckpt_dict[training.MODEL_KEY], adapter_config["r"], @@ -375,6 +376,7 @@ def save_checkpoint( epoch, adapter_config, adapter_only, + single_device, ) else: self._save_checkpoint_sync( From 72d5636dfb87cde345da71f9f5f51d8814555a6c Mon Sep 17 00:00:00 2001 From: Ankita George Date: Wed, 18 Jun 2025 06:46:51 -0700 Subject: [PATCH 3/7] fix vanilla lora --- torchtune/modules/peft/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index c637f3f545..fdd5109cb2 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -10,7 +10,6 @@ import torch import torch.distributed as dist from torch import nn -from torchao.dtypes.nf4tensor import NF4Tensor from torchtune.utils._logging import deprecate_parameter # Modules from MultiHeadAttention that LoRA can be applied to @@ -366,6 +365,7 @@ def forward(self, x): lora_b_weight, lora_a_weight, ) + state_dict[f"{module}.weight"] += lora_weight del state_dict[f"{module}.lora_a.weight"] del state_dict[f"{module}.lora_b.weight"] From b0f211e9cfc05f56feeb6c2906133f80603d6ada Mon Sep 17 00:00:00 2001 From: Ankita George Date: Wed, 18 Jun 2025 07:18:19 -0700 Subject: [PATCH 4/7] change order or checkpoint save --- .../checkpointing/_checkpoint_client.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index be9c844ded..7e51fafb6a 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -182,15 +182,11 @@ def _save_checkpoint_async( ) dcp_saver = self._get_dcp_checkpointer() - if not adapter_only: - dcp_saver.save_checkpoint(ckpt_dict, epoch=epoch, save_async=True) - - if self._is_rank_zero: - log.info( - f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs" - ) if adapter_config is not None: + # save adapter weights first because it is faster + # so will block training for less time + # because you can only do async checkpointing one at a time adapter_start = time.perf_counter() save_path = dcp_saver.get_output_path(epoch=epoch) @@ -214,6 +210,14 @@ def _save_checkpoint_async( f"Saving asynchronous checkpoint for adapter weights took {time.perf_counter() - adapter_start:.2f} secs" ) + if not adapter_only: + dcp_saver.save_checkpoint(ckpt_dict, epoch=epoch, save_async=True) + + if self._is_rank_zero: + log.info( + f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs" + ) + def _save_checkpoint_sync( self, model: torch.nn.Module, @@ -278,10 +282,10 @@ def _save_checkpoint_sync( # This check can be removed once we fully migrate over to ``OptimizerInBackward`` if isinstance(optimizer, OptimizerInBackwardWrapper): for param, opt in optimizer.optim_map.items(): - optim_state_dict[param] = ( - training.get_full_optimizer_state_dict( - model, opt, self._is_rank_zero, device=self._device - ) + optim_state_dict[ + param + ] = training.get_full_optimizer_state_dict( + model, opt, self._is_rank_zero, device=self._device ) elif isinstance(optimizer, OptimizerInBackward): optim_state_dict = optimizer.state_dict() @@ -423,9 +427,9 @@ def load_distributed_checkpoint( if "param_groups" in optim_state_dict: for param_group in optim_state_dict["param_groups"]: if param_group.get("initial_lr") is None: - param_group["initial_lr"] = ( - 0.0 # This will get overriden by the actual value in optimizer - ) + param_group[ + "initial_lr" + ] = 0.0 # This will get overriden by the actual value in optimizer checkpoint_dict.update( { From 00be546b2a32ebe1af1f0d5ee5df7b1be0e9c4db Mon Sep 17 00:00:00 2001 From: Ankita George Date: Fri, 20 Jun 2025 11:03:02 -0700 Subject: [PATCH 5/7] fix load too --- recipes/full_finetune_single_device.py | 1 + .../knowledge_distillation_single_device.py | 2 +- recipes/lora_dpo_single_device.py | 1 + recipes/lora_finetune_single_device.py | 1 + .../checkpointing/_checkpoint_client.py | 18 +++++++++++++----- 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 16aa0dbb0e..371d056e8b 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -289,6 +289,7 @@ def setup(self, cfg: DictConfig) -> None: ckpt_dict = self._checkpoint_client.load_distributed_checkpoint( self._model, self.optimizer, + single_device=True, ) except Exception as e: self._logger.warning( diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index 0947832ad2..3ca39dc03e 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -248,7 +248,7 @@ def setup(self, cfg: DictConfig) -> None: self._model, self._optimizer, self._adapter_config, - self._save_adapter_weights_only, + single_device=True, ) if training.ADAPTER_KEY not in checkpoint_dict: diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index d8149d8c8b..288ca5a2d6 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -233,6 +233,7 @@ def setup(self, cfg: DictConfig) -> None: self._model, self._optimizer, self._adapter_config, + single_device=True, ) if training.ADAPTER_KEY not in checkpoint_dict: diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index a56a4df269..1a61bb1289 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -285,6 +285,7 @@ def setup(self, cfg: DictConfig) -> None: self._model, self._optimizer, self._adapter_config, + single_device=True, ) if training.ADAPTER_KEY not in checkpoint_dict: diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 7e51fafb6a..da0a2841d4 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -405,6 +405,7 @@ def load_distributed_checkpoint( model: torch.nn.Module, optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper], adapter_config: Optional[dict[str, Any]] = None, + single_device: bool = False, ) -> dict[str, Any]: """ This method is used to resume training from a distributed checkpoint state. @@ -451,11 +452,18 @@ def load_distributed_checkpoint( } ) - get_merged_lora_ckpt( - checkpoint_dict[training.MODEL_KEY], - adapter_config["r"], - adapter_config["lora_alpha"], - ) + if single_device: + get_merged_lora_ckpt( + checkpoint_dict[training.MODEL_KEY], + adapter_config["r"], + adapter_config["lora_alpha"], + ) + else: + get_merged_lora_dist_ckpt( + checkpoint_dict[training.MODEL_KEY], + adapter_config["r"], + adapter_config["lora_alpha"], + ) adapter_only = False dcp_checkpointer = self._get_dcp_checkpointer() From 6e2920248df68894b5094b91c825337244c57b7b Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 23 Jun 2025 13:03:16 -0700 Subject: [PATCH 6/7] separate dist method not needed --- recipes/configs/llama3_2/3B_lora.yaml | 2 +- torchtune/modules/peft/__init__.py | 1 - torchtune/modules/peft/_utils.py | 133 ++---------------- .../checkpointing/_checkpoint_client.py | 37 ++--- 4 files changed, 28 insertions(+), 145 deletions(-) diff --git a/recipes/configs/llama3_2/3B_lora.yaml b/recipes/configs/llama3_2/3B_lora.yaml index eb220d794a..f5cd065de9 100644 --- a/recipes/configs/llama3_2/3B_lora.yaml +++ b/recipes/configs/llama3_2/3B_lora.yaml @@ -77,7 +77,7 @@ loss: _component_: torchtune.modules.loss.LinearCrossEntropyLoss # Training -epochs: 1 +epochs: 2 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size clip_grad_norm: null diff --git a/torchtune/modules/peft/__init__.py b/torchtune/modules/peft/__init__.py index 827d8ea0f4..ddf983d885 100644 --- a/torchtune/modules/peft/__init__.py +++ b/torchtune/modules/peft/__init__.py @@ -11,7 +11,6 @@ get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, - get_merged_lora_dist_ckpt, LORA_ATTN_MODULES, set_trainable_params, validate_missing_and_unexpected_for_lora, diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index fdd5109cb2..5874a46363 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -195,6 +195,7 @@ def get_merged_lora_ckpt( state_dict: dict[str, Any], rank: int, alpha: float, + use_distributed_barriers: bool = False, ) -> dict[str, Any]: """ Merge LoRA weights into the base model format for efficient inference. @@ -208,18 +209,24 @@ def get_merged_lora_ckpt( state_dict (dict[str, Any]): State dict from a model. rank (int): The rank of LoRA matrices. alpha (float): The alpha value used for scaling LoRA decompositions. + use_distributed_barriers (bool): Whether to include a distributed barrier before operations. + This is useful when using distributed operations like distributed matrix multiplication, to keep + operations in sync across ranks. Default: False Returns: dict[str, Any]: The merged state dict. """ lora_modules = _get_lora_modules(state_dict) lora_moe_modules = _get_lora_moe_modules(state_dict) - for module in lora_modules.union(lora_moe_modules): + for module in sorted(lora_modules.union(lora_moe_modules)): # TODO: we don't currently support DoRA for MoE layers if "experts" in module: for param in ["gate", "up", "down"]: lora_a_weight = state_dict[f"{module}.lora_{param}_a"] lora_b_weight = state_dict[f"{module}.lora_{param}_b"] + + if use_distributed_barriers: + dist.barrier() state_dict[f"{module}.{param}_proj"] += ( (alpha / rank) * lora_b_weight.transpose(1, 2) @@ -237,8 +244,13 @@ def get_merged_lora_ckpt( if lora_magnitude is not None: base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype) + if use_distributed_barriers: + dist.barrier() lora_weight = (alpha / rank) * lora_b_weight @ lora_a_weight merged_weight = base_weight + lora_weight + + if use_distributed_barriers: + dist.barrier() weight_norm = torch.linalg.norm(base_weight + lora_weight, dim=1) mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1) merged_weight *= mag_norm_scale @@ -247,6 +259,8 @@ def get_merged_lora_ckpt( # Otherwise it is just vanilla LoRA else: + if use_distributed_barriers: + dist.barrier() state_dict[f"{module}.weight"] += ( (alpha / rank) * lora_b_weight @ lora_a_weight ) @@ -257,123 +271,6 @@ def get_merged_lora_ckpt( return state_dict -@torch.no_grad -def get_merged_lora_dist_ckpt( - state_dict: dict[str, Any], - rank: int, - alpha: float, -) -> dict[str, Any]: - """ - Merge LoRA weights into the base model format for efficient inference using distributed operations. - This function is designed for distributed training scenarios and uses distributed operations - like distributed matrix multiplication. - NOTE: This function modifies state_dict inplace. If you do not want to do that, - make a copy prior to calling this function. - NOTE: This does not work for NF4Tensors as they don't support the add and mul operations used here. - For every LoRA module in the state dict, this function will convert its - base weight then delete the LoRA-specific parameters. - Args: - state_dict (dict[str, Any]): State dict from a model. - rank (int): The rank of LoRA matrices. - alpha (float): The alpha value used for scaling LoRA decompositions. - Returns: - dict[str, Any]: The merged state dict. - """ - - lora_modules = _get_lora_modules(state_dict) - lora_moe_modules = _get_lora_moe_modules(state_dict) - - # Create a simple module for matrix multiplication - class MatMulModule(torch.nn.Module): - def forward(self, x, y): - return (alpha / rank) * torch.matmul(x, y) - - for module in sorted(lora_modules.union(lora_moe_modules)): - # TODO: we don't currently support DoRA for MoE layers - if "experts" in module: - for param in ["gate", "up", "down"]: - lora_a_weight = state_dict[f"{module}.lora_{param}_a"] - lora_b_weight = state_dict[f"{module}.lora_{param}_b"] - - # Create a simple module for transpose operation - class TransposeModule(torch.nn.Module): - def __init__(self, dim0, dim1): - super().__init__() - self.dim0 = dim0 - self.dim1 = dim1 - - def forward(self, x): - return torch.transpose(x, self.dim0, self.dim1) - - # Parallelize transpose operations - transpose_module = TransposeModule(1, 2) - dist.barrier() - # Apply distributed transpose - transposed_b = transpose_module(lora_b_weight) - transposed_a = transpose_module(lora_a_weight) - - mm_module = MatMulModule() - dist.barrier() - result = mm_module(transposed_b, transposed_a) - - # Apply the result using out-of-place addition - proj_weight = state_dict[f"{module}.{param}_proj"] - - dist.barrier() - transposed_result = transpose_module(result) - - state_dict[f"{module}.{param}_proj"] = proj_weight + transposed_result - - del state_dict[f"{module}.lora_{param}_a"] - del state_dict[f"{module}.lora_{param}_b"] - continue - - lora_a_weight = state_dict[f"{module}.lora_a.weight"] - lora_b_weight = state_dict[f"{module}.lora_b.weight"] - lora_magnitude = state_dict.get(f"{module}.magnitude", None) - - # If magnitude is present, calculate merged DoRA weight - if lora_magnitude is not None: - base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype) - - mm_module = MatMulModule() - dist.barrier() - lora_weight = mm_module(lora_b_weight, lora_a_weight) - - merged_weight = base_weight + lora_weight - dist.barrier() - - # Create a simple module for norm calculation - class NormModule(torch.nn.Module): - def forward(self, x): - return torch.linalg.norm(x, dim=1) - - norm_module = NormModule() - dist.barrier() - weight_norm = norm_module(merged_weight) - - mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1) - merged_weight *= mag_norm_scale - state_dict[f"{module}.weight"] = merged_weight - del state_dict[f"{module}.magnitude"] - - # Otherwise it is just vanilla LoRA - else: - mm_module = MatMulModule() - dist.barrier() - lora_weight = mm_module( - lora_b_weight, - lora_a_weight, - ) - state_dict[f"{module}.weight"] += lora_weight - - del state_dict[f"{module}.lora_a.weight"] - del state_dict[f"{module}.lora_b.weight"] - - dist.barrier() - return state_dict - - @contextlib.contextmanager def disable_adapter(model: nn.Module) -> Generator[None, None, None]: """ diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index da0a2841d4..55e6f1c708 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -23,7 +23,6 @@ from torchtune.modules.peft import ( get_adapter_state_dict, get_merged_lora_ckpt, - get_merged_lora_dist_ckpt, validate_missing_and_unexpected_for_lora, ) from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer @@ -168,18 +167,12 @@ def _save_checkpoint_async( } ) - if single_device: - get_merged_lora_ckpt( - ckpt_dict[training.MODEL_KEY], - adapter_config["r"], - adapter_config["lora_alpha"], - ) - else: - get_merged_lora_dist_ckpt( - ckpt_dict[training.MODEL_KEY], - adapter_config["r"], - adapter_config["lora_alpha"], - ) + get_merged_lora_ckpt( + ckpt_dict[training.MODEL_KEY], + adapter_config["r"], + adapter_config["lora_alpha"], + use_distributed_barriers=not single_device, + ) dcp_saver = self._get_dcp_checkpointer() @@ -452,18 +445,12 @@ def load_distributed_checkpoint( } ) - if single_device: - get_merged_lora_ckpt( - checkpoint_dict[training.MODEL_KEY], - adapter_config["r"], - adapter_config["lora_alpha"], - ) - else: - get_merged_lora_dist_ckpt( - checkpoint_dict[training.MODEL_KEY], - adapter_config["r"], - adapter_config["lora_alpha"], - ) + get_merged_lora_ckpt( + checkpoint_dict[training.MODEL_KEY], + adapter_config["r"], + adapter_config["lora_alpha"], + use_distributed_barriers=not single_device, + ) adapter_only = False dcp_checkpointer = self._get_dcp_checkpointer() From 9a0a6eb831bcc6b1c6f16fa9570fb4532d6d80bb Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 23 Jun 2025 13:22:44 -0700 Subject: [PATCH 7/7] fix config --- recipes/configs/llama3_2/3B_lora.yaml | 2 +- torchtune/training/checkpointing/_checkpoint_client.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/configs/llama3_2/3B_lora.yaml b/recipes/configs/llama3_2/3B_lora.yaml index f5cd065de9..eb220d794a 100644 --- a/recipes/configs/llama3_2/3B_lora.yaml +++ b/recipes/configs/llama3_2/3B_lora.yaml @@ -77,7 +77,7 @@ loss: _component_: torchtune.modules.loss.LinearCrossEntropyLoss # Training -epochs: 2 +epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 # Use to increase effective batch size clip_grad_norm: null diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 55e6f1c708..293a84af33 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -365,6 +365,7 @@ def save_checkpoint( checkpointer user has configured. """ intermediate_checkpoint = epoch + 1 < training_progress.total_epochs + if intermediate_checkpoint and self._enable_async_checkpointing: self._save_checkpoint_async( model,