-
Notifications
You must be signed in to change notification settings - Fork 188
[Transform] QuIP Modifier #1648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
kylesayrs
wants to merge
45
commits into
main
Choose a base branch
from
kylesayrs/transform-quip-modifier
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+365
−33
Draft
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
ba617db
wip
kylesayrs 2f5b1c8
use random-hadamard, add correctness tests
kylesayrs 3aa35e7
add correctness test, note that precision makes a large difference
kylesayrs b6c088e
add on lifecycle methods
brian-dellabetta d1eb2a1
Merge branch 'main' into kylesayrs/transform-modifier
brian-dellabetta 3207124
TransformModifier with SpinQuant R1&R2
brian-dellabetta a88ca3c
spinquant and quip_online, running but outputting gibberish
brian-dellabetta 5bd51df
updated example
brian-dellabetta 3c216dd
DummyModel script
brian-dellabetta bbcdc8c
implement fuse_norm_linears
kylesayrs bd7f4d5
Merge branch 'kylesayrs/fuse-helpers' into bdellabe/transform-modifier
kylesayrs f5c2150
R1 working
kylesayrs dc5c30c
add r2, increase precision
kylesayrs 7172c26
spinquant modifier
kylesayrs 9298e82
remove space
kylesayrs f77226d
use iterable
kylesayrs fdb64b5
add rotation validation
kylesayrs 5daa2d5
embedding fusion
kylesayrs 0e9af7b
add missing norm fusion
kylesayrs fce83be
use norm mappings
kylesayrs a979f8a
break into separate files
kylesayrs 4cab29e
small cleanup
kylesayrs f1cc987
cleanup
kylesayrs a7bb2e2
more cleanup
kylesayrs 0cf0188
make new weight on cpu
kylesayrs 53ea307
standardize, make modifier serializable
kylesayrs 4b4257f
add compress model script
kylesayrs dc7ac1a
use untie_word_embeddings
kylesayrs 8542f8d
style
kylesayrs b1e637e
better registery logic
kylesayrs b44ac81
remove dummy model test (add later)
kylesayrs 7a52b71
docstring
kylesayrs f4d7ec6
update docstring
kylesayrs f18d0e8
rename example file
kylesayrs cec2914
use match_modules_set
kylesayrs f6c797e
Merge branch 'main' into bdellabe/transform-modifier
brian-dellabetta 0c5c514
unit test fixes
brian-dellabetta f2ef7cf
style fixes
brian-dellabetta d0e5bc5
remove hardcoded pipeline logic
brian-dellabetta 31ac8e9
docstrings
brian-dellabetta a4abb3d
stylefixes
brian-dellabetta 490b987
implement quip
kylesayrs ac7dbcd
add example, cleanup
kylesayrs a5d3ddc
update quip example
kylesayrs a21648d
prepare for merge without spinquant
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard | ||
import argparse | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor import oneshot | ||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||
from llmcompressor.modifiers.transform import SpinQuantModifier | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_id", type=str, help="Model stub to compress") | ||
parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier") | ||
parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)") | ||
return parser.parse_args() | ||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
|
||
# Select model and load it. | ||
MODEL_ID = args.model_id | ||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Configure the quantization algorithm to run. | ||
recipe = [] | ||
if args.transform_type: | ||
recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type)) | ||
|
||
if args.scheme: | ||
recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"])) | ||
|
||
# Apply algorithms. | ||
oneshot( | ||
model=model, | ||
recipe=recipe, | ||
dataset="ultrachat_200k", | ||
splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"}, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
) | ||
|
||
# Confirm generations of the quantized model look sane. | ||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||
output = model.generate(input_ids, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor import oneshot | ||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||
from llmcompressor.modifiers.transform import QuIPModifier | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
# Select model and load it. | ||
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
MODEL_ID, | ||
torch_dtype="auto", | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm to run. | ||
# * apply spinquant transforms to model in order to make quantization easier | ||
# * quantize the weights to 4 bit with GPTQ with a group size 128 | ||
recipe = [ | ||
QuIPModifier(transform_type="random-hadamard"), | ||
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), | ||
] | ||
|
||
# Apply algorithms. | ||
oneshot( | ||
model=model, | ||
recipe=recipe, | ||
dataset=ds, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
pipeline="datafree", | ||
) | ||
|
||
# Confirm generations of the quantized model look sane. | ||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||
output = model.generate(input_ids, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = MODEL_ID.split("/")[1] + "-quip-w4a16" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# flake8: noqa | ||
|
||
from .fuse import * | ||
from .prepare import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .quip import QuIPModifier |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .base import * | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from typing import List, Literal, Optional, Union | ||
|
||
from compressed_tensors.transform import ( | ||
TransformArgs, | ||
TransformConfig, | ||
TransformScheme, | ||
apply_transform_config, | ||
) | ||
from pydantic import Field, ValidationInfo, field_validator | ||
|
||
from llmcompressor.core import Event, EventType, State | ||
from llmcompressor.modifiers import Modifier | ||
|
||
__all__ = ["QuIPModifier"] | ||
|
||
|
||
class QuIPModifier(Modifier): | ||
""" | ||
Implements the transforms according to | ||
[QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) # noqa: E501 | ||
[QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) # noqa: E501 | ||
|
||
Transforms (rotations) are extra layers added to a model which reduce the accuracy | ||
loss induced by quantization. This is achieved through "rotating" weights and | ||
activations into a space with a smaller dynamic range of values, thus decreasing | ||
the range of scales required for quantization. | ||
|
||
QuIP and QuIP# apply transforms to every linear layer, two of which are fused into | ||
the model weights and two of which remain as online rotations computed at runtime. | ||
|
||
:param transform_type: The type of transform to apply to the model. | ||
`"hadamard"` has the least performance cost but only supports sizes which are | ||
powers of power of two. | ||
`"random-hadamard"` has more performance cost, but supports a much larger set of | ||
sizes. | ||
`"random-matrix"` has the greatest performance cost, but supports any size | ||
:param randomize: If true, create distinct transforms for each application | ||
:param learnable: If true, attach gradients to transform weights for training | ||
:param ignore: Modules to ignore when attaching transforms | ||
:param transform_config: Optional transform config for overriding provided arguments | ||
""" | ||
|
||
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( | ||
default="hadamard", exclude=True | ||
) | ||
randomize: bool = Field(default=False, exclude=True) | ||
learnable: bool = Field(default=False, exclude=True) | ||
ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True) | ||
|
||
# optional override for more fine-grained control | ||
# also included in recipe serialization | ||
transform_config: Optional[TransformConfig] = Field(default=None, repr=False) | ||
|
||
@field_validator("randomize", "learnable", mode="before") | ||
def validate_not_implemented(cls, value, info: ValidationInfo): | ||
if value: | ||
raise NotImplementedError(f"{info.field_name} is not supported right now") | ||
return value | ||
|
||
def on_initialize(self, state: State, **kwargs) -> bool: | ||
if self.transform_config is not None: | ||
return True | ||
|
||
self.transform_config = self._create_config() | ||
return True | ||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def on_start(self, state: State, event: Event, **kwargs): | ||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.started_ = True | ||
|
||
apply_transform_config(state.model, self.transform_config) | ||
|
||
def on_event(self, state: State, event: Event, **kwargs): | ||
if event.type_ == EventType.CALIBRATION_EPOCH_START: | ||
if not self.started_: | ||
self.on_start(state, None) | ||
|
||
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: | ||
pass | ||
|
||
elif event.type_ == EventType.CALIBRATION_EPOCH_END: | ||
if not self.ended_: | ||
self.on_end(state, None) | ||
|
||
def on_end(self, state: State, event: Event, **kwargs): | ||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.ended_ = True | ||
|
||
def on_finalize(self, state: State, **kwargs) -> bool: | ||
if not self.ended_: | ||
self.on_end(state, None) | ||
|
||
return True | ||
|
||
def _create_config(self) -> TransformConfig: | ||
return TransformConfig( | ||
config_groups={ | ||
"v": TransformScheme( | ||
type=self.transform_type, | ||
apply=[ | ||
TransformArgs( | ||
targets=["Linear"], | ||
location="input", # non-mergable | ||
ignore=self.ignore, | ||
), | ||
TransformArgs( | ||
targets=["Linear"], | ||
location="weight_input", | ||
# location="input", | ||
inverse=True, | ||
ignore=self.ignore, | ||
), | ||
], | ||
randomize=self.randomize, | ||
requires_grad=self.learnable, | ||
), | ||
"u": TransformScheme( | ||
type=self.transform_type, | ||
apply=[ | ||
TransformArgs( | ||
targets=["Linear"], | ||
location="weight_output", | ||
# location="output", | ||
ignore=self.ignore, | ||
), | ||
TransformArgs( | ||
targets=["Linear"], | ||
location="output", # non-mergable | ||
inverse=True, | ||
ignore=self.ignore, | ||
), | ||
], | ||
randomize=self.randomize, | ||
requires_grad=self.learnable, | ||
), | ||
} | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.