Skip to content

[Transforms] Transform Registry Support #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: dsikka/transform_support
Choose a base branch
from

Conversation

dsikka
Copy link
Collaborator

@dsikka dsikka commented Mar 11, 2025

Summary:

  • Add initial support for transforms through the addition of the Transforms registry as well as three commonly used transformations: randomized Hadamard transform, deterministic Hadamard transform, and matrix multiplication
  • Transforms are currently set-up to use a minimal set of arguments required for initialization (either during calibration or after deserialization) and during runtime. However, we may extend these arguments as we make updates to the recipe support, etc
  • Furthermore, for each transform, we need to define the "untransform" i.e the inverse operation that be applied during runtime. This will differ depending on the transform type and therefore is defined by each registered transform through its inverse_apply method. A helper method fetch_inverse_apply has been set-up to fetch the method during runtime
  • Furthermore, the transforms also support an empty option which if set to True, allows the transform classes to initialize an empty parameter onto which the deserialized transforms can be loaded. If False, a new transform is created and returned
  • The transforms are set-up so that going forward, their values can be optimized. For this reason, they are set-up as parameters
  • For the hadamard transforms, the matrix produced has the following property:
$$HH.T == nI$$
  • Where H is the hadamard rotation. When multiplied by its transpose, the matrix produced is a multiple of the identity. For the randomized hadamard, the matrix is normalized while for the deterministic transform, it is not.

  • A series of utilities have been added to support the generation of the randomized hadamard matrix, leveraged from the SpinQuant and QuaRot repositories. The utilities include hadamard matrices for commonly used sizes. So far, only had12 and had20 have been included which have saved in the files in their packed bit format to potentially experiment with other ways we could store the matrices before we decide to add more sizes. Suggestions are welcome!

  • When the preset hadamards are used, the matrices can be unpacked using the following code:

import numpy 
from compressed_tensors.transforms.hadamard_utils import  _get_had12

original_size = 12

packed_bits = numpy.array([128,  13,  29, 232, 235,  71, 218,  
        62, 209, 246, 139, 180, 157, 168, 237, 199, 106,  59], dtype=numpy.uint8)
had_unpacked = numpy.unpackbits(packed_bits) # get back 0s/1s
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked] # map  0 to -1
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size)) # reshape

Example Use:

from compressed_tensors.transforms import Transforms
import torch

# Create
size = 2048
hadamard_transform = Transforms.load_from_registry(
    "hadamard", size=size, dtype=dtype
)

# Add to module
module = torch.nn.Linear(size, size)
hadamard_transform.register_to_module(module, "weight_transform")

# Apply during runtime
transformed_weight = hadamard_transform.apply(module.weight)

# Apply inverse operation to remove the effect of the transform on the weight
original_weight = transform.inverse_apply(input_tensor=transformed_output, transpose=True)

To Do / Future:

  • We do not have to necessarily serialize the entire hadamard matrix for either the randomized or deterministic case - we could potentially update such that we take in a randomly generated seed instead and use that for creation and serialization along with the size

@dsikka dsikka changed the title transform registry support [Transforms] Transform Registry Support Mar 11, 2025
@dsikka dsikka changed the base branch from main to dsikka/transform_support March 11, 2025 21:06
@dsikka dsikka changed the base branch from dsikka/transform_support to main March 11, 2025 21:06
@dsikka dsikka changed the base branch from main to dsikka/transform_support March 11, 2025 21:06
@dsikka dsikka marked this pull request as ready for review March 11, 2025 21:23
Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending normalization of name in fetch_apply

Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meant to request changes

Copy link
Member

@markurtz markurtz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The registry setup looks good to me, especially towards future abilities to able to cache the Hadamard matrices. I have a few high-level concerns regarding the current architecture, though, and believe some adjustments could significantly enhance flexibility and usability. At present, it feels primarily like a convenient wrapper around functional calls, shifting much of the complexity and implementation responsibility onto the calling code rather than encapsulating it. Specifically:

  • External State Management: Currently, the required state for a transform is handled external to the Transform class forcing callers to manage parameters, state, and execution flow explicitly and reducing the transforms to simple functional wrappers.
  • Static Methods and Extensibility: Implementing transforms solely as static methods severely limits subclassing and future extensibility. At a minimum, they should be class methods to allow future override behavior when needed, but ideally it would be instance methods supporting better OOP flows as mentioned in the other feedback.
  • Alignment with PyTorch standards: The current implementation doesn't leverage standard PyTorch patterns such as Modules, Hooks, or Parameterization API. With that, it misses out on some key ecosystem benefits like handling parameters and devices.
  • Separate Forward and Inverse Operations: Combining forward and inaverse operations into a single transform reduces readability, IMO, and deviates form typical PyTorch patterns such as the standard setup with Quant and DeQuant. By setting up the responsibilities for a single instance to worry about just a forward, it's much easier to construct graphs and not need additional conditional call logic in the caller code for when to inject and invert.
  • Hardcoded Single Parameter Assumption: The current implementation assumes a single, hardcoded parameter managed externally restricting flexibility and easy extension to the modifiers either for different types or multiple of the same type.
  • Lack of Iterative Update Mechanisms: There are currently no straightforward mechanisms provided for easily updating transform values which are crucial for iterative algorithms like SpinQuant.

An ideal solution in my mind would expand out what exists here and implement a more native PyTorch solution that would also be more object oriented as noted in some of the other review comments. With that, I think that would boil down to either wrapper modules for transforms or utilizing hooks and the parameterization API. I've included an example implementation of the latter below, which would enable the following simple representation for a hadamard quantization setup:

layer = Linear(...)
HadamardTransform(layer, "weight", ...)
QuantizationTransform(layer, "weight", ...)
DeQuantizationTransform(layer, "weight", ...)
InverseHadamardTransform(layer, "weight", ...)

Ultimatley we can utilize the above in numerous places and it significantly simplifies the logic needed for the calling code. It sets up for a black box API for the caller and all it needs to know are what transforms to apply in that case. Additionally, it's easily and quickly extensible to create channelwise scaling for things like AWQ and SmoothQuant, Pruning, etc

from abc import ABC, abstractmethod
from typing import Any, Optional, Union, Literal
from pytest import param
import torch
from torch.nn import Module
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
from compressed_tensors.registry.registry import RegistryMixin


class Transforms(ABC, Module, RegistryMixin):
    def __init__(
        self, module: Module, target: Union[Literal["inputs", "outputs"], str]
    ):
        self.module = module
        self.target = target
        self.hooks = []
        self.add_hooks()

    def add_hooks(self):
        if self.target == "inputs":
            self.hooks.append(
                self.module.register_forward_pre_hook(
                    lambda module, inputs: self.forward(inputs)
                )
            )
        elif self.target == "outputs":
            self.hooks.append(
                self.module.register_forward_hook(
                    lambda module, inputs, outputs: self.forward(outputs)
                )
            )
        else:
            # targeting a parameter
            register_parametrization(self.module, self.target, self)

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

        # TODO: figure out how to remove just this parametrization and not all
        remove_parametrizations(self.module, self.target)

    @abstractmethod
    def update(*args, **kwargs): ...

    @abstractmethod
    def forward(self, inputs: Any) -> Any: ...


@Transforms.register("hadamard")
class HadamardTransform(Transforms):
    def __init__(
        self,
        module: Module,
        target: Union[Literal["inputs", "outputs"], str],
        size: int,
        left_multiply: bool = False,
        preinitialize: bool = True,
        device: Optional[Union[str, torch.device]] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__(module, target)
        self.target = target
        self.size = size
        self.left_multiply = left_multiply
        self.module.register_buffer(
            name=self.hadamard_buffer_name,
            tensor=torch.empty((size, size), device=device, dtype=dtype),
        )
        self.module.register_buffer(
            name=self.permute_buffer_name,
            tensor=torch.empty((size, size), device=device, dtype=dtype),
        )

        if preinitialize:
            self.update(
                hadamard=random_hadamard_matrix(size),
                permutation=torch.diagonal(torch.eye(size)),
            )

    @property
    def hadamard_buffer_name(self) -> str:
        return f"{self.target}_hadamard_transform"

    @property
    def permute_buffer_name(self) -> str:
        return f"{self.target}_hadamard_permute_transform"

    def update(
        self,
        hadamard: Optional[torch.Tensor] = None,
        permutation: Optional[torch.Tensor] = None,
    ):
        if hadamard is not None:
            getattr(self.module, self.hadamard_buffer_name).copy_(hadamard)
        if permutation is not None:
            getattr(self.module, self.permute_buffer_name).copy_(permutation)

    def forward(self, inputs: Any) -> Any:
        hadamard = getattr(self.module, self.hadamard_buffer_name)
        permutation = getattr(self.module, self.permute_buffer_name)

        if self.left_multiply:
            return hadamard @ inputs @ permutation
        else:
            return inputs @ permutation @ hadamard


@Transforms.register("hadamard_inverse")
class HadamardInverseTransform(HadamardTransform):
    def forward(self, inputs: Any) -> Any:
        hadamard = getattr(self.module, self.hadamard_buffer_name)
        permutation = getattr(self.module, self.permute_buffer_name)

        if self.left_multiply:
            return permutation.T @ hadamard @ inputs
        else:
            return inputs @ hadamard @ permutation.T


@Transforms.register("quantization")
class QuantizationTransform(Transforms):
    def __init__(
        self,
        module: Module,
        target: Union[Literal["inputs", "outputs"], str],
        scale: float,
        zero_point: int,
        quant_dtype: torch.dtype,
    ):
        super().__init__(module, target)
        self.module.register_buffer(
            name=self.zp_buffer_name, tensor=torch.tensor(zero_point, dtype=torch.int)
        )
        self.module.register_buffer(
            name=self.scale_buffer_name, tensor=torch.tensor(scale, dtype=torch.float)
        )
        self.quant_dtype = quant_dtype

    @property
    def zp_buffer_name(self) -> str:
        return f"{self.target}_zero_point"

    @property
    def scale_buffer_name(self) -> str:
        return f"{self.target}_scale"

    def update(self, scale: float, zero_point: int):
        getattr(self.module, self.zp_buffer_name).fill_(zero_point)
        getattr(self.module, self.scale_buffer_name).fill_(
            torch.tensor(scale, dtype=torch.float)
        )

    def forward(self, inputs: Any) -> Any:
        scale = getattr(self.module, self.scale_buffer_name)
        zero_point = getattr(self.module, self.zp_buffer_name)
        quantized = torch.clamp(
            torch.round(inputs / scale) + zero_point,
            torch.iinfo(self.quant_dtype).min,
            torch.iinfo(self.quant_dtype).max,
        )

        return quantized.to(self.quant_dtype)


@Transforms.register("dequantization")
class DeQuantizationTransform(Transforms):
    def __init__(
        self,
        module: Module,
        target: Union[Literal["inputs", "outputs"], str],
        scale: float,
        zero_point: int,
        quant_dtype: torch.dtype,
        dtype: torch.dtype,
    ):
        super().__init__(module, target)
        self.scale = scale
        self.zero_point = zero_point
        self.quant_dtype = quant_dtype
        self.dtype = dtype

    def update(self, scale: float, zero_point: int):
        self.scale = scale
        self.zero_point = zero_point

    def forward(self, inputs: Any) -> Any:
        scale = torch.tensor(self.scale, dtype=torch.float)
        zero_point = torch.tensor(self.zero_point, dtype=torch.int)
        dequantized = (inputs - zero_point) * scale

        return dequantized.to(self.dtype)

@dsikka
Copy link
Collaborator Author

dsikka commented Mar 20, 2025

The registry setup looks good to me, especially towards future abilities to able to cache the Hadamard matrices. I have a few high-level concerns regarding the current architecture, though, and believe some adjustments could significantly enhance flexibility and usability. At present, it feels primarily like a convenient wrapper around functional calls, shifting much of the complexity and implementation responsibility onto the calling code rather than encapsulating it. Specifically:

  • External State Management: Currently, the required state for a transform is handled external to the Transform class forcing callers to manage parameters, state, and execution flow explicitly and reducing the transforms to simple functional wrappers.
  • Static Methods and Extensibility: Implementing transforms solely as static methods severely limits subclassing and future extensibility. At a minimum, they should be class methods to allow future override behavior when needed, but ideally it would be instance methods supporting better OOP flows as mentioned in the other feedback.
  • Alignment with PyTorch standards: The current implementation doesn't leverage standard PyTorch patterns such as Modules, Hooks, or Parameterization API. With that, it misses out on some key ecosystem benefits like handling parameters and devices.
  • Separate Forward and Inverse Operations: Combining forward and inaverse operations into a single transform reduces readability, IMO, and deviates form typical PyTorch patterns such as the standard setup with Quant and DeQuant. By setting up the responsibilities for a single instance to worry about just a forward, it's much easier to construct graphs and not need additional conditional call logic in the caller code for when to inject and invert.
  • Hardcoded Single Parameter Assumption: The current implementation assumes a single, hardcoded parameter managed externally restricting flexibility and easy extension to the modifiers either for different types or multiple of the same type.
  • Lack of Iterative Update Mechanisms: There are currently no straightforward mechanisms provided for easily updating transform values which are crucial for iterative algorithms like SpinQuant.

An ideal solution in my mind would expand out what exists here and implement a more native PyTorch solution that would also be more object oriented as noted in some of the other review comments. With that, I think that would boil down to either wrapper modules for transforms or utilizing hooks and the parameterization API. I've included an example implementation of the latter below, which would enable the following simple representation for a hadamard quantization setup:

layer = Linear(...) HadamardTransform(layer, "weight", ...) QuantizationTransform(layer, "weight", ...) DeQuantizationTransform(layer, "weight", ...) InverseHadamardTransform(layer, "weight", ...)

Ultimatley we can utilize the above in numerous places and it significantly simplifies the logic needed for the calling code. It sets up for a black box API for the caller and all it needs to know are what transforms to apply in that case. Additionally, it's easily and quickly extensible to create channelwise scaling for things like AWQ and SmoothQuant, Pruning, etc

from abc import ABC, abstractmethod
from typing import Any, Optional, Union, Literal
from pytest import param
import torch
from torch.nn import Module
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
from compressed_tensors.registry.registry import RegistryMixin


class Transforms(ABC, Module, RegistryMixin):
    def __init__(
        self, module: Module, target: Union[Literal["inputs", "outputs"], str]
    ):
        self.module = module
        self.target = target
        self.hooks = []
        self.add_hooks()

    def add_hooks(self):
        if self.target == "inputs":
            self.hooks.append(
                self.module.register_forward_pre_hook(
                    lambda module, inputs: self.forward(inputs)
                )
            )
        elif self.target == "outputs":
            self.hooks.append(
                self.module.register_forward_hook(
                    lambda module, inputs, outputs: self.forward(outputs)
                )
            )
        else:
            # targeting a parameter
            register_parametrization(self.module, self.target, self)

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

        # TODO: figure out how to remove just this parametrization and not all
        remove_parametrizations(self.module, self.target)

    @abstractmethod
    def update(*args, **kwargs): ...

    @abstractmethod
    def forward(self, inputs: Any) -> Any: ...


@Transforms.register("hadamard")
class HadamardTransform(Transforms):
    def __init__(
        self,
        module: Module,
        target: Union[Literal["inputs", "outputs"], str],
        size: int,
        left_multiply: bool = False,
        preinitialize: bool = True,
        device: Optional[Union[str, torch.device]] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__(module, target)
        self.target = target
        self.size = size
        self.left_multiply = left_multiply
        self.module.register_buffer(
            name=self.hadamard_buffer_name,
            tensor=torch.empty((size, size), device=device, dtype=dtype),
        )
        self.module.register_buffer(
            name=self.permute_buffer_name,
            tensor=torch.empty((size, size), device=device, dtype=dtype),
        )

        if preinitialize:
            self.update(
                hadamard=random_hadamard_matrix(size),
                permutation=torch.diagonal(torch.eye(size)),
            )

    @property
    def hadamard_buffer_name(self) -> str:
        return f"{self.target}_hadamard_transform"

    @property
    def permute_buffer_name(self) -> str:
        return f"{self.target}_hadamard_permute_transform"

    def update(
        self,
        hadamard: Optional[torch.Tensor] = None,
        permutation: Optional[torch.Tensor] = None,
    ):
        if hadamard is not None:
            getattr(self.module, self.hadamard_buffer_name).copy_(hadamard)
        if permutation is not None:
            getattr(self.module, self.permute_buffer_name).copy_(permutation)

    def forward(self, inputs: Any) -> Any:
        hadamard = getattr(self.module, self.hadamard_buffer_name)
        permutation = getattr(self.module, self.permute_buffer_name)

        if self.left_multiply:
            return hadamard @ inputs @ permutation
        else:
            return inputs @ permutation @ hadamard


@Transforms.register("hadamard_inverse")
class HadamardInverseTransform(HadamardTransform):
    def forward(self, inputs: Any) -> Any:
        hadamard = getattr(self.module, self.hadamard_buffer_name)
        permutation = getattr(self.module, self.permute_buffer_name)

        if self.left_multiply:
            return permutation.T @ hadamard @ inputs
        else:
            return inputs @ hadamard @ permutation.T


@Transforms.register("quantization")
class QuantizationTransform(Transforms):
    def __init__(
        self,
        module: Module,
        target: Union[Literal["inputs", "outputs"], str],
        scale: float,
        zero_point: int,
        quant_dtype: torch.dtype,
    ):
        super().__init__(module, target)
        self.module.register_buffer(
            name=self.zp_buffer_name, tensor=torch.tensor(zero_point, dtype=torch.int)
        )
        self.module.register_buffer(
            name=self.scale_buffer_name, tensor=torch.tensor(scale, dtype=torch.float)
        )
        self.quant_dtype = quant_dtype

    @property
    def zp_buffer_name(self) -> str:
        return f"{self.target}_zero_point"

    @property
    def scale_buffer_name(self) -> str:
        return f"{self.target}_scale"

    def update(self, scale: float, zero_point: int):
        getattr(self.module, self.zp_buffer_name).fill_(zero_point)
        getattr(self.module, self.scale_buffer_name).fill_(
            torch.tensor(scale, dtype=torch.float)
        )

    def forward(self, inputs: Any) -> Any:
        scale = getattr(self.module, self.scale_buffer_name)
        zero_point = getattr(self.module, self.zp_buffer_name)
        quantized = torch.clamp(
            torch.round(inputs / scale) + zero_point,
            torch.iinfo(self.quant_dtype).min,
            torch.iinfo(self.quant_dtype).max,
        )

        return quantized.to(self.quant_dtype)


@Transforms.register("dequantization")
class DeQuantizationTransform(Transforms):
    def __init__(
        self,
        module: Module,
        target: Union[Literal["inputs", "outputs"], str],
        scale: float,
        zero_point: int,
        quant_dtype: torch.dtype,
        dtype: torch.dtype,
    ):
        super().__init__(module, target)
        self.scale = scale
        self.zero_point = zero_point
        self.quant_dtype = quant_dtype
        self.dtype = dtype

    def update(self, scale: float, zero_point: int):
        self.scale = scale
        self.zero_point = zero_point

    def forward(self, inputs: Any) -> Any:
        scale = torch.tensor(self.scale, dtype=torch.float)
        zero_point = torch.tensor(self.zero_point, dtype=torch.int)
        dequantized = (inputs - zero_point) * scale

        return dequantized.to(self.dtype)

@markurtz
Please look through the remaining PRs to understand how the transforms are applied to the module's parameters. This was also illustrated in the original design documents that were used to scope out these PRs.

Alignment with PyTorch standards: The current implementation doesn't leverage standard PyTorch patterns such as Modules, Hooks, or Parameterization API. With that, it misses out on some key ecosystem benefits like handling parameters and devices

The transforms exactly make use of Parameterization and therefore, directly benefit from device handling. Again, please look through the remaining PRs, especially: #276.
Their utilization in the context of Modules/Hooks is what will be implemented through Activation Support.

Separate Forward and Inverse Operations: Combining forward and inverse operations into a single transform reduces readability, IMO, and deviates form typical PyTorch patterns such as the standard setup with Quant and DeQuant. By setting up the responsibilities for a single instance to worry about just a forward, it's much easier to construct graphs and not need additional conditional call logic in the caller code for when to inject and invert.

The forward and inverse operations make use of the same matrix/data but are two separate operations, as implemented by apply and inverse_apply. They are following the same pattern as Quant/DeQuant.

Hardcoded Single Parameter Assumption: The current implementation assumes a single, hardcoded parameter managed externally restricting flexibility and easy extension to the modifiers either for different types or multiple of the same type.

I am not sure what this is referring to. Each transform class is dedicated to one type of transform which is responsible for its specific set-up.

I strongly disagree that the registry should be an extension of the Module class.

I do agree that there should be functionality to support updating.

Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit: LGTM

# If updating the module parameter data, assumes this is also the transform
# data
if name is None:
raise ValueError("Name and module are required to update parma data")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param data?

def inverse_apply(
self,
input_tensor: torch.Tensor,
transpose: bool = False,
Copy link
Contributor

@kylesayrs kylesayrs Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having first and call_args in general seems to be an extensible pattern, but maybe transpose should not be a call arg, as it seems to be determined by the specific Transform and whether you're applying inverse, not something that a caller should have control over?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaict transpose is only used when apply an inverse hadamard, and therefore should not be controllable by the user to avoid footgunning

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants