Skip to content

[Transform] R3 and R4 attention rotations #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: kylesayrs/transform-merge
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,18 @@ def compress_model(self, model: Module):
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):

if prefix in module_to_scheme or prefix in sparse_compression_targets:
module_device = get_execution_device(module).type
is_meta = module_device == "meta"
module_device = get_execution_device(module)
is_meta = module_device.type == "meta"

exec_device = "meta" if is_meta else "cpu"
onloading_device = "meta" if is_meta else module_device

# in the future, support compression on same device
with align_module_device(module, execution_device=exec_device):
state_dict = module.state_dict(prefix=f"{prefix}.")
state_dict = {
f"{prefix}.{name}": param
for name, param in module.named_parameters(recurse=False)
}

# quantization first
if prefix in module_to_scheme:
Expand All @@ -421,7 +424,7 @@ def compress_model(self, model: Module):

# remove any existing parameters
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters()):
for name, _ in list(module.named_parameters(recurse=False)):
delete_offload_parameter(module, name)

# replace with compressed parameters
Expand Down Expand Up @@ -458,7 +461,10 @@ def decompress_model(self, model: Module):
if prefix in module_to_scheme or prefix in sparse_compression_targets:
# in the future, support decompression on same device
with align_module_device(module, execution_device="cpu"):
state_dict = module.state_dict(prefix=f"{prefix}.")
state_dict = {
f"{prefix}.{name}": param
for name, param in module.named_parameters(recurse=False)
}

# sparsity first
if prefix in sparse_compression_targets:
Expand All @@ -483,7 +489,7 @@ def decompress_model(self, model: Module):
# remove any existing parameters
exec_device = get_execution_device(module)
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters()):
for name, _ in list(module.named_parameters(recurse=False)):
delete_offload_parameter(module, name)

# replace with decompressed parameters
Expand Down Expand Up @@ -747,12 +753,16 @@ def _replace_weights(self, dense_weight_generator, model: Module):

def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
"""
Returns a dictionary which maps quantized module names to their quantization schemes
Returns a dictionary which maps quantized module names to their quantization
schemes. Only includes modules with weight quantization
"""
return {
fix_fsdp_module_name(name): module.quantization_scheme
for name, module in model.named_modules()
if is_module_quantized(module)
if (
hasattr(module, "quantization_scheme") and
module.quantization_scheme.weights is not None
)
}


Expand Down Expand Up @@ -785,4 +795,4 @@ def override_quantization_status(
try:
yield
finally:
config.quantization_status = original_status
config.quantization_status = original_status
1 change: 1 addition & 0 deletions src/compressed_tensors/modeling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This folder contains code which models existing `torch` logic as used by `transformers`
159 changes: 159 additions & 0 deletions src/compressed_tensors/modeling/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict, defaultdict
from typing import TYPE_CHECKING, Callable, Optional

import torch
from compressed_tensors.utils import getattr_chain
from torch.utils.hooks import RemovableHandle
from transformers import AttentionInterface, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import eager_attention_forward


if TYPE_CHECKING:
from compressed_tensors.quantization import QuantizationArgs, QuantizationStatus


__all__ = ["CompressedAttentionImpl", "enable_compressed_attention", "call_attn_impl"]


ActivationHookFn = Callable[[torch.nn.Module, torch.Tensor], None]


class CompressedAttentionImpl(torch.nn.Module):
"""
Callable attention implementation which applies transforms, calibration, and
quantization if applicable. Can be hooked with calibrations hooks in order to
trigger quantization observers.

:param attn_implementation: original attention implementation to call after hooks
"""

NAME = "compressed_attention"
ATTN_IMPL = "eager"
_ATTN_IMPLS = dict()

@classmethod
def from_module(cls, module: torch.nn.Module):
if module not in cls._ATTN_IMPLS:
cls._ATTN_IMPLS[module] = cls()
return cls._ATTN_IMPLS[module]

def __init__(self):
super().__init__()
self.query_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict()
self.key_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict()
self.value_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict()

def register_query_hook(self, hook: ActivationHookFn) -> RemovableHandle:
handle = RemovableHandle(self.query_hooks)
self.query_hooks[handle.id] = hook

return handle

def register_key_hook(self, hook: ActivationHookFn) -> RemovableHandle:
handle = RemovableHandle(self.key_hooks)
self.key_hooks[handle.id] = hook

return handle

def register_value_hook(self, hook: ActivationHookFn) -> RemovableHandle:
handle = RemovableHandle(self.value_hooks)
self.value_hooks[handle.id] = hook

return handle

def forward(
self,
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
from compressed_tensors.quantization import forward_quantize

for hook in self.query_hooks.values():
output = hook(self, query)
if output is not None:
query = output

for hook in self.key_hooks.values():
output = hook(self, key)
if output is not None:
key = output

for hook in self.value_hooks.values():
output = hook(self, value)
if output is not None:
value = output

# TODO: attnq
# 2. calibrate/ apply quantization
# args_path = "quantization_scheme.input_activations"
# status_path = "quantization_status"
# input_args: Optional[QuantizationArgs] = getattr_chain(
# module, args_path, None
# )
# status: Optional[QuantizationStatus] = getattr(module, status_path, None)
# if input_args is not None and status in (
# QuantizationStatus.CALIBRATION,
# QuantizationStatus.FROZEN,
# ):
# query = forward_quantize(module, query, "q", input_args)
# key = forward_quantize(module, key, "k", input_args)
# value = forward_quantize(module, value, "v", input_args)

# 3. apply original attention function
# `eager_attention_forward` is duplicated across models by design
# assume that llama implementation is representative of all attention functions
# see: https://github.com/huggingface/transformers/issues/38541#issuecomment-2958567250 # noqa: 501

attention_fn: Callable = (
eager_attention_forward
# if self.ATTN_IMPL == "eager"
# else ALL_ATTENTION_FUNCTIONS[self.ATTN_IMPL]
)
# print(self.ATTN_IMPL)
return attention_fn(
module, query, key, value, attention_mask, scaling, dropout, **kwargs
)


def enable_compressed_attention(model: torch.nn.Module):
"""
Enables transforms, calibration, and quantization for an attention implementation.
This function can safetly be called multiple times on the same model.

:param model: model to enable compressed quantization for
:return: singleton instance of `CompressedAttentionImpl`
"""
if not isinstance(model, PreTrainedModel):
return

attn_impl = getattr(model.config, "_attn_implementation", "eager")

CompressedAttentionImpl.ATTN_IMPL = attn_impl
AttentionInterface.register(CompressedAttentionImpl.NAME, call_attn_impl)
model.config._attn_implementation = CompressedAttentionImpl.NAME
# model.set_attention_implementation(CompressedAttentionImpl.NAME)


def call_attn_impl(module: torch.nn.Module, *args, **kwargs):
return CompressedAttentionImpl.from_module(module)(module, *args, **kwargs)
4 changes: 4 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.modeling.attention import enable_compressed_attention
from compressed_tensors.quantization.lifecycle.compressed import (
compress_quantized_weights,
)
Expand Down Expand Up @@ -144,6 +145,9 @@ def apply_quantization_config(
for target in scheme.targets:
target_to_scheme[target] = scheme

# enable attention calibration/ quantization
enable_compressed_attention(model)

if run_compressed:
from compressed_tensors.linear.compressed_linear import CompressedLinear

Expand Down
3 changes: 3 additions & 0 deletions src/compressed_tensors/transform/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import torch
from compressed_tensors.modeling.attention import enable_compressed_attention
from compressed_tensors.transform import TransformConfig, TransformFactory


Expand All @@ -27,6 +28,8 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
:param model: model to apply config to
:param config: transform config to apply
"""
enable_compressed_attention(model)

for name, scheme in config.config_groups.items():
factory = TransformFactory.from_scheme(scheme, name=name)
factory.apply_to_model(model)
52 changes: 45 additions & 7 deletions src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@

import torch
import torch.nn.utils.parametrize as P
from compressed_tensors import InternalModule, match_named_modules
from compressed_tensors.modeling.attention import CompressedAttentionImpl
from compressed_tensors.registry.registry import RegistryMixin, T
from compressed_tensors.transform import (
TransformArgs,
TransformLocation,
TransformScheme,
)
from compressed_tensors.utils import (
InternalModule,
align_module_device,
delete_offload_module,
has_offloaded_params,
match_named_modules,
patch_attr,
register_offload_module,
update_offload_parameter,
Expand Down Expand Up @@ -111,15 +113,15 @@ def _apply_to_module(self, module: Module, args: TransformArgs):

# create transform as submodule
transform_name = f"{self.name}_{args.location}"
transform = self.create_transform(module, args)
self.transforms.append(transform)
register_offload_module(module, transform_name, transform)

# register input transformation hook
if args.location == TransformLocation.INPUT:
transform = self.create_transform(module, args)
self.transforms.append(transform)
register_offload_module(module, transform_name, transform)

def input_hook(_, args):
input = args[0]
input = args[0] if isinstance(args, tuple) else args
return transform(input)

module.register_forward_pre_hook(input_hook, prepend=True)
Expand All @@ -129,6 +131,9 @@ def input_hook(_, args):
TransformLocation.WEIGHT_INPUT,
TransformLocation.WEIGHT_OUTPUT,
):
transform = self.create_transform(module, args)
register_offload_module(module, transform_name, transform)

# fuse transform into weight
assert hasattr(module, "weight")
with torch.no_grad(), align_module_device(module):
Expand All @@ -140,22 +145,55 @@ def input_hook(_, args):
if has_offloaded_params(module):
raise ValueError("Offloaded training is not supported")
P.register_parametrization(module, "weight", transform)
self.transforms.append(transform)

else:
# transform is no longer needed (unfusing is not supported)
delete_offload_module(module, transform_name)

# register output transformation hook
elif args.location == TransformLocation.OUTPUT:
transform = self.create_transform(module, args)
self.transforms.append(transform)
register_offload_module(module, transform_name, transform)

def output_hook(_, _input, output):
return transform(output)

module.register_forward_hook(output_hook)

# other locations such as q_attn and k_attn have not been implemented
# query hook registered to `CompressedAttentionImpl`
elif args.location in TransformLocation.ATTN_Q:
# TODO: makes name assumptions. Maybe we can target q_proj in the config
# then assume parent? Not sure
transform = self.create_transform(module.q_proj, args)
self.transforms.append(transform)
register_offload_module(module, transform_name, transform)

attention_impl = CompressedAttentionImpl.from_module(module)

def query_hook(_, query):
return transform(query)

attention_impl.register_query_hook(query_hook)

# key hook registered to `CompressedAttentionImpl`
elif args.location in TransformLocation.ATTN_K:
# TODO: makes name assumptions. Maybe we can target k_proj in the config
# then assume parent? Not sure
transform = self.create_transform(module.k_proj, args)
self.transforms.append(transform)
register_offload_module(module, transform_name, transform)

attention_impl = CompressedAttentionImpl.from_module(module)

def key_hook(_, key):
return transform(key)

attention_impl.register_key_hook(key_hook)

else:
raise NotImplementedError()
raise ValueError()

def _update_tied_weights(self):
"""
Expand Down
Loading