|
| 1 | +# RandLora: Full-rank parameter-efficient fine-tuning of large models |
| 2 | + |
| 3 | +## Introduction |
| 4 | +[RandLora](https://huggingface.co/papers/2502.00987) is a parameter-efficient fine-tuning technique that is similar to LoRA and VeRA but performs full rank updates to improve performance. RandLora can be particulary usefull when adapting large model to hard tasks that require complex updates while preserving the parameter efficiency of LoRA. The full rank update of RandLora is acheived by linearly scaling random bases. The random bases are a collection of multiple low rank matrices such that the summation of their ranks if greater or equal to the full rank of the parameter matrices. The trainable parameters of RandLora are two diagonal matrices (vectors) that get multiplied with the right hand low rank random bases, in a similar way to VeRA's update. To maintain low memory usage, RandLora uses a custom function that prevents storing unnecessary bases in memory for backpropagation. |
| 5 | + |
| 6 | +## Quick start |
| 7 | +```python |
| 8 | +import torch |
| 9 | +from peft import RandLoraConfig, get_peft_model |
| 10 | +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer |
| 11 | +from datasets import load_dataset |
| 12 | + |
| 13 | +model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="cuda") |
| 14 | +tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") |
| 15 | +dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") |
| 16 | +randlora_config = RandLoraConfig() |
| 17 | + |
| 18 | +peft_model = get_peft_model(model, lora_config) |
| 19 | +trainer = transformers.Trainer( |
| 20 | + model=peft_model, |
| 21 | + train_dataset=dataset, |
| 22 | + dataset_text_field="text", |
| 23 | + max_seq_length=2048, |
| 24 | + tokenizer=tokenizer, |
| 25 | +) |
| 26 | +trainer.train() |
| 27 | +peft_model.save_pretrained("randlora-llama-7b") |
| 28 | +``` |
| 29 | + |
| 30 | +There is no additional change needed to your standard PEFT training procedure, simply swap your `LoraConfig` for a `RandLoraConfig`. Note however that RandLora's trainable parameter count is **inversely proportional** to the rank parameter `r`. Lower `r` to increase and increase it to reduce trainable parameters of RandLora. |
| 31 | + |
| 32 | +Run the finetuning script simply by running: |
| 33 | +```bash |
| 34 | +python examples/randlora_finetuning/randlora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco |
| 35 | +``` |
| 36 | +This 👆🏻 by default will load the model in peft set up with RandLora config. Now if you wanna quickly compare it with Lora, all you need to do is to input ` --use_lora` in the command line and reduce `--randlora_alpha` to 2x the rank. So same above example would be 👇🏻; |
| 37 | + |
| 38 | +```bash |
| 39 | +python examples/randlora_finetuning/randlora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco --use_lora --rank 32 --randlora_alpha 64 |
| 40 | +``` |
| 41 | + |
| 42 | +RandLora can be made to use sparse or very sparse random bases. These sparse matrices can help reduce overfitting. Add `--very_sparse` to run with very sparse matrices or `--sparse` for sparse matrices: |
| 43 | + |
| 44 | +```bash |
| 45 | +python examples/randlora_finetuning/randlora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --sparse |
| 46 | +``` |
| 47 | + |
| 48 | +RandLora also supports quantization. To use 4-bit quantization try: |
| 49 | + |
| 50 | +```bash |
| 51 | +python examples/randlora_finetuning/randlora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --quantize |
| 52 | +``` |
| 53 | + |
| 54 | +By default the RandLora layers are the key and value layers of LLama model. Adding adapters on more layers will increase memory usage. If you wish to choose a different set of layers for RandLora to be applied on, you can simply define it using: |
| 55 | +```bash |
| 56 | +python examples/randlora_finetuning/randlora_finetuning.py --randlora_target_modules "q_proj,k_proj,v_proj" |
| 57 | +``` |
| 58 | + |
| 59 | +### Full example of the script |
| 60 | +```bash |
| 61 | +python randlora_finetuning.py \ |
| 62 | + --base_model "PATH_TO_MODEL" \ |
| 63 | + --data_path "PATH_TO_DATASET" \ |
| 64 | + --output_dir "PATH_TO_OUTPUT_DIR" \ |
| 65 | + --batch_size 1 \ |
| 66 | + --num_epochs 3 \ |
| 67 | + --learning_rate 3e-4 \ |
| 68 | + --cutoff_len 512 \ |
| 69 | + --val_set_size 500 \ |
| 70 | + --quantize \ |
| 71 | + --eval_step 10 \ |
| 72 | + --save_step 100 \ |
| 73 | + --device "cuda:0" \ |
| 74 | + --rank 32 \ |
| 75 | + --randlora_alpha 640 \ |
| 76 | + --randlora_dropout 0.05 \ |
| 77 | + --randlora_target_modules "k_proj,v_proj" \ |
| 78 | + --hub_model_id "YOUR_HF_REPO" \ |
| 79 | + --push_to_hub |
| 80 | +``` |
| 81 | + |
| 82 | +## RandLora vs. LoRA |
| 83 | +RandLora differs from LoRA and other related low rank approximation algorithms by chanllenging the low rank paradigm. RandLora adapters learn **full-rank** updates as the [paper](https://huggingface.co/papers/2502.00987) shows that the low rank constraint of LoRA can constrain performance gains as trainable parameters increase (with higher ranks). As a result, using RandLora is specifically recommended for difficult tasks that are underfit by LoRA. RandLoRA however also often improves performance for common tasks. If increasing LoRA's rank improves performance for your task, RandLora will most likely outperform. |
| 84 | + |
| 85 | +RandLora is expected to increase performance over LoRA for equivalent amounts of trainable parameters, mostly for larger equivalent amounts (> LoRA rank 4). |
| 86 | + |
| 87 | +RandLora's performance increase comes with two limitations: |
| 88 | + |
| 89 | +1. Performance is dependent on using a large `randlora_alpha` scaling parameter (usually 20x the basis rank). This large parameter can sometimes make training the update unstable, reduce the learning rate or the scaling parameter if this is the case. |
| 90 | + |
| 91 | +2. Increase training time over LoRA when using very low RandLora basis ranks. |
| 92 | + |
| 93 | +## RandLora vs. VeRA |
| 94 | +RandLora shares similarities with VeRA in that both algorithms use random basis combinations to address some of LoRA's limitations. The limitations addressed by each algorithm is however different. |
| 95 | +VeRA aims to reduce trainable parameters beyond rank 1 LoRAs while RandLoRA reduces the performance limitation due to the low rank of the update as the trainable parameter count increases. |
| 96 | + |
| 97 | +RandLora is expected to: |
| 98 | + |
| 99 | +1. Improve performance over VeRA when more trainable parameters are required (hard tasks) |
| 100 | + |
| 101 | +2. Reduce memory usage over VeRA thanks to RandLora's random base sharing strategy |
| 102 | + |
| 103 | + |
| 104 | +## Citation |
| 105 | +``` |
| 106 | +@inproceedings{2025_ICLR_RandLoRA, |
| 107 | + title="{RandLoRA: Full rank parameter-efficient fine-tuning of large models}", |
| 108 | + author="Albert, Paul and Zhang, Frederic Z. and Saratchandran, Hemanth and Rodriguez-Opazo, Cristian and van den Hengel, Anton and Abbasnejad, Ehsan", |
| 109 | + booktitle="{International Conference on Learning Representations (ICLR)}", |
| 110 | + year="2025" |
| 111 | +} |
| 112 | +``` |
0 commit comments