diff --git a/compress_model.py b/compress_model.py new file mode 100644 index 000000000..fa67bead0 --- /dev/null +++ b/compress_model.py @@ -0,0 +1,60 @@ +# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.utils import dispatch_for_generation + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, help="Model stub to compress") + parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier") + parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + # Select model and load it. + MODEL_ID = args.model_id + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + # Select number of samples. 512 samples is a good place to start. + # Increasing the number of samples can improve accuracy. + NUM_CALIBRATION_SAMPLES = 512 + MAX_SEQUENCE_LENGTH = 2048 + + # Configure the quantization algorithm to run. + recipe = [] + if args.transform_type: + recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type)) + + if args.scheme: + recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"])) + + # Apply algorithms. + oneshot( + model=model, + recipe=recipe, + dataset="ultrachat_200k", + splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"}, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + ) + + # Confirm generations of the quantized model look sane. + print("\n\n") + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=100) + print(tokenizer.decode(output[0])) + print("==========================================\n\n") + + # Save to disk compressed. + SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}" + model.save_pretrained(SAVE_DIR, save_compressed=True) + tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py new file mode 100644 index 000000000..9a0cacee7 --- /dev/null +++ b/examples/transform/quip_example.py @@ -0,0 +1,108 @@ +""" +WARNING: This example requires the following minimum versions: + * compressed-tensors>=0.10.3.dev + * transformers>=4.56.dev +Note that (you may need to install from source) + +Models produced by this example will not be runnable in vLLM without +the following changes: https://github.com/vllm-project/vllm/pull/22219 +""" + +from datasets import load_dataset +from packaging import version +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.utils.import_utils import _is_package_available + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import QuIPModifier +from llmcompressor.utils import dispatch_for_generation + +# check correct versioning +_, ct_version = _is_package_available("compressed_tensors", return_version=True) +_, tfms_version = _is_package_available("transformers", return_version=True) +if version.parse(ct_version) < version.parse("0.10.3.dev"): + print(version.parse(ct_version)) + raise ValueError("Please install compressed-tensors>=0.10.3 or from source") +if version.parse(tfms_version) < version.parse("4.56.dev"): + raise ValueError("Please install transformers>=4.56 or from source") + +# Select model and load it. +MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * apply spinquant transforms to model in order to make quantization easier +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + QuIPModifier(transform_type="random-hadamard"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +# Apply algorithms. +oneshot( + model=model, + recipe=recipe, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + pipeline="datafree", +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-quip-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index e2c22ed1f..76b6b0391 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa +from .fuse import * from .prepare import * diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py new file mode 100644 index 000000000..e37e388f4 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .quip import QuIPModifier diff --git a/src/llmcompressor/modifiers/transform/quip/__init__.py b/src/llmcompressor/modifiers/transform/quip/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/quip/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py new file mode 100644 index 000000000..5d89dab95 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -0,0 +1,140 @@ +from typing import List, Literal, Optional, Union + +import torch +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from compressed_tensors.utils import TorchDtype +from pydantic import Field, ValidationInfo, field_validator + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modifiers import Modifier + +__all__ = ["QuIPModifier"] + + +class QuIPModifier(Modifier): + """ + Implements the transforms according to + [QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) # noqa: E501 + [QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) # noqa: E501 + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achieved through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + QuIP and QuIP# apply transforms to every linear layer, two of which are fused into + the model weights and two of which remain as online rotations computed at runtime. + + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-hadamard"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: If true, create distinct transforms for each application + :param learnable: If true, attach gradients to transform weights for training + :param precision: Precision at which all transforms should be applied. This applies + to both weight fusing and online rotations + :param ignore: Modules to ignore when attaching transforms + :param transform_config: Optional transform config for overriding provided arguments + """ + + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="random-hadamard" + ) + randomize: bool = Field(default=False) + learnable: bool = Field(default=False) + precision: TorchDtype = Field(default=torch.float64) + ignore: Union[str, List[str]] = Field(default="lm_head") + + # optional override for more fine-grained control + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + if value: + raise NotImplementedError(f"{info.field_name} is not supported right now") + return value + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.transform_config is not None: + return True + + self.transform_config = self._create_config() + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + apply_transform_config(state.model, self.transform_config) + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + def _create_config(self) -> TransformConfig: + return TransformConfig( + config_groups={ + "v": TransformScheme( + type=self.transform_type, + apply=[ + TransformArgs( + targets=["Linear"], + location="input", # non-mergable + ignore=self.ignore, + ), + TransformArgs( + targets=["Linear"], + location="weight_input", + inverse=True, + ignore=self.ignore, + ), + ], + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + ), + "u": TransformScheme( + type=self.transform_type, + apply=[ + TransformArgs( + targets=["Linear"], + location="weight_output", + ignore=self.ignore, + ), + TransformArgs( + targets=["Linear"], + location="output", # non-mergable + inverse=True, + ignore=self.ignore, + ), + ], + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + ), + } + ) diff --git a/src/llmcompressor/pipelines/data_free/pipeline.py b/src/llmcompressor/pipelines/data_free/pipeline.py index 587f7ca69..7ad6d56dc 100644 --- a/src/llmcompressor/pipelines/data_free/pipeline.py +++ b/src/llmcompressor/pipelines/data_free/pipeline.py @@ -5,6 +5,7 @@ from llmcompressor.core.session_functions import LifecycleCallbacks from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.utils.dev import dispatch_for_generation if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -27,5 +28,9 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ + # some ops are still performed on the model by modifiers + # we want those ops to occur on the GPU + dispatch_for_generation(model) + LifecycleCallbacks.calibration_epoch_start() LifecycleCallbacks.calibration_epoch_end() diff --git a/tests/llmcompressor/modifiers/transform/quip/test_correctness.py b/tests/llmcompressor/modifiers/transform/quip/test_correctness.py new file mode 100644 index 000000000..276060b6b --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/quip/test_correctness.py @@ -0,0 +1,44 @@ +import os + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.core import State +from llmcompressor.modifiers.transform import QuIPModifier +from tests.testing_utils import requires_gpu + + +@requires_gpu +@pytest.mark.skipif( + (not os.getenv("HF_TOKEN")), + reason="Skipping tracing tests requiring gated model access", +) +@pytest.mark.parametrize( + "model_dtype,precision,exp_mse", + [ + (torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019 + (torch.bfloat16, torch.float32, 5e-3), # 0.0022 + (torch.float32, torch.float32, 5e-10), # 1.0777e-10 + (torch.float32, torch.float64, 5e-11), # 2.6632e-11 + ], +) +def test_apply_correctness(model_dtype, precision, exp_mse): + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=model_dtype + ) + state = State(model=model) + modifier = QuIPModifier(transform_type="random-hadamard", precision=precision) + + input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} + with torch.no_grad(): + true_output = model(**input) + + modifier.on_initialize(state) + modifier.on_start(state, None) + + with torch.no_grad(): + output = model(**input) + + print(torch.nn.MSELoss()(output.logits, true_output.logits)) + assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse diff --git a/tests/llmcompressor/modifiers/transform/quip/test_serialization.py b/tests/llmcompressor/modifiers/transform/quip/test_serialization.py new file mode 100644 index 000000000..3dc682728 --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/quip/test_serialization.py @@ -0,0 +1,7 @@ +from llmcompressor.modifiers.transform import QuIPModifier + + +def test_reload(): + modifier = QuIPModifier(transform_type="hadamard") + dump = modifier.model_dump() + assert QuIPModifier.model_validate(dump) == modifier