Skip to content

Support saving and loading 8-bit block weights #273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/petals/bloom/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@

from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import get_block_size
from petals.utils.convert_block import replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for

logger = get_logger(__name__)

CLIENT_BRANCH = "main"
BLOCK_BRANCH_PREFIX = "block_"
BLOCK_BRANCH_PREFIX = "int8_block"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll roll that back before merging



def load_pretrained_block(
Expand All @@ -38,6 +39,8 @@ def load_pretrained_block(
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
load_in_8bit=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
load_in_8bit=False,
load_in_8bit: bool = False,

device: Optional[Union[str, torch.device]] = None,
) -> WrappedBloomBlock:
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
Expand All @@ -49,6 +52,9 @@ def load_pretrained_block(

with init_empty_weights():
block = WrappedBloomBlock(config)
if load_in_8bit:
block = replace_8bit_linear(block)
block = block.to(device)
Comment on lines +55 to +57
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved replace_8bit_linear here because it's not possible to correctly load the quantized Linear8bitLt checkpoint into the model before it's converted and quantized


state_dict = _load_state_dict(
converted_model_name_or_path,
Expand Down
18 changes: 14 additions & 4 deletions src/petals/cli/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

logger = get_logger(__name__)

DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int8=torch.int8, auto="auto")


def main():
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")

parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
parser.add_argument(
"--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Load initial model in this dtype"
)
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
Expand All @@ -41,7 +43,6 @@ def main():
if args.model == "bigscience/bloom" and free_ram_gb < 400:
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")

assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
if os.path.exists(args.output_path) and (
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
):
Expand All @@ -54,8 +55,17 @@ def main():
config.dht_prefix = args.output_repo

model = BloomModel.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
args.model,
use_auth_token=args.use_auth_token,
revision=args.revision,
torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
load_in_8bit=args.torch_dtype == "int8",
device_map="auto" if args.torch_dtype == "int8" else None,
)
if args.torch_dtype == "int8":
# trigger weight quantization
model = model.cuda()

if args.resize_token_embeddings:
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
model.resize_token_embeddings(args.resize_token_embeddings)
Expand Down
4 changes: 3 additions & 1 deletion src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,10 @@ def create(
use_auth_token=use_auth_token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
load_in_8bit=load_in_8bit,
device=device,
)
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
block = convert_block(block, block_config, tensor_parallel_devices, device, freeze=True)

backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
blocks[module_uid] = TransformerBackend(
Expand Down
6 changes: 4 additions & 2 deletions src/petals/server/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import convert_block
from petals.utils.convert_block import convert_block, replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR

logger = get_logger(__name__)
Expand Down Expand Up @@ -149,7 +149,9 @@ def measure_compute_rps(
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = WrappedBloomBlock(config).to(dtype)
block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
if load_in_8bit:
block = replace_8bit_linear(block)
block = convert_block(block, config, tensor_parallel_devices, device, freeze=True)

cache = None
elapsed = 0
Expand Down
9 changes: 1 addition & 8 deletions src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,17 @@ def convert_block(
config: BloomConfig,
tensor_parallel_devices: Sequence[torch.device],
output_device: torch.device,
load_in_8bit: bool,
threshold: float = 6.0,
freeze: bool = True,
) -> tp.TensorParallel:
"""
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
Optimize a transformer block for use in a Petals server and apply tensor parallelism

:note: some optimizations will modify the input block in-place!
:param block: a single transformer block, either pre-trained or newly initialized
:param config: HF transformers config for the full model
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
:param output_device: if tensor_parallel_devices is True, output
:param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
:param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
:param freeze: if True (default), make all module parameters non-trainable
:return: a module that acts like the original block, but runs with all specified optimizations

Expand All @@ -49,9 +45,6 @@ def convert_block(

block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)

if load_in_8bit:
block = replace_8bit_linear(block, threshold=threshold)

for shard, device in zip(block.module_shards, block.devices):
shard.to(device)

Expand Down