-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Open
Labels
questionFurther information is requestedFurther information is requested
Description
I'm trying to run inference within the LightningTrainer using a litgpt model with 2d parallelization (TP+FSDP) while using a Bitsandbytes precision plugin to enable quantization, however I get into issues.
Here's a sample script I use
from __future__ import annotations
from pathlib import Path
from typing import Union
import lightning as L
import litgpt
import torch
from lightning.pytorch.plugins import (
BitsandbytesPrecision,
)
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
from litgpt import GPT, Config
from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from torch import nn
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
class RandomTokenDataset(torch.utils.data.Dataset):
def __init__(self, vocab_size: int, seq_length: int):
self.vocab_size = vocab_size
self.seq_length = seq_length
self.tokens = torch.randint(
self.vocab_size,
size=(len(self), self.seq_length + 1),
# Set a seed to make this toy dataset the same on each rank
# Fabric will add a `DistributedSampler` to shard the data correctly
generator=torch.Generator().manual_seed(42),
)
def __len__(self) -> int:
return 128
def __getitem__(self, item: int):
return self.tokens[item][:-1], self.tokens[item][1:]
def _apply_data_parallel(model: GPT, device_mesh: DeviceMesh, mp_policy: MixedPrecisionPolicy):
"""Apply data parallelism using PyTorch FSDP (Fully Sharded Data Parallelism).
Applies FSDP to each transformer block and the inner model, configuring
mixed precision and customized sharding.
Args:
model: HuggingFace PreTrainedModel to apply data parallelism to
device_mesh: Device mesh containing the data_parallel submesh
mp_policy: MixedPrecisionPolicy to apply
"""
from torch.distributed.fsdp import fully_shard
dp_mesh = device_mesh["data_parallel"]
assert dp_mesh.ndim == 1, f"Hybrid-sharding not supported; dp_mesh.ndim ({dp_mesh.ndim}) != 1"
# NOTE: Currently, the user is required to manually handle precision settings such as the `mp_policy` here
# because the model parallel strategy does not respect all settings of `Fabric(precision=...)` at the moment.
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
layers = model.transformer.h
for layer_id, transformer_block in enumerate(layers):
# As an optimization, do not reshard after forward for the last transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
layers[layer_id] = transformer_block
# shard embedding
model = fully_shard(model, **fsdp_config) # type: ignore
def tensor_parallel_mlp(mesh, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE]) -> None:
plan = {}
if isinstance(mlp, LLaMAMLP):
plan["fc_1"] = ColwiseParallel()
plan["fc_2"] = ColwiseParallel()
plan["proj"] = RowwiseParallel()
elif isinstance(mlp, GptNeoxMLP):
plan["fc"] = ColwiseParallel()
plan["proj"] = RowwiseParallel()
elif isinstance(mlp, LLaMAMoE):
# we use expert slicing across ranks, alternatively, we could create a expert parallelism group
# when the number of experts is a multiple of the world size
for expert in mlp.experts:
tensor_parallel_mlp(mesh, expert)
else:
raise NotImplementedError
parallelize_module(mlp, mesh, plan)
def tensor_parallel_attn(mesh: DeviceMesh, attn: CausalSelfAttention) -> None:
def shard(x, dim, world_size, rank):
assert x.size(dim=dim) % world_size == 0
return torch.tensor_split(x, world_size, dim=dim)[rank]
plan = {
"proj": RowwiseParallel(),
}
query_size = attn.config.n_head * attn.config.head_size
key_size = value_size = attn.config.n_query_groups * attn.config.head_size
q, k, v = attn.qkv.weight.split((query_size, key_size, value_size), dim=0)
q = shard(q, 0, mesh["tensor_parallel"].size(), mesh.get_local_rank("tensor_parallel"))
k = shard(k, 0, mesh["tensor_parallel"].size(), mesh.get_local_rank("tensor_parallel"))
v = shard(v, 0, mesh["tensor_parallel"].size(), mesh.get_local_rank("tensor_parallel"))
attn.qkv.weight = nn.Parameter(torch.cat((q, k, v), dim=0), requires_grad=False)
parallelize_module(attn, mesh["tensor_parallel"], plan)
def parallelize(model, device_mesh):
tp_mesh = device_mesh["tensor_parallel"]
dp_mesh = device_mesh["data_parallel"]
assert tp_mesh.size() > 1
plan = {
"transformer.wte": RowwiseParallel(input_layouts=Replicate()),
"transformer.ln_f": SequenceParallel(),
"lm_head": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
}
parallelize_module(model, tp_mesh, plan)
for block in model.transformer.h:
plan = {}
tensor_parallel_mlp(tp_mesh, block.mlp)
tensor_parallel_attn(device_mesh, block.attn)
# update the config values to the shard sizes
# this is only relevant for `tensor_parallel_attn`, but it needs to run only once
attrs = ["n_head", "n_embd", "n_query_groups"]
for attr in attrs:
size = getattr(model.config, attr)
if size % tp_mesh.size() != 0:
raise ValueError(f"This {attr} value ({size}) is not evenly divisible by the world size ({tp_mesh.size()})")
print("Setting %s to %s", attr, size // tp_mesh.size())
setattr(model.config, attr, size // tp_mesh.size())
_apply_data_parallel(
model,
dp_mesh,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
),
)
return model
class LitLLM(L.LightningModule):
def __init__(self):
super().__init__()
config = Config.from_checkpoint(Path("Qwen/Qwen3-8B"))
self.model = GPT(config=config)
def configure_model(self):
state_dict = torch.load(
"checkpoints/Qwen/Qwen3-8B/lit_model.pth", map_location="cpu", mmap=True, weights_only=True
)
self.model.load_state_dict(state_dict=state_dict, assign=True)
parallelize(self.model, self.device_mesh)
def predict_step(self, batch):
input_ids, targets = batch
logits = self.model(input_ids)
loss = litgpt.utils.chunked_cross_entropy(logits, targets)
return loss
if __name__ == "__main__":
data = RandomTokenDataset(vocab_size=151936, seq_length=2048)
dataloader = torch.utils.data.DataLoader(data, batch_size=1)
strategy = ModelParallelStrategy(data_parallel_size=2, tensor_parallel_size=2)
bnb = BitsandbytesPrecision("fp4")
trainer = L.Trainer(
devices=4, max_epochs=2, accumulate_grad_batches=1, precision=None, plugins=bnb, strategy=strategy
)
with trainer.init_module(empty_init=True):
model = LitLLM()
trainer.predict(model, dataloader)
trainer.print("Inference successfully completed!")
trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
The first issue is inside torch.distributed.tensor.parallel.styles.py
when applying the different styles it checks for isinstance(module, nn.Linear)
which fails. (changing that to isinstance(module, torch.nn.modules.linear.Linear)
moves past that.
The next issue is within bitsandbytes
[rank1]: File "/home/user/miniconda3/envs/condaenv/lib/python3.11/site-packages/lightning/fabric/plugins/precision/bitsandbytes.py", line 187, in _replace_param
[rank1]: return bnb.nn.Params4bit(
[rank1]: ^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/user/miniconda3/envs/condaenv/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 230, in __new__
[rank1]: self = torch.Tensor._make_subclass(cls, data, requires_grad)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: Creating a new Tensor subclass Params4bit but the raw Tensor object is already associated to a python object of type DTensor
Any pointers on how to make this work? @awaelchli
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested