-
Notifications
You must be signed in to change notification settings - Fork 15
[Transforms] Transform Registry Support #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
1adfa30
ab6101e
1e1760b
749420b
7ecb1b0
2988aba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
import torch | ||
from compressed_tensors.registry.registry import RegistryMixin | ||
from compressed_tensors.transforms.utils import apply_matrix_transform | ||
from compressed_tensors.utils import register_offload_parameter, update_parameter_data | ||
|
||
|
||
__all__ = ["Transforms"] | ||
|
@@ -27,18 +28,16 @@ | |
# first or second matirx in torch.matmul depending on dimensions, can be inferred | ||
# by the layer time likely. | ||
|
||
MATIRX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"] | ||
|
||
|
||
class Transforms(RegistryMixin): | ||
def __new__( | ||
cls, | ||
def __init__( | ||
self, | ||
transform: torch.Tensor, | ||
learnable: Optional[bool] = True, | ||
device: Optional[Union[str, torch.device]] = "cuda", | ||
dtype: Optional[torch.dtype] = torch.bfloat16, | ||
*args, | ||
**kwargs, | ||
): | ||
self.learnable = learnable | ||
""" | ||
Base class for setting up transforms. The registry creates transforms | ||
as parameters which can be attached to modules. | ||
|
@@ -48,38 +47,58 @@ def __new__( | |
size = 1024 | ||
dtype = torch.bfloat16 | ||
module = torch.nn.Linear(size, size) | ||
name = "weight_transform" | ||
|
||
hadamard_transform = Transforms.load_from_registry( | ||
"random_hadamard", size=size, dtype=dtype | ||
) | ||
hadamard_apply = Transforms.fetch_apply("random_hadamard") | ||
module.weight_transform = hadamard_transform | ||
|
||
transformed_output = hadamard_apply(input_tensor=module.weight, | ||
transform=moduel.weight_transform) | ||
hadamard_transform.register_to_module(name, module) | ||
module.transform_data = {name: {"call_args": dict, "class": hadamard_transform}} | ||
|
||
hadamard_inverse = Transforms.fetch_inverse_apply("random_hadamard") | ||
original_weight = hadamard_inverse(input_tensor=transformed_output, | ||
transform=model.weight_trainsform, | ||
transpose=True) | ||
transformed_output = hadamard_transform.apply(input_tensor=module.weight) | ||
original_weight = hadamard_transform.inverse_apply( | ||
input_tensor=transformed_output) | ||
|
||
:param transform: transform (e.g. torch.Tensor, scalar) to be applied | ||
""" | ||
return torch.nn.Parameter(transform.to(device).to(dtype), requires_grad=False) | ||
|
||
@classmethod | ||
def fetch_apply(cls, name: str): | ||
if name in MATIRX_TRANSFORMS: | ||
return apply_matrix_transform | ||
raise NotImplementedError("Only matrix transforms are supported") | ||
|
||
@classmethod | ||
def fetch_inverse_apply(cls, name: str): | ||
return cls.get_value_from_registry(name=name).inverse_apply | ||
if self.learnable: | ||
self.transform = torch.nn.Parameter(transform.to(dtype).to(device)) | ||
else: | ||
self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) | ||
|
||
# register to class for easy offloading, serialization, deserialization | ||
def register_to_module(self, name: str, module: torch.nn.Module): | ||
if self.learnable: | ||
register_offload_parameter(module, name, self.transform) | ||
else: | ||
# TODO: have to verify serialization/offloading | ||
module.register_buffer(name, self.transform) | ||
|
||
def update_transform( | ||
self, | ||
data: torch.Tensor, | ||
module: Optional[torch.nn.Module] = None, | ||
name: Optional[str] = None, | ||
): | ||
if module is None: | ||
self.transform.data.copy_(data) | ||
else: | ||
# If updating the module parameter data, assumes this is also the transform | ||
# data | ||
if name is None: | ||
raise ValueError("Name and module are required to update parma data") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. param data? |
||
update_parameter_data(module, data, name) | ||
|
||
def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: | ||
""" | ||
Apply the transform to the module | ||
""" | ||
raise NotImplementedError() | ||
|
||
@staticmethod | ||
# TODO: potentially split into its own transform using the same shared set-up | ||
def inverse_apply( | ||
transform: torch.Tensor, input_tensor: torch.Tensor, *args, **kwargs | ||
self, input_tensor: torch.Tensor, *args, **kwargs | ||
) -> torch.Tensor: | ||
""" | ||
Apply the inverse operation applied by the apply method | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,12 +22,14 @@ | |
|
||
@Transforms.register("hadamard") | ||
class Hadamard(Transforms): | ||
def __new__( | ||
cls, | ||
def __init__( | ||
self, | ||
size: int, | ||
empty: Optional[bool] = False, | ||
device: Optional[Union[str, torch.device]] = "cuda", | ||
dtype: Optional[torch.dtype] = torch.bfloat16, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Produces a hadamard matrix with dims (size, size), with values | ||
|
@@ -50,11 +52,23 @@ def __new__( | |
else: | ||
transform = torch.empty((size, size)) | ||
|
||
return super().__new__(cls, transform=transform, device=device, dtype=dtype) | ||
super().__init__(transform=transform, dtype=dtype, device=device) | ||
|
||
def apply( | ||
self, | ||
input_tensor: torch.Tensor, | ||
transpose: bool = False, | ||
first: bool = True, | ||
) -> torch.Tensor: | ||
return apply_matrix_transform( | ||
transform=self.transform, | ||
input_tensor=input_tensor, | ||
transpose=transpose, | ||
first=first, | ||
) | ||
|
||
@staticmethod | ||
def inverse_apply( | ||
transform: torch.Tensor, | ||
self, | ||
input_tensor: torch.Tensor, | ||
transpose: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Afaict transpose is only used when apply an inverse hadamard, and therefore should not be controllable by the user to avoid footgunning |
||
first: bool = True, | ||
|
@@ -73,10 +87,10 @@ def inverse_apply( | |
# need to normalize before sending back | ||
return ( | ||
apply_matrix_transform( | ||
transform=transform, | ||
transform=self.transform, | ||
input_tensor=input_tensor, | ||
transpose=transpose, | ||
first=first, | ||
) | ||
/ transform.shape[0] | ||
/ self.transform.shape[0] | ||
) |
Uh oh!
There was an error while loading. Please reload this page.