Skip to content

Commit 0c2bdbb

Browse files
authored
FEAT Add LoRA-FA to PEFT (#2468)
Adds LoRA with frozen A (LoRA-FA) to PEFT. Paper: https://arxiv.org/abs/2308.03303
1 parent 13c81df commit 0c2bdbb

File tree

6 files changed

+771
-3
lines changed

6 files changed

+771
-3
lines changed

docs/source/developer_guides/lora.md

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,40 @@ The same logic applies to `alpha_pattern`. If you're in doubt, don't try to get
271271

272272
## Optimizers
273273

274-
LoRA training can optionally include special purpose optimizers. Currently the only such optimizer is LoRA+.
274+
LoRA training can optionally include special purpose optimizers. Currently PEFT supports LoRA-FA and LoRA+.
275+
276+
### LoRA-FA Optimizer
277+
278+
LoRA training can be more effective and efficient using LoRA-FA, as described in [LoRA-FA](https://arxiv.org/abs/2308.03303). LoRA-FA reduces activation memory consumption by fixing the matrix A and only tuning the matrix B. During training, the gradient of B is optimized to approximate the full parameter fine-tuning gradient. Moreover, the memory consumption of LoRA-FA is not sensitive to the rank (since it erases the activation of $A$), therefore it can improve performance by enlarging lora rank without increasing memory consumption.
279+
280+
```py
281+
from peft import LoraConfig, get_peft_model
282+
from peft.optimizers import create_lorafa_optimizer
283+
from transformers import Trainer, get_cosine_schedule_with_warmup
284+
285+
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
286+
287+
config = LoraConfig(...)
288+
model = get_peft_model(base_model, config)
289+
290+
optimizer = create_lorafa_optimizer(
291+
model=model,
292+
r=128,
293+
lora_alpha=32,
294+
lr=7e-5,
295+
)
296+
297+
scheduler = get_cosine_schedule_with_warmup(
298+
optimizer,
299+
num_warmup_steps=100,
300+
num_training_steps=1000,
301+
)
302+
303+
trainer = Trainer(
304+
...,
305+
optimizers=(optimizer, scheduler),
306+
)
307+
```
275308

276309
### LoRA+ optimized LoRA
277310

examples/lorafa_finetune/README.md

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# LoRA-FA: Memory-efficient Low-rank Adaptation for Large Language Models Fine-tuning
2+
3+
## Introduction
4+
5+
[LoRA-FA](https://arxiv.org/abs/2308.03303) is a noval Parameter-efficient Fine-tuning method, which freezes the projection down layer (matrix A) during LoRA training process and thus lead to less GPU memory consumption by eliminating the need for storing the activations of input tensors (X). Furthermore, LoRA-FA narrows the gap between the update amount of pre-trained weights when using the low-rank fine-tuning method and the full fine-tuning method. In conclusion, LoRA-FA reduces the memory consumption and leads to superior performance compared to vanilla LoRA.
6+
7+
## Quick start
8+
9+
```python
10+
import torch
11+
from peft import LoraConfig, get_peft_model
12+
from peft.optimizers import create_lorafa_optimizer
13+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
14+
from datasets import load_dataset
15+
16+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
17+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
18+
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
19+
20+
lora_rank = 16
21+
lora_alpha = 32
22+
23+
lora_config = LoraConfig(
24+
r=lora_rank,
25+
lora_alpha=lora_alpha,
26+
bias="none",
27+
)
28+
peft_model = get_peft_model(model, lora_config)
29+
optimizer = create_lorafa_optimizer(
30+
model=peft_model,
31+
r=lora_rank,
32+
lora_alpha=lora_alpha,
33+
lr=7e-5,
34+
)
35+
# you can also use scheduler, we recommend get_cosine_schedule_with_warmup from transformers
36+
# for better model performance
37+
scheduler = None
38+
39+
trainer = transformers.Trainer(
40+
model=peft_model,
41+
train_dataset=dataset,
42+
dataset_text_field="text",
43+
max_seq_length=2048,
44+
tokenizer=tokenizer,
45+
optimizers=(optimizer, None),
46+
)
47+
trainer.train()
48+
peft_model.save_pretrained("lorafa-llama-3-8b-inst")
49+
```
50+
51+
The only change in your code is to pass the LoRA-FA optimizer to the trainer (if training with trainer). Do not forget `from peft.optimizers import create_lorafa_optimizer`!
52+
53+
## Example
54+
55+
In this dir, we also provide you a simple example for fine-tuning with LoRA-FA optimizer.
56+
57+
### Run on CPU, single-GPU or multi-GPU
58+
59+
This 👇 by default will load the model in peft set up with LoRA config, and train the model with LoRA-FA optimizer.
60+
61+
0. CPU
62+
63+
You can simply run LoRA-FA as below:
64+
65+
```bash
66+
python lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --output_dir path/to/output --lorafa
67+
```
68+
69+
1. Single-GPU
70+
71+
Run the finetuning script on 1 GPU:
72+
73+
```bash
74+
CUDA_VISIBLE_DEVICES=0 python lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --output_dir path/to/output --lorafa
75+
```
76+
77+
2. Multi-GPU
78+
79+
LoRA-FA can also be run on multi-GPU, with 🤗 Accelerate:
80+
81+
```bash
82+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --output_dir path/to/output --lorafa
83+
```
84+
85+
The `accelerate launch` will automatically configure multi-GPU for you. You can also utilize `accelerate launch` in single-GPU scenario.
86+
87+
### Use the model from 🤗
88+
You can load and use the model as any other 🤗 models.
89+
```python
90+
from transformers import AutoModel
91+
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
92+
```
93+
94+
## Best practice in fine-tuning Llama using LoRA-FA: the hyper-params
95+
96+
Sometimes, achieving optimal LoRA fine-tuning can be challenging due to the larger number of hyperparameters to consider compared to full fine-tuning. For instance, not only do we need to adjust the commonly used learning rate, but the ideal LoRA rank may also vary depending on the specific model and task. Additionally, there are other factors to consider, such as LoRA alpha and sequence length. To assist with this, we have created a repository of reproducible best practices in the [LoRA-FA examples](https://github.com/AaronZLT/lorafa) for reference. This resource showcases the optimal LoRA-FA fine-tuning hyperparameters for different models across various datasets. By doing so, we significantly reduce the time and effort spent on hyperparameter tuning, and it may also provide insights for tuning other training hyperparameters. We encourage you to experiment and fine-tune on your own downstream tasks as well.
97+
98+
## LoRA-FA's advantages and limitations
99+
100+
By eliminating the activation of adapter A, LoRA-FA uses less memory for fine-tuning compared to LoRA. For instance, when fine-tuning Llama-2-7b-chat-hf with a batch size of 8 and a sequence length of 1024, LoRA-FA requires 36GB of memory to store activations. This allows it to run successfully on an 80GB GPU. In contrast, LoRA requires at least 60GB of memory for activations, leading to an Out of Memory (OOM) error. Additionally, the memory consumption of LoRA-FA is not sensitive to the rank, allowing for performance improvements by increasing the LoRA rank without additional memory usage. LoRA-FA further narrows the performance gap with Full-FT by minimizing the discrepancy between the low-rank gradient and the full gradient, enabling it to achieve performance that is on par with or even superior to vanilla LoRA.
101+
102+
Despite its advantages, LoRA-FA is inherently limited by its low-rank approximation nature and potential issues with catastrophic forgetting. The gradient approximation can impact training throughput. Addressing these limitations, especially in terms of approximation accuracy and forgetting phenomena, presents a promising direction for future research.
103+
104+
## Citation
105+
```
106+
@misc{zhang2023lorafamemoryefficientlowrankadaptation,
107+
title={LoRA-FA: Memory-efficient Low-rank Adaptation for Large Language Models Fine-tuning},
108+
author={Longteng Zhang and Lin Zhang and Shaohuai Shi and Xiaowen Chu and Bo Li},
109+
year={2023},
110+
eprint={2308.03303},
111+
archivePrefix={arXiv},
112+
primaryClass={cs.CL},
113+
url={https://arxiv.org/abs/2308.03303},
114+
}
115+
```
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Copyright 2025-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import Optional
17+
18+
import torch
19+
from datasets import load_dataset
20+
from transformers import (
21+
AutoModelForCausalLM,
22+
AutoTokenizer,
23+
BitsAndBytesConfig,
24+
DataCollatorForLanguageModeling,
25+
Trainer,
26+
TrainingArguments,
27+
)
28+
29+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
30+
from peft.optimizers import create_lorafa_optimizer
31+
32+
33+
def train_model(
34+
base_model_name_or_path: str,
35+
dataset_name_or_path: str,
36+
output_dir: str,
37+
batch_size: int,
38+
num_epochs: int,
39+
lr: float,
40+
cutoff_len: int,
41+
quantize: bool,
42+
eval_step: int,
43+
save_step: int,
44+
lora_rank: int,
45+
lora_alpha: int,
46+
lora_dropout: float,
47+
lora_target_modules: Optional[str],
48+
lorafa: bool,
49+
):
50+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
51+
52+
compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
53+
device_map = "cuda" if torch.cuda.is_available() else None
54+
55+
# load tokenizer
56+
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
57+
58+
# load model
59+
if quantize:
60+
model = AutoModelForCausalLM.from_pretrained(
61+
base_model_name_or_path,
62+
quantization_config=BitsAndBytesConfig(
63+
load_in_4bit=True,
64+
bnb_4bit_compute_dtype=compute_dtype,
65+
bnb_4bit_use_double_quant=False,
66+
bnb_4bit_quant_type="nf4",
67+
),
68+
torch_dtype=compute_dtype,
69+
device_map=device_map,
70+
)
71+
# setup for quantized training
72+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
73+
else:
74+
model = AutoModelForCausalLM.from_pretrained(
75+
base_model_name_or_path, torch_dtype=compute_dtype, device_map=device_map
76+
)
77+
78+
# LoRA config for the PEFT model
79+
if lora_target_modules is not None:
80+
if lora_target_modules == "all-linear":
81+
target_modules = "all-linear"
82+
else:
83+
target_modules = lora_target_modules.split(",")
84+
else:
85+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
86+
87+
lora_config = LoraConfig(
88+
r=lora_rank,
89+
lora_alpha=lora_alpha,
90+
target_modules=target_modules,
91+
lora_dropout=lora_dropout,
92+
bias="none",
93+
)
94+
95+
# get the peft model with LoRA config
96+
model = get_peft_model(model, lora_config)
97+
98+
tokenizer.pad_token = tokenizer.eos_token
99+
100+
# Load the dataset
101+
dataset = load_dataset(dataset_name_or_path)
102+
103+
def tokenize_function(examples):
104+
inputs = tokenizer(examples["query"], padding="max_length", truncation=True, max_length=cutoff_len)
105+
outputs = tokenizer(examples["response"], padding="max_length", truncation=True, max_length=cutoff_len)
106+
inputs["labels"] = outputs["input_ids"].copy()
107+
return inputs
108+
109+
# Tokenize the dataset and prepare for training
110+
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
111+
dataset = tokenized_datasets["train"].train_test_split(test_size=0.1, shuffle=True, seed=42)
112+
train_dataset = dataset["train"]
113+
eval_dataset = dataset["test"]
114+
115+
# Data collator to dynamically pad the batched examples
116+
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
117+
118+
# Define training arguments
119+
training_args = TrainingArguments(
120+
output_dir=output_dir,
121+
num_train_epochs=num_epochs,
122+
per_device_train_batch_size=batch_size,
123+
per_device_eval_batch_size=batch_size,
124+
warmup_steps=100,
125+
weight_decay=0.01,
126+
logging_dir="./logs",
127+
logging_steps=eval_step,
128+
save_steps=save_step,
129+
save_total_limit=2,
130+
gradient_accumulation_steps=1,
131+
bf16=True if compute_dtype == torch.bfloat16 else False,
132+
fp16=True if compute_dtype == torch.float16 else False,
133+
learning_rate=lr,
134+
)
135+
136+
# Here we initialize the LoRA-FA Optimizer
137+
# After this, all adapter A will be fixed, only adapter B will be trainable
138+
if lorafa:
139+
optimizer = create_lorafa_optimizer(
140+
model=model, r=lora_rank, lora_alpha=lora_alpha, lr=lr, weight_decay=training_args.weight_decay
141+
)
142+
trainer = Trainer(
143+
model=model,
144+
args=training_args,
145+
train_dataset=train_dataset,
146+
eval_dataset=eval_dataset,
147+
data_collator=data_collator,
148+
optimizers=(optimizer, None),
149+
)
150+
else:
151+
trainer = Trainer(
152+
model=model,
153+
args=training_args,
154+
train_dataset=train_dataset,
155+
eval_dataset=eval_dataset,
156+
data_collator=data_collator,
157+
)
158+
159+
# Start model training
160+
trainer.train()
161+
162+
# Save the model and tokenizer locally
163+
model.save_pretrained(output_dir)
164+
tokenizer.save_pretrained(output_dir)
165+
166+
167+
if __name__ == "__main__":
168+
import argparse
169+
170+
parser = argparse.ArgumentParser(description="Fine-tune Meta-Llama-3-8B-Instruct with LoRA-FA and PEFT")
171+
parser.add_argument(
172+
"--base_model_name_or_path",
173+
type=str,
174+
default="meta-llama/Meta-Llama-3-8B-Instruct",
175+
help="Base model name or path",
176+
)
177+
parser.add_argument(
178+
"--dataset_name_or_path", type=str, default="meta-math/MetaMathQA-40K", help="Dataset name or path"
179+
)
180+
parser.add_argument("--output_dir", type=str, help="Output directory for the fine-tuned model")
181+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
182+
parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
183+
parser.add_argument("--lr", type=float, default=7e-5, help="Learning rate")
184+
parser.add_argument("--cutoff_len", type=int, default=1024, help="Cutoff length for tokenization")
185+
parser.add_argument("--quantize", action="store_true", help="Use quantization")
186+
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
187+
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
188+
parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank")
189+
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
190+
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate")
191+
parser.add_argument(
192+
"--lora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA"
193+
)
194+
parser.add_argument("--lorafa", action="store_true", help="Use LoRA-FA Optimizer")
195+
196+
args = parser.parse_args()
197+
198+
train_model(
199+
base_model_name_or_path=args.base_model_name_or_path,
200+
dataset_name_or_path=args.dataset_name_or_path,
201+
output_dir=args.output_dir,
202+
batch_size=args.batch_size,
203+
num_epochs=args.num_epochs,
204+
lr=args.lr,
205+
cutoff_len=args.cutoff_len,
206+
quantize=args.quantize,
207+
eval_step=args.eval_step,
208+
save_step=args.save_step,
209+
lora_rank=args.lora_rank,
210+
lora_alpha=args.lora_alpha,
211+
lora_dropout=args.lora_dropout,
212+
lora_target_modules=args.lora_target_modules,
213+
lorafa=args.lorafa,
214+
)

src/peft/optimizers/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-present the HuggingFace Inc. team.
1+
# Copyright 2025-present the HuggingFace Inc. team.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .lorafa import create_lorafa_optimizer
1516
from .loraplus import create_loraplus_optimizer
1617

1718

18-
__all__ = ["create_loraplus_optimizer"]
19+
__all__ = ["create_lorafa_optimizer", "create_loraplus_optimizer"]

0 commit comments

Comments
 (0)