Skip to content

[Transforms] Transform Registry Support #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: dsikka/transform_support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 46 additions & 27 deletions src/compressed_tensors/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from compressed_tensors.registry.registry import RegistryMixin
from compressed_tensors.transforms.utils import apply_matrix_transform
from compressed_tensors.utils import register_offload_parameter, update_parameter_data


__all__ = ["Transforms"]
Expand All @@ -27,18 +28,16 @@
# first or second matirx in torch.matmul depending on dimensions, can be inferred
# by the layer time likely.

MATIRX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"]


class Transforms(RegistryMixin):
def __new__(
cls,
def __init__(
self,
transform: torch.Tensor,
learnable: Optional[bool] = True,
device: Optional[Union[str, torch.device]] = "cuda",
dtype: Optional[torch.dtype] = torch.bfloat16,
*args,
**kwargs,
):
self.learnable = learnable
"""
Base class for setting up transforms. The registry creates transforms
as parameters which can be attached to modules.
Expand All @@ -48,38 +47,58 @@ def __new__(
size = 1024
dtype = torch.bfloat16
module = torch.nn.Linear(size, size)
name = "weight_transform"

hadamard_transform = Transforms.load_from_registry(
"random_hadamard", size=size, dtype=dtype
)
hadamard_apply = Transforms.fetch_apply("random_hadamard")
module.weight_transform = hadamard_transform

transformed_output = hadamard_apply(input_tensor=module.weight,
transform=moduel.weight_transform)
hadamard_transform.register_to_module(name, module)
module.transform_data = {name: {"call_args": dict, "class": hadamard_transform}}

hadamard_inverse = Transforms.fetch_inverse_apply("random_hadamard")
original_weight = hadamard_inverse(input_tensor=transformed_output,
transform=model.weight_trainsform,
transpose=True)
transformed_output = hadamard_transform.apply(input_tensor=module.weight)
original_weight = hadamard_transform.inverse_apply(
input_tensor=transformed_output)

:param transform: transform (e.g. torch.Tensor, scalar) to be applied
"""
return torch.nn.Parameter(transform.to(device).to(dtype), requires_grad=False)

@classmethod
def fetch_apply(cls, name: str):
if name in MATIRX_TRANSFORMS:
return apply_matrix_transform
raise NotImplementedError("Only matrix transforms are supported")

@classmethod
def fetch_inverse_apply(cls, name: str):
return cls.get_value_from_registry(name=name).inverse_apply
if self.learnable:
self.transform = torch.nn.Parameter(transform.to(dtype).to(device))
else:
self.transform = torch.nn.Buffer(transform.to(dtype).to(device))

# register to class for easy offloading, serialization, deserialization
def register_to_module(self, name: str, module: torch.nn.Module):
if self.learnable:
register_offload_parameter(module, name, self.transform)
else:
# TODO: have to verify serialization/offloading
module.register_buffer(name, self.transform)

def update_transform(
self,
data: torch.Tensor,
module: Optional[torch.nn.Module] = None,
name: Optional[str] = None,
):
if module is None:
self.transform.data.copy_(data)
else:
# If updating the module parameter data, assumes this is also the transform
# data
if name is None:
raise ValueError("Name and module are required to update parma data")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param data?

update_parameter_data(module, data, name)

def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Apply the transform to the module
"""
raise NotImplementedError()

@staticmethod
# TODO: potentially split into its own transform using the same shared set-up
def inverse_apply(
transform: torch.Tensor, input_tensor: torch.Tensor, *args, **kwargs
self, input_tensor: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""
Apply the inverse operation applied by the apply method
Expand Down
28 changes: 21 additions & 7 deletions src/compressed_tensors/transforms/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@

@Transforms.register("hadamard")
class Hadamard(Transforms):
def __new__(
cls,
def __init__(
self,
size: int,
empty: Optional[bool] = False,
device: Optional[Union[str, torch.device]] = "cuda",
dtype: Optional[torch.dtype] = torch.bfloat16,
*args,
**kwargs,
):
"""
Produces a hadamard matrix with dims (size, size), with values
Expand All @@ -50,11 +52,23 @@ def __new__(
else:
transform = torch.empty((size, size))

return super().__new__(cls, transform=transform, device=device, dtype=dtype)
super().__init__(transform=transform, dtype=dtype, device=device)

def apply(
self,
input_tensor: torch.Tensor,
transpose: bool = False,
first: bool = True,
) -> torch.Tensor:
return apply_matrix_transform(
transform=self.transform,
input_tensor=input_tensor,
transpose=transpose,
first=first,
)

@staticmethod
def inverse_apply(
transform: torch.Tensor,
self,
input_tensor: torch.Tensor,
transpose: bool = False,
Copy link
Contributor

@kylesayrs kylesayrs Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having first and call_args in general seems to be an extensible pattern, but maybe transpose should not be a call arg, as it seems to be determined by the specific Transform and whether you're applying inverse, not something that a caller should have control over?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaict transpose is only used when apply an inverse hadamard, and therefore should not be controllable by the user to avoid footgunning

first: bool = True,
Expand All @@ -73,10 +87,10 @@ def inverse_apply(
# need to normalize before sending back
return (
apply_matrix_transform(
transform=transform,
transform=self.transform,
input_tensor=input_tensor,
transpose=transpose,
first=first,
)
/ transform.shape[0]
/ self.transform.shape[0]
)
20 changes: 16 additions & 4 deletions src/compressed_tensors/transforms/matrix_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,27 @@

import torch
from compressed_tensors.transforms import Transforms
from compressed_tensors.transforms.utils import apply_matrix_transform


# TODO: fix loading
@Transforms.register("matrix-mul")
class MatrixMultiply(Transforms):
@staticmethod
def apply(
self,
input_tensor: torch.Tensor,
transpose: bool = False,
first: bool = True,
) -> torch.Tensor:
return apply_matrix_transform(
transform=self.transform,
input_tensor=input_tensor,
transpose=transpose,
first=first,
)

def inverse_apply(
transform: torch.Tensor,
self,
input_tensor: torch.Tensor,
transpose: bool = False,
first: bool = True,
Expand All @@ -38,9 +51,8 @@ def inverse_apply(
"""

# Note: not implemented for lower precision than float32
transform = torch.linalg.inv(transform)
return apply_matrix_transform(
transform=transform,
transform=torch.linalg.inv(self.transform),
input_tensor=input_tensor,
transpose=transpose,
first=first,
Expand Down
24 changes: 18 additions & 6 deletions src/compressed_tensors/transforms/random_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

@Transforms.register("random-hadamard")
class RandomHadamard(Transforms):
def __new__(
cls,
def __init__(
self,
size: int,
empty: Optional[bool] = False,
device: Optional[Union[str, torch.device]] = "cuda",
Expand Down Expand Up @@ -58,11 +58,23 @@ def __new__(
else:
transform = torch.empty((size, size))

return super().__new__(cls, transform=transform, device=device, dtype=dtype)
super().__init__(transform=transform, device=device, dtype=dtype)

def apply(
self,
input_tensor: torch.Tensor,
transpose: bool = False,
first: bool = True,
) -> torch.Tensor:
return apply_matrix_transform(
transform=self.transform,
input_tensor=input_tensor,
transpose=transpose,
first=first,
)

@staticmethod
def inverse_apply(
transform: torch.Tensor,
self,
input_tensor: torch.Tensor,
transpose: bool = False,
first: bool = True,
Expand All @@ -80,7 +92,7 @@ def inverse_apply(

transpose = not transpose
return apply_matrix_transform(
transform=transform,
transform=self.transform,
input_tensor=input_tensor,
transpose=transpose,
first=first,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transforms/test_hadamards.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_random_hadamard_matrix_compliant(size):

@pytest.mark.parametrize(
"size",
[1024, 2048],
[1024],
)
def test_deterministic_hadamard_compliant(size):
had_matrix = deterministic_hadamard_matrix(size)
Expand Down
21 changes: 9 additions & 12 deletions tests/test_transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,18 @@ def test_random_hadamard_transform(size: int, dtype: torch.dtype):
# check initialize
assert hadamard_transform is not None

val_1 = torch.round(hadamard_transform @ hadamard_transform.T)
val_1 = torch.round(hadamard_transform.transform @ hadamard_transform.transform.T)

# output will be normalized, multiply by sqrt(size) to ensure form
normalized = math.sqrt(size) * hadamard_transform
normalized = math.sqrt(size) * hadamard_transform.transform
# all values should be -1 or +1
assert torch.all(torch.isin(normalized, torch.Tensor([-1, +1])))
# check creation; HH.T == I
assert torch.equal(val_1, torch.eye(size))

# check apply
x = torch.rand((size, size), dtype=dtype)
apply = Transforms.fetch_apply("random-hadamard")
transformed_value = apply(input_tensor=x, transform=hadamard_transform)
transformed_value = hadamard_transform.apply(input_tensor=x)
# TODO: check to make sure the matrix was applied correctly?
assert transformed_value.shape == (size, size)

Expand All @@ -75,16 +74,15 @@ def test_deterministic_hadamard_transform(size: int, dtype: torch.dtype):

# check initialize
assert hadamard_transform is not None
assert torch.all(torch.isin(hadamard_transform, torch.Tensor([-1, +1])))
assert torch.all(torch.isin(hadamard_transform.transform, torch.Tensor([-1, +1])))

val_1 = hadamard_transform @ hadamard_transform.T
val_1 = hadamard_transform.transform @ hadamard_transform.transform.T
# check creation; HH.T == nI
assert torch.equal(val_1 / size, torch.eye(size))

# check apply
x = torch.rand((size, size), dtype=dtype)
apply = Transforms.fetch_apply("hadamard")
transformed_value = apply(input_tensor=x, transform=hadamard_transform)
transformed_value = hadamard_transform.apply(input_tensor=x)
# TODO: check to make sure the matrix was applied correctly?
assert transformed_value.shape == (size, size)

Expand All @@ -103,9 +101,8 @@ def test_multiplier_transform(size: int, dtype: torch.dtype):
"matrix-mul", transform=multiplier, device="cpu", dtype=dtype
)
assert multiplier_transform is not None
assert torch.equal(multiplier_transform, multiplier)
assert torch.equal(multiplier_transform.transform, multiplier)

x = torch.rand((size, size), dtype=dtype)
apply = Transforms.fetch_apply("matrix-mul")
transformed_value = apply(input_tensor=x, transform=multiplier_transform)
assert torch.equal(transformed_value, x)
transformed_output = multiplier_transform.apply(x)
assert torch.equal(transformed_output, x)