Skip to content

FSDP QLORA doesn't work with multiple adapters #2834

@ojh31

Description

@ojh31

System Info

  • Accelerate version: 1.10.1
  • Platform: Linux-5.15.0-157-generic-x86_64-with-glibc2.39
  • accelerate bash location: /opt/uv/venv/bin/accelerate
  • Python version: 3.12.11
  • Numpy version: 2.2.6
  • PyTorch version: 2.9.0.dev20250825+cu128
  • PyTorch accelerator: CUDA
  • System RAM: 1003.13 GB
  • GPU type: NVIDIA H100 80GB HBM3
  • Accelerate default config:
    Not found
  • peft version: 0.15.2
  • transformers version: 4.56.1

Who can help?

@BenjaminBossan @githubnemo

Reproduction

"""Based on peft/examples/sft/run_peft_qlora_fsdp.sh

Launch command:
{
    "name": "Accelerate Launch - Minimal FSDP QLoRA Training",
    "type": "debugpy",
    "request": "launch",
    "module": "accelerate.commands.launch",
    "args": [
        "--config_file",
        "scripts/fsdp_config_qlora.yaml",
        "--num_processes",
        "2",
        "scripts/20251008_fsdp_qlora_sft_custom.py",
        "--seed",
        "100",
        "--model_name_or_path",
        "meta-llama/Llama-3.1-8B-Instruct",
        "--dataset_name",
        "smangrul/ultrachat-10k-chatml",
        "--add_special_tokens",
        "False",
        "--append_concat_token",
        "False",
        "--splits",
        "train,test",
        "--max_seq_len",
        "2048",
        "--num_train_epochs",
        "1",
        "--logging_steps",
        "5",
        "--log_level",
        "info",
        "--logging_strategy",
        "steps",
        "--learning_rate",
        "1e-4",
        "--lr_scheduler_type",
        "cosine",
        "--weight_decay",
        "1e-4",
        "--warmup_ratio",
        "0.0",
        "--max_grad_norm",
        "1.0",
        "--output_dir",
        "llama-sft-qlora-fsdp",
        "--per_device_train_batch_size",
        "2",
        "--per_device_eval_batch_size",
        "2",
        "--gradient_accumulation_steps",
        "2",
        "--gradient_checkpointing",
        "True",
        "--lora_r",
        "8",
        "--lora_alpha",
        "16",
        "--lora_dropout",
        "0.1",
        "--lora_target_modules",
        "all-linear",
        "--max_steps",
        "2",
    ],
    "console": "integratedTerminal",
    "justMyCode": false,
    "cwd": "${workspaceFolder}"
}
"""

import os
import sys
from dataclasses import dataclass, field

import torch
from accelerate import Accelerator
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from peft import LoraConfig, PeftConfig, PeftModel
from peft.utils.other import fsdp_auto_wrap_policy
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    PreTrainedModel,
    PreTrainedTokenizer,
    TrainingArguments,
    get_scheduler,
    set_seed,
)
from transformers.data.data_collator import DataCollatorWithPadding


class MinimalSFTTrainer:
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        peft_config: PeftConfig,
        train_dataset,
        args: TrainingArguments,
    ):
        self.args = args
        self.train_dataset = train_dataset

        self.tokenizer = tokenizer

        # Initialize accelerator with FSDP
        self.accelerator = Accelerator(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            mixed_precision="bf16",
        )

        # Prepare PEFT model
        if args.gradient_checkpointing:
            model.gradient_checkpointing_enable()
            model.enable_input_require_grads()

        # Create PEFT model
        self.model = PeftModel.from_pretrained(
            model,
            "AlignmentResearch/Llama-3.1-8B-Instruct-gsm8k-lora-reference",
            autocast_adapter_dtype=False,
            adapter_name="reference",
        )
        self.model.load_adapter(
            "AlignmentResearch/Llama-3.1-8B-Instruct-gsm8k-lora-reference", adapter_name="policy", autocast_adapter_dtype=False
        )

        # Critical: Update FSDP plugin for QLORA
        if self.accelerator.state.fsdp_plugin is not None:
            fsdp_plugin = self.accelerator.state.fsdp_plugin

            # Set auto wrap policy for PEFT
            fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)

            quant_storage = model.hf_quantizer.quantization_config.bnb_4bit_quant_storage
            if quant_storage.is_floating_point:
                fsdp_plugin.set_mixed_precision(quant_storage, override=True)

        # Create dataloader
        formatted_ds = self.train_dataset.map(
            lambda x: {"content": tokenizer.apply_chat_template(x["messages"], tokenize=False)},
            batched=False,
            remove_columns=self.train_dataset.column_names,
        )
        tokenized_ds = formatted_ds.map(
            lambda x: self.tokenizer(x["content"], truncation=True), batched=True, remove_columns=formatted_ds.column_names
        )
        self.train_dataloader = DataLoader(
            tokenized_ds,
            batch_size=args.per_device_train_batch_size,
            collate_fn=DataCollatorWithPadding(self.tokenizer),
            shuffle=True,
        )

        # Create optimizer - only optimize trainable parameters
        optimizer_params = [p for p in model.parameters() if p.requires_grad]
        self.optimizer = torch.optim.AdamW(
            optimizer_params,
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
        )

        # Calculate training steps
        num_update_steps_per_epoch = len(self.train_dataloader) // args.gradient_accumulation_steps
        max_steps = args.max_steps if args.max_steps > 0 else int(args.num_train_epochs * num_update_steps_per_epoch)

        # Create scheduler
        self.lr_scheduler = get_scheduler(
            args.lr_scheduler_type,
            optimizer=self.optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=max_steps,
        )

        # Prepare everything with accelerator
        self.model.base_model.set_adapter(["reference", "policy"])
        self.accelerator.print(f"Active adapters: {self.model.active_adapters}")
        for name, param in self.model.named_parameters():
            if "layers.0.self_attn.q_proj" in name:
                print(f"{name} {param.shape} {param.device} {param.dtype} {param.requires_grad}")
        # N.B. the below will hang unless peft.tuners.tuner_utils.py::BaseTunerLayer._move_adapter_to_device_of_base_layer is
        # overridden to remove the special meta device handling
        self.accelerator.print("Preparing everything with accelerator")
        self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
            self.model, self.optimizer, self.train_dataloader, self.lr_scheduler
        )
        self.accelerator.print("Everything prepared with accelerator")

        self.global_step = 0
        self.max_steps = max_steps

def create_dataset(tokenizer, data_args):
    raw_datasets = DatasetDict()
    for split in data_args.splits.split(","):
        try:
            # Try first if dataset on a Hub repo
            dataset = load_dataset(data_args.dataset_name, split=split)
        except DatasetGenerationError:
            # If not, check local dataset
            dataset = load_from_disk(os.path.join(data_args.dataset_name, split))

        assert isinstance(dataset, Dataset)
        dataset = dataset.select(range(8))
        if "train" in split:
            raw_datasets["train"] = dataset
        elif "test" in split:
            raw_datasets["test"] = dataset
        else:
            raise ValueError(f"Split type {split} not recognized as one of test or train.")

    train_data = raw_datasets["train"]
    print(f"Size of the train set: {len(train_data)}")
    print(f"A sample of train dataset: {train_data[0]}")

    return train_data


def create_and_prepare_model(args):
    quant_storage_dtype = torch.bfloat16

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype="bfloat16",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_storage=quant_storage_dtype,
    )

    torch_dtype = quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32

    # Prepare model loading arguments
    model_kwargs = {
        "trust_remote_code": True,
        "torch_dtype": torch_dtype,
        "attn_implementation": "flash_attention_2",
        "quantization_config": bnb_config,
        "use_cache": False,
    }
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)

    peft_config = LoraConfig(
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        r=args.lora_r,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=args.lora_target_modules.split(",")
        if args.lora_target_modules != "all-linear"
        else args.lora_target_modules,
    )

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    return model, peft_config, tokenizer


# Define and parse arguments.
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    max_seq_length: int | None = field(
        default=512,
        metadata={"help": "The maximum total input sequence length after tokenization."},
    )
    lora_alpha: int | None = field(default=16)
    lora_dropout: float | None = field(default=0.1)
    lora_r: int | None = field(default=64)
    lora_target_modules: str | None = field(
        default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
        metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
    )


@dataclass
class DataTrainingArguments:
    dataset_name: str | None = field(
        default="timdettmers/openassistant-guanaco",
        metadata={"help": "The preference dataset to use."},
    )
    append_concat_token: bool | None = field(
        default=False,
        metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
    )
    add_special_tokens: bool | None = field(
        default=False,
        metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
    )
    splits: str | None = field(
        default="train,test",
        metadata={"help": "Comma separate list of the splits to use from the dataset."},
    )


def main(model_args, data_args, training_args):
    # Set seed for reproducibility
    set_seed(training_args.seed)

    # model
    model, peft_config, tokenizer = create_and_prepare_model(model_args)

    training_args.dataset_kwargs = {
        "append_concat_token": data_args.append_concat_token,
        "add_special_tokens": data_args.add_special_tokens,
    }

    # datasets
    train_dataset = create_dataset(
        tokenizer,
        data_args,
    )

    # trainer
    trainer = MinimalSFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        peft_config=peft_config,
    )
    trainer.accelerator.print(f"{trainer.model}")
    if hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()



if __name__ == "__main__":
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    main(model_args, data_args, training_args)

Expected behavior

The accelerator.prepare should not hang, and also I would expect that device1 would show all tensors on meta device, but in fact it shows that the second adapter is redundantly on cpu

base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight torch.Size([4194304, 1]) meta torch.bfloat16 False
base_model.model.model.layers.0.self_attn.q_proj.lora_A.reference.weight torch.Size([64, 4096]) meta torch.bfloat16 True
base_model.model.model.layers.0.self_attn.q_proj.lora_A.policy.weight torch.Size([64, 4096]) cpu torch.float32 True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.reference.weight torch.Size([4096, 64]) meta torch.bfloat16 True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.policy.weight torch.Size([4096, 64]) cpu torch.float32 True

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions