Skip to content

Bitsandbytes quantization for litgpt 2d parallel model (TP+FSDP) within LightningTrainer #2115

@radulescupetru

Description

@radulescupetru

I'm trying to run inference within the LightningTrainer using a litgpt model with 2d parallelization (TP+FSDP) while using a Bitsandbytes precision plugin to enable quantization, however I get into issues.
Here's a sample script I use

from __future__ import annotations

from pathlib import Path
from typing import Union

import lightning as L
import litgpt
import torch
from lightning.pytorch.plugins import (
    BitsandbytesPrecision,
)
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
from litgpt import GPT, Config
from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from torch import nn
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    RowwiseParallel,
    SequenceParallel,
    parallelize_module,
)


class RandomTokenDataset(torch.utils.data.Dataset):
    def __init__(self, vocab_size: int, seq_length: int):
        self.vocab_size = vocab_size
        self.seq_length = seq_length
        self.tokens = torch.randint(
            self.vocab_size,
            size=(len(self), self.seq_length + 1),
            # Set a seed to make this toy dataset the same on each rank
            # Fabric will add a `DistributedSampler` to shard the data correctly
            generator=torch.Generator().manual_seed(42),
        )

    def __len__(self) -> int:
        return 128

    def __getitem__(self, item: int):
        return self.tokens[item][:-1], self.tokens[item][1:]


def _apply_data_parallel(model: GPT, device_mesh: DeviceMesh, mp_policy: MixedPrecisionPolicy):
    """Apply data parallelism using PyTorch FSDP (Fully Sharded Data Parallelism).

    Applies FSDP to each transformer block and the inner model, configuring
    mixed precision and customized sharding.

    Args:
        model: HuggingFace PreTrainedModel to apply data parallelism to
        device_mesh: Device mesh containing the data_parallel submesh
        mp_policy: MixedPrecisionPolicy to apply
    """
    from torch.distributed.fsdp import fully_shard

    dp_mesh = device_mesh["data_parallel"]

    assert dp_mesh.ndim == 1, f"Hybrid-sharding not supported; dp_mesh.ndim ({dp_mesh.ndim}) != 1"

    # NOTE: Currently, the user is required to manually handle precision settings such as the `mp_policy` here
    # because the model parallel strategy does not respect all settings of `Fabric(precision=...)` at the moment.
    fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

    layers = model.transformer.h
    for layer_id, transformer_block in enumerate(layers):
        # As an optimization, do not reshard after forward for the last transformer block since FSDP would prefetch it immediately
        reshard_after_forward = int(layer_id) < len(layers) - 1
        fully_shard(
            transformer_block,
            **fsdp_config,
            reshard_after_forward=reshard_after_forward,
        )
        layers[layer_id] = transformer_block

    # shard embedding
    model = fully_shard(model, **fsdp_config)  # type: ignore


def tensor_parallel_mlp(mesh, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE]) -> None:
    plan = {}
    if isinstance(mlp, LLaMAMLP):
        plan["fc_1"] = ColwiseParallel()
        plan["fc_2"] = ColwiseParallel()
        plan["proj"] = RowwiseParallel()
    elif isinstance(mlp, GptNeoxMLP):
        plan["fc"] = ColwiseParallel()
        plan["proj"] = RowwiseParallel()
    elif isinstance(mlp, LLaMAMoE):
        # we use expert slicing across ranks, alternatively, we could create a expert parallelism group
        # when the number of experts is a multiple of the world size
        for expert in mlp.experts:
            tensor_parallel_mlp(mesh, expert)
    else:
        raise NotImplementedError

    parallelize_module(mlp, mesh, plan)


def tensor_parallel_attn(mesh: DeviceMesh, attn: CausalSelfAttention) -> None:
    def shard(x, dim, world_size, rank):
        assert x.size(dim=dim) % world_size == 0
        return torch.tensor_split(x, world_size, dim=dim)[rank]

    plan = {
        "proj": RowwiseParallel(),
    }
    query_size = attn.config.n_head * attn.config.head_size
    key_size = value_size = attn.config.n_query_groups * attn.config.head_size
    q, k, v = attn.qkv.weight.split((query_size, key_size, value_size), dim=0)
    q = shard(q, 0, mesh["tensor_parallel"].size(), mesh.get_local_rank("tensor_parallel"))
    k = shard(k, 0, mesh["tensor_parallel"].size(), mesh.get_local_rank("tensor_parallel"))
    v = shard(v, 0, mesh["tensor_parallel"].size(), mesh.get_local_rank("tensor_parallel"))
    attn.qkv.weight = nn.Parameter(torch.cat((q, k, v), dim=0), requires_grad=False)
    parallelize_module(attn, mesh["tensor_parallel"], plan)


def parallelize(model, device_mesh):
    tp_mesh = device_mesh["tensor_parallel"]
    dp_mesh = device_mesh["data_parallel"]

    assert tp_mesh.size() > 1

    plan = {
        "transformer.wte": RowwiseParallel(input_layouts=Replicate()),
        "transformer.ln_f": SequenceParallel(),
        "lm_head": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
    }
    parallelize_module(model, tp_mesh, plan)

    for block in model.transformer.h:
        plan = {}
        tensor_parallel_mlp(tp_mesh, block.mlp)
        tensor_parallel_attn(device_mesh, block.attn)

    # update the config values to the shard sizes
    # this is only relevant for `tensor_parallel_attn`, but it needs to run only once
    attrs = ["n_head", "n_embd", "n_query_groups"]
    for attr in attrs:
        size = getattr(model.config, attr)
        if size % tp_mesh.size() != 0:
            raise ValueError(f"This {attr} value ({size}) is not evenly divisible by the world size ({tp_mesh.size()})")
        print("Setting %s to %s", attr, size // tp_mesh.size())
        setattr(model.config, attr, size // tp_mesh.size())

    _apply_data_parallel(
        model,
        dp_mesh,
        mp_policy=MixedPrecisionPolicy(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
            output_dtype=torch.bfloat16,
        ),
    )

    return model


class LitLLM(L.LightningModule):
    def __init__(self):
        super().__init__()
        config = Config.from_checkpoint(Path("Qwen/Qwen3-8B"))
        self.model = GPT(config=config)

    def configure_model(self):
        state_dict = torch.load(
            "checkpoints/Qwen/Qwen3-8B/lit_model.pth", map_location="cpu", mmap=True, weights_only=True
        )
        self.model.load_state_dict(state_dict=state_dict, assign=True)
        parallelize(self.model, self.device_mesh)

    def predict_step(self, batch):
        input_ids, targets = batch
        logits = self.model(input_ids)
        loss = litgpt.utils.chunked_cross_entropy(logits, targets)
        return loss


if __name__ == "__main__":
    data = RandomTokenDataset(vocab_size=151936, seq_length=2048)
    dataloader = torch.utils.data.DataLoader(data, batch_size=1)
    strategy = ModelParallelStrategy(data_parallel_size=2, tensor_parallel_size=2)
    bnb = BitsandbytesPrecision("fp4")
    trainer = L.Trainer(
        devices=4, max_epochs=2, accumulate_grad_batches=1, precision=None, plugins=bnb, strategy=strategy
    )
    with trainer.init_module(empty_init=True):
        model = LitLLM()

    trainer.predict(model, dataloader)
    trainer.print("Inference successfully completed!")
    trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

The first issue is inside torch.distributed.tensor.parallel.styles.py when applying the different styles it checks for isinstance(module, nn.Linear) which fails. (changing that to isinstance(module, torch.nn.modules.linear.Linear) moves past that.

The next issue is within bitsandbytes

[rank1]:   File "/home/user/miniconda3/envs/condaenv/lib/python3.11/site-packages/lightning/fabric/plugins/precision/bitsandbytes.py", line 187, in _replace_param
[rank1]:     return bnb.nn.Params4bit(
[rank1]:            ^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/user/miniconda3/envs/condaenv/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 230, in __new__
[rank1]:     self = torch.Tensor._make_subclass(cls, data, requires_grad)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: Creating a new Tensor subclass Params4bit but the raw Tensor object is already associated to a python object of type DTensor

Any pointers on how to make this work? @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions