Skip to content

[Transforms] Transform Registry Support #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: dsikka/transform_support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/compressed_tensors/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import Transforms
from .hadamard import Hadamard
from .matrix_multiply import MatrixMultiply
from .random_hadamard import RandomHadamard
87 changes: 87 additions & 0 deletions src/compressed_tensors/transforms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Union

import torch
from compressed_tensors.registry.registry import RegistryMixin
from compressed_tensors.transforms.utils import apply_matrix_transform


__all__ = ["Transforms"]


# TODO: We don't need to save all the __call__ args for serialization or even have
# them defined by a recipe. Some of them, such as if the transformation should be the
# first or second matirx in torch.matmul depending on dimensions, can be inferred
# by the layer time likely.

MATIRX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"]


class Transforms(RegistryMixin):
def __new__(
cls,
transform: torch.Tensor,
device: Optional[Union[str, torch.device]] = "cuda",
dtype: Optional[torch.dtype] = torch.bfloat16,
*args,
**kwargs,
):
"""
Base class for setting up transforms. The registry creates transforms
as parameters which can be attached to modules.

import torch

size = 1024
dtype = torch.bfloat16
module = torch.nn.Linear(size, size)

hadamard_transform = Transforms.load_from_registry(
"random_hadamard", size=size, dtype=dtype
)
hadamard_apply = Transforms.fetch_apply("random_hadamard")
module.weight_transform = hadamard_transform

transformed_output = hadamard_apply(input_tensor=module.weight,
transform=moduel.weight_transform)

hadamard_inverse = Transforms.fetch_inverse_apply("random_hadamard")
original_weight = hadamard_inverse(input_tensor=transformed_output,
transform=model.weight_trainsform,
transpose=True)

:param transform: transform (e.g. torch.Tensor, scalar) to be applied
"""
return torch.nn.Parameter(transform.to(device).to(dtype), requires_grad=False)

@classmethod
def fetch_apply(cls, name: str):
if name in MATIRX_TRANSFORMS:
return apply_matrix_transform
raise NotImplementedError("Only matrix transforms are supported")

@classmethod
def fetch_inverse_apply(cls, name: str):
return cls.get_value_from_registry(name=name).inverse_apply

@staticmethod
def inverse_apply(
transform: torch.Tensor, input_tensor: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""
Apply the inverse operation applied by the apply method
"""
raise NotImplementedError()
82 changes: 82 additions & 0 deletions src/compressed_tensors/transforms/hadamard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import torch
from compressed_tensors.transforms import Transforms
from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix
from compressed_tensors.transforms.utils import apply_matrix_transform


@Transforms.register("hadamard")
class Hadamard(Transforms):
def __new__(
cls,
size: int,
empty: Optional[bool] = False,
device: Optional[Union[str, torch.device]] = "cuda",
dtype: Optional[torch.dtype] = torch.bfloat16,
):
"""
Produces a hadamard matrix with dims (size, size), with values
-1 and 1, and the property HH.T == nI i.e the transformation
matrix when multiplied by its transpose is a multiple of the identity.
All rows and columns are orthonormal. The matrix returned
is not normalized and will be deterministic.

:param size: size of the matrix, if generating a new Hadamard matrix.
The generated matrix will have dimensions (size, size)
:param transform: if loading in a previously generated matrix, will
use that through this transformation, as opposed to creating a new
one
:param dtype: type to cast the rotation matrix to

"""
if not empty:
# TODO: this is deterministic; we should just serialize the size
transform = torch.Tensor(deterministic_hadamard_matrix(size=size))
else:
transform = torch.empty((size, size))

return super().__new__(cls, transform=transform, device=device, dtype=dtype)

@staticmethod
def inverse_apply(
transform: torch.Tensor,
input_tensor: torch.Tensor,
transpose: bool = False,
Copy link
Contributor

@kylesayrs kylesayrs Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having first and call_args in general seems to be an extensible pattern, but maybe transpose should not be a call arg, as it seems to be determined by the specific Transform and whether you're applying inverse, not something that a caller should have control over?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaict transpose is only used when apply an inverse hadamard, and therefore should not be controllable by the user to avoid footgunning

first: bool = True,
) -> torch.Tensor:
"""
Apply the inverse operation of `apply`

:param transform: hadamard tensor
:param input_tensor: tensor to which the transform matrix is applied
:param transpose: whether or not the transform matrix is transposed before
being applied.
:param first: if the transform matrix will be the first or second matrix to be
multiplied
"""
transpose = not transpose
# need to normalize before sending back
return (
apply_matrix_transform(
transform=transform,
input_tensor=input_tensor,
transpose=transpose,
first=first,
)
/ transform.shape[0]
)
166 changes: 166 additions & 0 deletions src/compressed_tensors/transforms/hadamard_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy
import torch


__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]

# adapted from:
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
def deterministic_hadamard_matrix(size: int):
"""
Construct an Hadamard matrix.

Constructs an n-by-n Hadamard matrix, using Sylvester's
construction. `n` must be a power of 2.

:param size: order of the matrix; must be a power of 2

returns a (size, size) hadamard matrix
"""

dtype = int
if size < 1:
lg2 = 0
else:
lg2 = int(math.log(size, 2))
if 2**lg2 != size:
raise ValueError("size must be an positive integer and a power of 2")

H = numpy.array([[1]], dtype=dtype)

# Sylvester's construction
for i in range(0, lg2):
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))

return H


# adapted from:
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py

# TODO: the following library exists for online rotations and should be considered
# in the future:
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master


def random_hadamard_matrix(size: int) -> torch.Tensor:
"""
Produces a randomly generated Hadamard matrix.
See https://cornell-relaxml.github.io/quip-sharp/ ,
Section "Randomized Hadamard Transformation"

:param size: The dimension of the matrix. Matrix generated will have dimensions
(size, size)

"""
# TODO: potentially update to add "seed" as an arugment, to allow
# the matrix generated to be reproducible

# Benefits: support other shapes / non powers of 2, support randomization
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
Q = Q * 2 - 1
Q = torch.diag(Q)
return _matmul_hadU(Q)


def _get_hadK(n, transpose=False):
# NOTE: we can easily extend the list of supported shapes/sizes
# by adding to these methods
hadK, K = None, None
if n % 20 == 0:
assert _is_pow2(n // 20)
K = 20
hadK = _get_had20().T if transpose else _get_had20()
elif n % 12 == 0:
assert _is_pow2(n // 12)
K = 12
hadK = _get_had12().T if transpose else _get_had12()
else:
assert _is_pow2(n)
K = 1

return hadK, K


def _matmul_hadU(X, transpose=False):
n = X.shape[-1]
# Check if we have the determined hadamard matrix
hadK, K = _get_hadK(n, transpose)
# Reshape diag matrix with randomized -1/+1
input = X.clone().view(-1, n, 1)
output = input.clone()

# for cases when hadK is not predetermined, determine hadamard matrix
while input.shape[1] > K:
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
output = output.view(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.view(input.shape[0], input.shape[1], -1)
(input, output) = (output, input)
del output

# K == 1 when hadK is None; this happens when the size dim (n)
# is not comaptible with any of the maintained hadamard matrices

if K > 1:
# Do not explicitly repeat - OOM
# input = torch.bmm(
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
# Use bcast instead

# for cases when hadK is pre-determined
input = hadK.view(1, K, K).to(input) @ input

# normalize
return input.view(X.shape) / torch.tensor(n).sqrt()


def _is_pow2(n):
return (n & (n - 1) == 0) and (n > 0)


def _reshape_bits(packed_bits, original_size):
had_unpacked = numpy.unpackbits(packed_bits)
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
return had_unpacked


# http://www.neilsloane.com/hadamard/index.html
def _get_had12():
# fmt: off
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
# fmt: on
# TODO: just unpack during apply
had_12_unpacked = _reshape_bits(had_12, original_size=12)
return torch.FloatTensor(had_12_unpacked)


def _get_had20():
# fmt: off
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
# fmt: on
# TODO: just unpack during apply
had_20_unpacked = _reshape_bits(had_20, original_size=20)
return torch.FloatTensor(had_20_unpacked)
Loading
Loading