From 3591df5a7e5a6fabc8face7d9b2942003239fd16 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Jul 2025 16:37:08 -0400 Subject: [PATCH 1/4] add and enable compressed quantization Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/README.md | 1 + src/compressed_tensors/modeling/attention.py | 142 ++++++++++++++++++ .../quantization/lifecycle/apply.py | 4 + src/compressed_tensors/transform/apply.py | 3 + .../transform/transform_args.py | 8 +- 5 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 src/compressed_tensors/modeling/README.md create mode 100644 src/compressed_tensors/modeling/attention.py diff --git a/src/compressed_tensors/modeling/README.md b/src/compressed_tensors/modeling/README.md new file mode 100644 index 00000000..c5e6d3b6 --- /dev/null +++ b/src/compressed_tensors/modeling/README.md @@ -0,0 +1 @@ +This folder contains code which models existing `torch` logic as used by `transformers` \ No newline at end of file diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py new file mode 100644 index 00000000..f7c2c06b --- /dev/null +++ b/src/compressed_tensors/modeling/attention.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from typing import Callable, Optional + +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStatus, + forward_quantize, +) +from compressed_tensors.transform import TransformBase, TransformLocation +from compressed_tensors.utils import getattr_chain +from torch import Module, Tensor +from torch.utils.hooks import RemovableHandle +from transformers import AttentionInterface, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.llama.modeling_llama import eager_attention_forward + + +__all__ = ["CompressedAttentionImpl", "enable_compressed_attention"] + + +COMPRESSED_ATTENTION_NAME = "compressed_attention" + + +CalibHook = Callable[[Module, Tensor, Tensor, Tensor]] + + +class CompressedAttentionImpl(Module): + """ + Callable attention implementation which applies transforms, calibration, and + quantization if applicable. Can be hooked with calibrations hooks in order to + trigger quantization observers. + + In the future, the idea of making attention implementions hookable Modules + could be upstreamed to transformers model definitions + + :param attn_implementation: original attention implementation to call after hooks + """ + + def __init__(self, attn_implementation: str): + self.attn_implementation = attn_implementation + self.calib_hooks: OrderedDict[int, CalibHook] = OrderedDict() + + # `eager_attention_forward` is duplicated across models by design + # assume that llama implementation is representative of all attention functions + # see: https://github.com/huggingface/transformers/issues/38541#issuecomment-2958567250 # noqa: 501 + self.attention_fn: Callable = ( + eager_attention_forward + if self.attn_implementation == "eager" + else ALL_ATTENTION_FUNCTIONS[self.attn_implementation] + ) + + def register_calib_hook(self, hook: CalibHook) -> RemovableHandle: + """ + Register a calibration hook which is called + after transforms and before quantization + + :param hook: hook to be called + :return: removable handle + """ + handle = RemovableHandle(self.calib_hooks) + self.calib_hooks[handle.id] = hook + + return handle + + def forward( + self, + module: Module, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: Optional[Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, + ): + # 1. apply transforms + for submodule in module.children(): + if isinstance(submodule, TransformBase): + if TransformBase.args.location == TransformLocation.ATTN_Q: + query = submodule(query) + + if TransformBase.args.location == TransformLocation.ATTN_K: + key = submodule(key) + + # note that, unlike qk, v_proj does not undergo RoPE before attention + # and can therefore be targeted directly + + # TODO: attnq + # 2. calibrate/ apply quantization + # args_path = "quantization_scheme.input_activations" + # input_args: Optional[QuantizationArgs] = getattr_chain(module, args_path, None) # noqa: 501 + # if input_args is not None: + # status_path = "quantization_status" + # status: Optional[QuantizationStatus] = getattr(module, status_path, None) + + # # 2a. calibrate quantization + # if status == QuantizationStatus.CALIBRATION: + # assert len(self.calib_hooks) <= 1 + # for hook in self.calib_hooks.items(): + # hook(module, query, key, value) + + # # 2b. apply quantization + # if status in (QuantizationStatus.CALIBRATION, QuantizationStatus.FROZEN): + # query = forward_quantize(module, query, "q", input_args) + # key = forward_quantize(module, key, "k", input_args) + # value = forward_quantize(module, value, "v", input_args) + + # 3. apply original attention function + return self.attention_fn( + module, query, key, value, attention_mask, scaling, dropout, **kwargs + ) + + +def enable_compressed_attention(model: PreTrainedModel) -> CompressedAttentionImpl: + """ + Enables transforms, calibration, and quantization for an attention implementation. + This function can safetly be called multiple times on the same model. + + :param model: model to enable compressed quantization for + :return: singleton instance of `CompressedAttentionImpl` + """ + attn_implementation = getattr(model.config, "attn_implementation", "eager") + if attn_implementation != COMPRESSED_ATTENTION_NAME: + compressed_attention = CompressedAttentionImpl(attn_implementation) + AttentionInterface.register(COMPRESSED_ATTENTION_NAME, compressed_attention) + model.config.attn_implementation = COMPRESSED_ATTENTION_NAME + + return ALL_ATTENTION_FUNCTIONS[COMPRESSED_ATTENTION_NAME] diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba..7ff8f8ff 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -22,6 +22,7 @@ import torch from compressed_tensors.config import CompressionFormat +from compressed_tensors.modeling.attention import enable_compressed_attention from compressed_tensors.quantization.lifecycle.compressed import ( compress_quantized_weights, ) @@ -189,6 +190,9 @@ def apply_quantization_config( f"{set(config.ignore) - set(ignored_submodules)}" ) + # enable attention calibration/ quantization + enable_compressed_attention(model) + # apply current quantization status across all targeted layers apply_quantization_status(model, config.quantization_status) return names_to_scheme diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index a5d4c8c2..70464d58 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +from compressed_tensors.modeling.attention import enable_compressed_attention from compressed_tensors.transform import TransformConfig, TransformFactory @@ -30,3 +31,5 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): for name, scheme in config.config_groups.items(): factory = TransformFactory.from_scheme(scheme, name=name) factory.apply_to_model(model) + + enable_compressed_attention(model) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index e94d4d2d..582510c8 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -33,8 +33,8 @@ class TransformLocation(str, Enum): | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 - | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501 - | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501 + | `ATTN_Q` | online | query_states | `this.ATTN_K` | # noqa: E501 + | `ATTN_K` | online | key_states | `this.Q_ATTN` | # noqa: E501 | -------------------------------------------------------------------------------------------------------- | # noqa: E501 """ @@ -42,8 +42,8 @@ class TransformLocation(str, Enum): WEIGHT_INPUT = "weight_input" WEIGHT_OUTPUT = "weight_output" OUTPUT = "output" - K_CACHE = "k_cache" - Q_ATTN = "q_attn" + ATTN_Q = "attn_q" + ATTN_K = "attn_k" class TransformArgs(BaseModel, use_enum_values=True): From 6bc9b2fe7b391845eab9a6327f8a8f89bbfda57c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Jul 2025 17:28:49 -0400 Subject: [PATCH 2/4] use qkv hooks Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 87 +++++++++++-------- .../quantization/lifecycle/apply.py | 6 +- src/compressed_tensors/transform/apply.py | 4 +- .../transform/factory/base.py | 24 ++++- 4 files changed, 75 insertions(+), 46 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index f7c2c06b..c43e5cfd 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -29,13 +29,17 @@ from transformers.models.llama.modeling_llama import eager_attention_forward -__all__ = ["CompressedAttentionImpl", "enable_compressed_attention"] +__all__ = [ + "CompressedAttentionImpl", + "enable_compressed_attention", + "get_compressed_attention_impl", +] COMPRESSED_ATTENTION_NAME = "compressed_attention" -CalibHook = Callable[[Module, Tensor, Tensor, Tensor]] +ActivationHookFn = Callable[[Module, Tensor]] class CompressedAttentionImpl(Module): @@ -52,7 +56,9 @@ class CompressedAttentionImpl(Module): def __init__(self, attn_implementation: str): self.attn_implementation = attn_implementation - self.calib_hooks: OrderedDict[int, CalibHook] = OrderedDict() + self.query_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() + self.key_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() + self.value_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() # `eager_attention_forward` is duplicated across models by design # assume that llama implementation is representative of all attention functions @@ -63,16 +69,21 @@ def __init__(self, attn_implementation: str): else ALL_ATTENTION_FUNCTIONS[self.attn_implementation] ) - def register_calib_hook(self, hook: CalibHook) -> RemovableHandle: - """ - Register a calibration hook which is called - after transforms and before quantization + def register_query_hook(self, hook: ActivationHookFn) -> RemovableHandle: + handle = RemovableHandle(self.query_hooks) + self.query_hooks[handle.id] = hook - :param hook: hook to be called - :return: removable handle - """ - handle = RemovableHandle(self.calib_hooks) - self.calib_hooks[handle.id] = hook + return handle + + def register_key_hook(self, hook: ActivationHookFn) -> RemovableHandle: + handle = RemovableHandle(self.key_hooks) + self.key_hooks[handle.id] = hook + + return handle + + def register_value_hook(self, hook: ActivationHookFn) -> RemovableHandle: + handle = RemovableHandle(self.value_hooks) + self.value_hooks[handle.id] = hook return handle @@ -87,37 +98,30 @@ def forward( dropout: float = 0.0, **kwargs, ): - # 1. apply transforms - for submodule in module.children(): - if isinstance(submodule, TransformBase): - if TransformBase.args.location == TransformLocation.ATTN_Q: - query = submodule(query) + for hook in self.query_hooks(): + query = hook(self, query) - if TransformBase.args.location == TransformLocation.ATTN_K: - key = submodule(key) + for hook in self.key_hooks(): + key = hook(self, key) - # note that, unlike qk, v_proj does not undergo RoPE before attention - # and can therefore be targeted directly + for hook in self.value_hooks(): + value = hook(self, value) # TODO: attnq # 2. calibrate/ apply quantization # args_path = "quantization_scheme.input_activations" - # input_args: Optional[QuantizationArgs] = getattr_chain(module, args_path, None) # noqa: 501 - # if input_args is not None: - # status_path = "quantization_status" - # status: Optional[QuantizationStatus] = getattr(module, status_path, None) - - # # 2a. calibrate quantization - # if status == QuantizationStatus.CALIBRATION: - # assert len(self.calib_hooks) <= 1 - # for hook in self.calib_hooks.items(): - # hook(module, query, key, value) - - # # 2b. apply quantization - # if status in (QuantizationStatus.CALIBRATION, QuantizationStatus.FROZEN): - # query = forward_quantize(module, query, "q", input_args) - # key = forward_quantize(module, key, "k", input_args) - # value = forward_quantize(module, value, "v", input_args) + # status_path = "quantization_status" + # input_args: Optional[QuantizationArgs] = getattr_chain( + # module, args_path, None + # ) + # status: Optional[QuantizationStatus] = getattr(module, status_path, None) + # if input_args is not None and status in ( + # QuantizationStatus.CALIBRATION, + # QuantizationStatus.FROZEN, + # ): + # query = forward_quantize(module, query, "q", input_args) + # key = forward_quantize(module, key, "k", input_args) + # value = forward_quantize(module, value, "v", input_args) # 3. apply original attention function return self.attention_fn( @@ -125,7 +129,7 @@ def forward( ) -def enable_compressed_attention(model: PreTrainedModel) -> CompressedAttentionImpl: +def enable_compressed_attention(model: PreTrainedModel): """ Enables transforms, calibration, and quantization for an attention implementation. This function can safetly be called multiple times on the same model. @@ -139,4 +143,11 @@ def enable_compressed_attention(model: PreTrainedModel) -> CompressedAttentionIm AttentionInterface.register(COMPRESSED_ATTENTION_NAME, compressed_attention) model.config.attn_implementation = COMPRESSED_ATTENTION_NAME + +def get_compressed_attention_impl() -> CompressedAttentionImpl: + if COMPRESSED_ATTENTION_NAME not in ALL_ATTENTION_FUNCTIONS: + raise ValueError( + "Please call `enable_compressed_attention(model)` before attempting " + "to get singleton instance of `CompressedAttentionImpl`" + ) return ALL_ATTENTION_FUNCTIONS[COMPRESSED_ATTENTION_NAME] diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7ff8f8ff..b8dafd25 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -145,6 +145,9 @@ def apply_quantization_config( for target in scheme.targets: target_to_scheme[target] = scheme + # enable attention calibration/ quantization + enable_compressed_attention(model) + if run_compressed: from compressed_tensors.linear.compressed_linear import CompressedLinear @@ -190,9 +193,6 @@ def apply_quantization_config( f"{set(config.ignore) - set(ignored_submodules)}" ) - # enable attention calibration/ quantization - enable_compressed_attention(model) - # apply current quantization status across all targeted layers apply_quantization_status(model, config.quantization_status) return names_to_scheme diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index 70464d58..970043f0 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -28,8 +28,8 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): :param model: model to apply config to :param config: transform config to apply """ + enable_compressed_attention(model) + for name, scheme in config.config_groups.items(): factory = TransformFactory.from_scheme(scheme, name=name) factory.apply_to_model(model) - - enable_compressed_attention(model) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 1fdfa121..93d531f5 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -19,6 +19,7 @@ import torch import torch.nn.utils.parametrize as P from compressed_tensors import InternalModule, match_named_modules +from compressed_tensors.modeling.attention import get_compressed_attention_impl from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -119,7 +120,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs): if args.location == TransformLocation.INPUT: def input_hook(_, args): - input = args[0] + input = args[0] if isinstance(args, tuple) else args return transform(input) module.register_forward_pre_hook(input_hook, prepend=True) @@ -153,9 +154,26 @@ def output_hook(_, _input, output): module.register_forward_hook(output_hook) - # other locations such as q_attn and k_attn have not been implemented + # query hook registered to `CompressedAttentionImpl` + elif args.location in TransformLocation.ATTN_Q: + attention_impl = get_compressed_attention_impl() + + def query_hook(_, query): + return transform(query) + + attention_impl.register_query_hook(query_hook) + + # key hook registered to `CompressedAttentionImpl` + elif args.location in TransformLocation.ATTN_K: + attention_impl = get_compressed_attention_impl() + + def key_hook(_, key): + return transform(key) + + attention_impl.register_key_hook(key_hook) + else: - raise NotImplementedError() + raise ValueError() def _update_tied_weights(self): """ From 9059400fe7994b1bb40f6c423b66bd875dff37b9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 17 Jul 2025 17:28:32 -0400 Subject: [PATCH 3/4] r3 r4 works, but not with sdpa Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 124 +++++++++-------- .../transform/factory/base.py | 34 ++++- .../transform/utils/matrix.py | 130 +++++++++--------- tests/test_transform/conftest.py | 86 +++++------- .../factory/test_correctness.py | 38 +++-- 5 files changed, 224 insertions(+), 188 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index c43e5cfd..85c4205c 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -12,63 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict -from typing import Callable, Optional - -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationStatus, - forward_quantize, -) -from compressed_tensors.transform import TransformBase, TransformLocation +from collections import OrderedDict, defaultdict +from typing import TYPE_CHECKING, Callable, Optional + +import torch from compressed_tensors.utils import getattr_chain -from torch import Module, Tensor from torch.utils.hooks import RemovableHandle from transformers import AttentionInterface, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.llama.modeling_llama import eager_attention_forward -__all__ = [ - "CompressedAttentionImpl", - "enable_compressed_attention", - "get_compressed_attention_impl", -] +if TYPE_CHECKING: + from compressed_tensors.quantization import QuantizationArgs, QuantizationStatus -COMPRESSED_ATTENTION_NAME = "compressed_attention" +__all__ = ["CompressedAttentionImpl", "enable_compressed_attention", "call_attn_impl"] -ActivationHookFn = Callable[[Module, Tensor]] +ActivationHookFn = Callable[[torch.nn.Module, torch.Tensor], None] -class CompressedAttentionImpl(Module): +class CompressedAttentionImpl(torch.nn.Module): """ Callable attention implementation which applies transforms, calibration, and quantization if applicable. Can be hooked with calibrations hooks in order to trigger quantization observers. - In the future, the idea of making attention implementions hookable Modules - could be upstreamed to transformers model definitions - :param attn_implementation: original attention implementation to call after hooks """ - def __init__(self, attn_implementation: str): - self.attn_implementation = attn_implementation + NAME = "compressed_attention" + ATTN_IMPL = "eager" + _ATTN_IMPLS = dict() + + @classmethod + def from_module(cls, module: torch.nn.Module): + if module not in cls._ATTN_IMPLS: + cls._ATTN_IMPLS[module] = cls() + return cls._ATTN_IMPLS[module] + + def __init__(self): + super().__init__() self.query_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() self.key_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() self.value_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() - # `eager_attention_forward` is duplicated across models by design - # assume that llama implementation is representative of all attention functions - # see: https://github.com/huggingface/transformers/issues/38541#issuecomment-2958567250 # noqa: 501 - self.attention_fn: Callable = ( - eager_attention_forward - if self.attn_implementation == "eager" - else ALL_ATTENTION_FUNCTIONS[self.attn_implementation] - ) - def register_query_hook(self, hook: ActivationHookFn) -> RemovableHandle: handle = RemovableHandle(self.query_hooks) self.query_hooks[handle.id] = hook @@ -89,23 +78,31 @@ def register_value_hook(self, hook: ActivationHookFn) -> RemovableHandle: def forward( self, - module: Module, - query: Tensor, - key: Tensor, - value: Tensor, - attention_mask: Optional[Tensor], + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): - for hook in self.query_hooks(): - query = hook(self, query) + from compressed_tensors.quantization import forward_quantize + + for hook in self.query_hooks.values(): + output = hook(self, query) + if output is not None: + query = output - for hook in self.key_hooks(): - key = hook(self, key) + for hook in self.key_hooks.values(): + output = hook(self, key) + if output is not None: + key = output - for hook in self.value_hooks(): - value = hook(self, value) + for hook in self.value_hooks.values(): + output = hook(self, value) + if output is not None: + value = output # TODO: attnq # 2. calibrate/ apply quantization @@ -124,12 +121,22 @@ def forward( # value = forward_quantize(module, value, "v", input_args) # 3. apply original attention function - return self.attention_fn( + # `eager_attention_forward` is duplicated across models by design + # assume that llama implementation is representative of all attention functions + # see: https://github.com/huggingface/transformers/issues/38541#issuecomment-2958567250 # noqa: 501 + + attention_fn: Callable = ( + eager_attention_forward + # if self.ATTN_IMPL == "eager" + # else ALL_ATTENTION_FUNCTIONS[self.ATTN_IMPL] + ) + # print(self.ATTN_IMPL) + return attention_fn( module, query, key, value, attention_mask, scaling, dropout, **kwargs ) -def enable_compressed_attention(model: PreTrainedModel): +def enable_compressed_attention(model: torch.nn.Module): """ Enables transforms, calibration, and quantization for an attention implementation. This function can safetly be called multiple times on the same model. @@ -137,17 +144,16 @@ def enable_compressed_attention(model: PreTrainedModel): :param model: model to enable compressed quantization for :return: singleton instance of `CompressedAttentionImpl` """ - attn_implementation = getattr(model.config, "attn_implementation", "eager") - if attn_implementation != COMPRESSED_ATTENTION_NAME: - compressed_attention = CompressedAttentionImpl(attn_implementation) - AttentionInterface.register(COMPRESSED_ATTENTION_NAME, compressed_attention) - model.config.attn_implementation = COMPRESSED_ATTENTION_NAME - - -def get_compressed_attention_impl() -> CompressedAttentionImpl: - if COMPRESSED_ATTENTION_NAME not in ALL_ATTENTION_FUNCTIONS: - raise ValueError( - "Please call `enable_compressed_attention(model)` before attempting " - "to get singleton instance of `CompressedAttentionImpl`" - ) - return ALL_ATTENTION_FUNCTIONS[COMPRESSED_ATTENTION_NAME] + if not isinstance(model, PreTrainedModel): + return + + attn_impl = getattr(model.config, "_attn_implementation", "eager") + + CompressedAttentionImpl.ATTN_IMPL = attn_impl + AttentionInterface.register(CompressedAttentionImpl.NAME, call_attn_impl) + model.config._attn_implementation = CompressedAttentionImpl.NAME + # model.set_attention_implementation(CompressedAttentionImpl.NAME) + + +def call_attn_impl(module: torch.nn.Module, *args, **kwargs): + return CompressedAttentionImpl.from_module(module)(module, *args, **kwargs) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 93d531f5..0849917a 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -18,8 +18,7 @@ import torch import torch.nn.utils.parametrize as P -from compressed_tensors import InternalModule, match_named_modules -from compressed_tensors.modeling.attention import get_compressed_attention_impl +from compressed_tensors.modeling.attention import CompressedAttentionImpl from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -27,9 +26,11 @@ TransformScheme, ) from compressed_tensors.utils import ( + InternalModule, align_module_device, delete_offload_module, has_offloaded_params, + match_named_modules, patch_attr, register_offload_module, update_offload_parameter, @@ -112,12 +113,12 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location}" - transform = self.create_transform(module, args) - self.transforms.append(transform) - register_offload_module(module, transform_name, transform) # register input transformation hook if args.location == TransformLocation.INPUT: + transform = self.create_transform(module, args) + self.transforms.append(transform) + register_offload_module(module, transform_name, transform) def input_hook(_, args): input = args[0] if isinstance(args, tuple) else args @@ -130,6 +131,9 @@ def input_hook(_, args): TransformLocation.WEIGHT_INPUT, TransformLocation.WEIGHT_OUTPUT, ): + transform = self.create_transform(module, args) + register_offload_module(module, transform_name, transform) + # fuse transform into weight assert hasattr(module, "weight") with torch.no_grad(), align_module_device(module): @@ -141,6 +145,7 @@ def input_hook(_, args): if has_offloaded_params(module): raise ValueError("Offloaded training is not supported") P.register_parametrization(module, "weight", transform) + self.transforms.append(transform) else: # transform is no longer needed (unfusing is not supported) @@ -148,6 +153,9 @@ def input_hook(_, args): # register output transformation hook elif args.location == TransformLocation.OUTPUT: + transform = self.create_transform(module, args) + self.transforms.append(transform) + register_offload_module(module, transform_name, transform) def output_hook(_, _input, output): return transform(output) @@ -156,7 +164,13 @@ def output_hook(_, _input, output): # query hook registered to `CompressedAttentionImpl` elif args.location in TransformLocation.ATTN_Q: - attention_impl = get_compressed_attention_impl() + # TODO: makes name assumptions. Maybe we can target q_proj in the config + # then assume parent? Not sure + transform = self.create_transform(module.q_proj, args) + self.transforms.append(transform) + register_offload_module(module, transform_name, transform) + + attention_impl = CompressedAttentionImpl.from_module(module) def query_hook(_, query): return transform(query) @@ -165,7 +179,13 @@ def query_hook(_, query): # key hook registered to `CompressedAttentionImpl` elif args.location in TransformLocation.ATTN_K: - attention_impl = get_compressed_attention_impl() + # TODO: makes name assumptions. Maybe we can target k_proj in the config + # then assume parent? Not sure + transform = self.create_transform(module.k_proj, args) + self.transforms.append(transform) + register_offload_module(module, transform_name, transform) + + attention_impl = CompressedAttentionImpl.from_module(module) def key_hook(_, key): return transform(key) diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index 18a7dc3a..4aa7b746 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -59,47 +59,13 @@ def get_transform_size( def apply_transform_weight( - weight: torch.Tensor, + transform_weight: torch.Tensor, value: torch.Tensor, location: TransformLocation, module_type: type[torch.nn.Module], ) -> torch.Tensor: """ - :param weight: transform weight to apply - :param value: value to apply weight to - :param location: determines how weight should be applied - :param model_type: result of type(module), passed in to determine application of - weight transform. This is needed because torch uses convention: - - torch.nn.Linear(in_features,out_features) has weight shape - (out_features, in_features) - - torch.nn.Embedding(num_embeddings, embedding_dim) has weight shape - (num_embeddings, embedding_dim) - The transform has to account for Linear's transposed weights - :return: value after weight has been applied - """ - # get function used to apply transform - fn, axis = _get_transform_method(module_type, location) - - # reshape for head_dim - head_dim = weight.shape[0] - num_heads = value.shape[axis] // head_dim - value = value.unflatten(axis, (num_heads, head_dim)) - - # apply transform - value = fn(weight, value) - - # [undo] reshape for head_dim - value = value.flatten(axis - 1, axis) - - return value - - -def _get_transform_method( - module_type: type[torch.nn.Module], - location: TransformLocation, -) -> Tuple[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], int]: - """ - Using the transform location, determine how to apply the transform weight to the + Using the transform location, apply the transform_weight to the given value wrt linear weights. For more info on input and output transforms, see `TransformLocation` @@ -129,51 +95,89 @@ def _get_transform_method( = y U = yh - :param weight: transform weight to apply - :param value: value to apply weight to + :param transform_weight: transform weight to apply + :param value: value to apply transform_weight to :param location: determines how weight should be applied - :return: value after transform weight has been applied + :param model_type: result of type(module), passed in to determine application of + weight transform + :return: value after transform_weight has been applied """ - fn = axis = None + + assert transform_weight.shape[0] == transform_weight.shape[1] if module_type == torch.nn.Linear: - if location == TransformLocation.INPUT: - fn = lambda weight, value: value @ weight - axis = -1 + if location in ( + TransformLocation.INPUT, + TransformLocation.ATTN_Q, + TransformLocation.ATTN_K, + ): + return _multihead_matmul(value, transform_weight) elif location == TransformLocation.WEIGHT_INPUT: - fn = lambda weight, value: value @ weight.T - axis = -1 + # equivalent to (transform_weight @ value.T).T + return _multihead_matmul(value, transform_weight.T) elif location == TransformLocation.WEIGHT_OUTPUT: - fn = lambda weight, value: weight.T @ value - axis = -2 + # equivalent to (value.T @ transform_weight).T + return _multihead_matmul(transform_weight.T, value) elif location == TransformLocation.OUTPUT: - fn = lambda weight, value: value @ weight - axis = -1 + return _multihead_matmul(value, transform_weight) # similar derivation to torch.nn.Linear, but `y = (x W)` - if module_type == torch.nn.Embedding: + elif module_type == torch.nn.Embedding: if location == TransformLocation.INPUT: - fn = lambda weight, value: value @ weight - axis = -1 + return _multihead_matmul(value, transform_weight) elif location == TransformLocation.WEIGHT_INPUT: - fn = lambda weight, value: weight @ value - axis = -1 + return _multihead_matmul( + transform_weight, + value, + ) elif location == TransformLocation.WEIGHT_OUTPUT: - fn = lambda weight, value: value @ weight - axis = -1 + return _multihead_matmul(value, transform_weight) elif location == TransformLocation.OUTPUT: - fn = lambda weight, value: value @ weight - axis = -1 + return _multihead_matmul(value, transform_weight) + + raise NotImplementedError( + f"Applying transforms to {module_type} {location} is not supported" + ) - if fn is None: - raise NotImplementedError( - f"Applying transforms to {module_type} {location} is not supported" - ) - return fn, axis +def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """ + Performs A @ B for last two dims of two matrices A and B that possibly + have different shapes, as is the case in multi-headed dimension. If + shapes are different, this is equivalent to converting the last two dims + of the smaller matrix into a block-diagonal matrix with the same shape as + the last two dims of the larger matrix. + + E.g. if A is half the size of B, this function will perform + [[A ] @ B + [ A]] + + If B is a third of the size of A, this function will perform + A @ [[B ] + [ B ] + [ B]] + + This function will error out if the shapes are not evenly divisble + + :param A: left-hand tensor + :param B: right-hand tensor + :return: result + """ + if A.shape[-1] > B.shape[-2]: + head_dim = B.shape[-2] + num_heads = A.shape[-1] // head_dim + A = A.unflatten(-1, (num_heads, head_dim)) + return (A @ B).flatten(-2, -1) + elif A.shape[-1] < B.shape[-2]: + head_dim = A.shape[-1] + num_heads = B.shape[-2] // head_dim + B = B.unflatten(-2, (num_heads, head_dim)) + return (A @ B).flatten(-3, -2) + else: + return A @ B diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index e08e4d49..83b2c8a6 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -14,8 +14,13 @@ import pytest import torch +from compressed_tensors.modeling.attention import call_attn_impl from compressed_tensors.transform import TransformArgs, TransformFactory from transformers import PretrainedConfig, PreTrainedModel +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaRotaryEmbedding, +) class TransformableModel(PreTrainedModel): @@ -34,65 +39,46 @@ def forward(self, x): return x -class MockAttention(torch.nn.Module): +class MockAttentionModel(PreTrainedModel): def __init__( - self, hidden_size: int, num_attention_heads: int, num_key_value_heads: int + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + skip_pos_embeddings: bool = False, + attn_implementation: str = "eager", ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - - self.num_key_value_groups = num_attention_heads // num_key_value_heads - self.head_dim = hidden_size // num_attention_heads - self.scaling = self.head_dim**-0.5 - assert hidden_size >= num_attention_heads * self.head_dim - - self.q_proj = torch.nn.Linear( - hidden_size, num_attention_heads * self.head_dim, bias=False - ) - self.k_proj = torch.nn.Linear( - hidden_size, num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = torch.nn.Linear( - hidden_size, num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = torch.nn.Linear( - num_attention_heads * self.head_dim, hidden_size, bias=False + config = PretrainedConfig( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=0.0, + attention_bias=False, + max_position_embeddings=128, + rope_theta=500000.0, + _attn_implementation_internal=attn_implementation, + _attn_implementation_autoset=False, ) + super().__init__(config) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.attn = LlamaAttention(config, layer_idx=0) + self.skip_pos_embeddings = skip_pos_embeddings def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, hidden_size = hidden_states.shape - hidden_shape = (batch_size, seq_len, -1, self.head_dim) + assert hidden_states.size(1) <= self.config.max_position_embeddings - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not self.skip_pos_embeddings: + position_ids = torch.arange(hidden_states.size(1)).unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + zeros = torch.zeros(hidden_states.size(1), dtype=hidden_states.dtype) + position_embeddings = (zeros, zeros) - key_states = self.repeat_kv(key_states, self.num_key_value_groups) - value_states = self.repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = ( - torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + attn_output, _attn_weights = self.attn( + hidden_states, position_embeddings=position_embeddings, attention_mask=None ) - attn_weights = torch.nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape((batch_size, seq_len, -1)).contiguous() - - return self.o_proj(attn_output) - - def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + return attn_output @pytest.fixture(scope="function") diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index c0225636..f5face63 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -22,7 +22,7 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch -from tests.test_transform.conftest import MockAttention +from tests.test_transform.conftest import MockAttentionModel from tests.testing_utils import requires_accelerate, requires_gpu @@ -122,33 +122,53 @@ def test_correctness_attention_heads(type, randomize, head_dim): hidden_size = 64 num_attention_heads = 8 - attention = MockAttention( + model = MockAttentionModel( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_key_value_heads=head_dim, + skip_pos_embeddings=False, + attn_implementation="eager", # TODO: fails with sdpa ) input = torch.rand(17, 5, hidden_size) - true_output = attention(input) + true_output = model(input) config = TransformConfig( config_groups={ - "": TransformScheme( + # "R3": TransformScheme( + # type=type, + # randomize=randomize, + # head_dim=head_dim, + # apply=[ + # TransformArgs(targets="attn.q_proj", location="output"), + # TransformArgs(targets="attn.k_proj", location="output"), + # ], + # ), + "R3": TransformScheme( + type=type, + randomize=randomize, + head_dim=head_dim, + apply=[ + TransformArgs(targets="attn", location="attn_q"), + TransformArgs(targets="attn", location="attn_k"), + ], + ), + "R2": TransformScheme( type=type, randomize=randomize, head_dim=head_dim, apply=[ - TransformArgs(targets="v_proj", location="weight_output"), + TransformArgs(targets="attn.v_proj", location="weight_output"), TransformArgs( - targets="o_proj", location="weight_input", inverse=True + targets="attn.o_proj", location="weight_input", inverse=True ), ], - ) + ), } ) - apply_transform_config(attention, config) + apply_transform_config(model, config) - output = attention(input) + output = model(input) assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) From 4f03325d568327651dec422f2127672e03aff0f7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 17 Jul 2025 18:18:02 -0400 Subject: [PATCH 4/4] merge in model_decompress_safeguard Signed-off-by: Kyle Sayers --- .../model_compressors/model_compressor.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index ff341a68..bb3444a0 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -392,15 +392,18 @@ def compress_model(self, model: Module): for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): if prefix in module_to_scheme or prefix in sparse_compression_targets: - module_device = get_execution_device(module).type - is_meta = module_device == "meta" + module_device = get_execution_device(module) + is_meta = module_device.type == "meta" exec_device = "meta" if is_meta else "cpu" onloading_device = "meta" if is_meta else module_device # in the future, support compression on same device with align_module_device(module, execution_device=exec_device): - state_dict = module.state_dict(prefix=f"{prefix}.") + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } # quantization first if prefix in module_to_scheme: @@ -421,7 +424,7 @@ def compress_model(self, model: Module): # remove any existing parameters offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters()): + for name, _ in list(module.named_parameters(recurse=False)): delete_offload_parameter(module, name) # replace with compressed parameters @@ -458,7 +461,10 @@ def decompress_model(self, model: Module): if prefix in module_to_scheme or prefix in sparse_compression_targets: # in the future, support decompression on same device with align_module_device(module, execution_device="cpu"): - state_dict = module.state_dict(prefix=f"{prefix}.") + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } # sparsity first if prefix in sparse_compression_targets: @@ -483,7 +489,7 @@ def decompress_model(self, model: Module): # remove any existing parameters exec_device = get_execution_device(module) offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters()): + for name, _ in list(module.named_parameters(recurse=False)): delete_offload_parameter(module, name) # replace with decompressed parameters @@ -747,12 +753,16 @@ def _replace_weights(self, dense_weight_generator, model: Module): def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]: """ - Returns a dictionary which maps quantized module names to their quantization schemes + Returns a dictionary which maps quantized module names to their quantization + schemes. Only includes modules with weight quantization """ return { fix_fsdp_module_name(name): module.quantization_scheme for name, module in model.named_modules() - if is_module_quantized(module) + if ( + hasattr(module, "quantization_scheme") and + module.quantization_scheme.weights is not None + ) } @@ -785,4 +795,4 @@ def override_quantization_status( try: yield finally: - config.quantization_status = original_status + config.quantization_status = original_status \ No newline at end of file