-
Notifications
You must be signed in to change notification settings - Fork 10
[Transforms] Transform Registry Support #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dsikka/transform_support
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending normalization of name in fetch_apply
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meant to request changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The registry setup looks good to me, especially towards future abilities to able to cache the Hadamard matrices. I have a few high-level concerns regarding the current architecture, though, and believe some adjustments could significantly enhance flexibility and usability. At present, it feels primarily like a convenient wrapper around functional calls, shifting much of the complexity and implementation responsibility onto the calling code rather than encapsulating it. Specifically:
- External State Management: Currently, the required state for a transform is handled external to the Transform class forcing callers to manage parameters, state, and execution flow explicitly and reducing the transforms to simple functional wrappers.
- Static Methods and Extensibility: Implementing transforms solely as static methods severely limits subclassing and future extensibility. At a minimum, they should be class methods to allow future override behavior when needed, but ideally it would be instance methods supporting better OOP flows as mentioned in the other feedback.
- Alignment with PyTorch standards: The current implementation doesn't leverage standard PyTorch patterns such as Modules, Hooks, or Parameterization API. With that, it misses out on some key ecosystem benefits like handling parameters and devices.
- Separate Forward and Inverse Operations: Combining forward and inaverse operations into a single transform reduces readability, IMO, and deviates form typical PyTorch patterns such as the standard setup with Quant and DeQuant. By setting up the responsibilities for a single instance to worry about just a forward, it's much easier to construct graphs and not need additional conditional call logic in the caller code for when to inject and invert.
- Hardcoded Single Parameter Assumption: The current implementation assumes a single, hardcoded parameter managed externally restricting flexibility and easy extension to the modifiers either for different types or multiple of the same type.
- Lack of Iterative Update Mechanisms: There are currently no straightforward mechanisms provided for easily updating transform values which are crucial for iterative algorithms like SpinQuant.
An ideal solution in my mind would expand out what exists here and implement a more native PyTorch solution that would also be more object oriented as noted in some of the other review comments. With that, I think that would boil down to either wrapper modules for transforms or utilizing hooks and the parameterization API. I've included an example implementation of the latter below, which would enable the following simple representation for a hadamard quantization setup:
layer = Linear(...)
HadamardTransform(layer, "weight", ...)
QuantizationTransform(layer, "weight", ...)
DeQuantizationTransform(layer, "weight", ...)
InverseHadamardTransform(layer, "weight", ...)
Ultimatley we can utilize the above in numerous places and it significantly simplifies the logic needed for the calling code. It sets up for a black box API for the caller and all it needs to know are what transforms to apply in that case. Additionally, it's easily and quickly extensible to create channelwise scaling for things like AWQ and SmoothQuant, Pruning, etc
from abc import ABC, abstractmethod
from typing import Any, Optional, Union, Literal
from pytest import param
import torch
from torch.nn import Module
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
from compressed_tensors.registry.registry import RegistryMixin
class Transforms(ABC, Module, RegistryMixin):
def __init__(
self, module: Module, target: Union[Literal["inputs", "outputs"], str]
):
self.module = module
self.target = target
self.hooks = []
self.add_hooks()
def add_hooks(self):
if self.target == "inputs":
self.hooks.append(
self.module.register_forward_pre_hook(
lambda module, inputs: self.forward(inputs)
)
)
elif self.target == "outputs":
self.hooks.append(
self.module.register_forward_hook(
lambda module, inputs, outputs: self.forward(outputs)
)
)
else:
# targeting a parameter
register_parametrization(self.module, self.target, self)
def remove_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks = []
# TODO: figure out how to remove just this parametrization and not all
remove_parametrizations(self.module, self.target)
@abstractmethod
def update(*args, **kwargs): ...
@abstractmethod
def forward(self, inputs: Any) -> Any: ...
@Transforms.register("hadamard")
class HadamardTransform(Transforms):
def __init__(
self,
module: Module,
target: Union[Literal["inputs", "outputs"], str],
size: int,
left_multiply: bool = False,
preinitialize: bool = True,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__(module, target)
self.target = target
self.size = size
self.left_multiply = left_multiply
self.module.register_buffer(
name=self.hadamard_buffer_name,
tensor=torch.empty((size, size), device=device, dtype=dtype),
)
self.module.register_buffer(
name=self.permute_buffer_name,
tensor=torch.empty((size, size), device=device, dtype=dtype),
)
if preinitialize:
self.update(
hadamard=random_hadamard_matrix(size),
permutation=torch.diagonal(torch.eye(size)),
)
@property
def hadamard_buffer_name(self) -> str:
return f"{self.target}_hadamard_transform"
@property
def permute_buffer_name(self) -> str:
return f"{self.target}_hadamard_permute_transform"
def update(
self,
hadamard: Optional[torch.Tensor] = None,
permutation: Optional[torch.Tensor] = None,
):
if hadamard is not None:
getattr(self.module, self.hadamard_buffer_name).copy_(hadamard)
if permutation is not None:
getattr(self.module, self.permute_buffer_name).copy_(permutation)
def forward(self, inputs: Any) -> Any:
hadamard = getattr(self.module, self.hadamard_buffer_name)
permutation = getattr(self.module, self.permute_buffer_name)
if self.left_multiply:
return hadamard @ inputs @ permutation
else:
return inputs @ permutation @ hadamard
@Transforms.register("hadamard_inverse")
class HadamardInverseTransform(HadamardTransform):
def forward(self, inputs: Any) -> Any:
hadamard = getattr(self.module, self.hadamard_buffer_name)
permutation = getattr(self.module, self.permute_buffer_name)
if self.left_multiply:
return permutation.T @ hadamard @ inputs
else:
return inputs @ hadamard @ permutation.T
@Transforms.register("quantization")
class QuantizationTransform(Transforms):
def __init__(
self,
module: Module,
target: Union[Literal["inputs", "outputs"], str],
scale: float,
zero_point: int,
quant_dtype: torch.dtype,
):
super().__init__(module, target)
self.module.register_buffer(
name=self.zp_buffer_name, tensor=torch.tensor(zero_point, dtype=torch.int)
)
self.module.register_buffer(
name=self.scale_buffer_name, tensor=torch.tensor(scale, dtype=torch.float)
)
self.quant_dtype = quant_dtype
@property
def zp_buffer_name(self) -> str:
return f"{self.target}_zero_point"
@property
def scale_buffer_name(self) -> str:
return f"{self.target}_scale"
def update(self, scale: float, zero_point: int):
getattr(self.module, self.zp_buffer_name).fill_(zero_point)
getattr(self.module, self.scale_buffer_name).fill_(
torch.tensor(scale, dtype=torch.float)
)
def forward(self, inputs: Any) -> Any:
scale = getattr(self.module, self.scale_buffer_name)
zero_point = getattr(self.module, self.zp_buffer_name)
quantized = torch.clamp(
torch.round(inputs / scale) + zero_point,
torch.iinfo(self.quant_dtype).min,
torch.iinfo(self.quant_dtype).max,
)
return quantized.to(self.quant_dtype)
@Transforms.register("dequantization")
class DeQuantizationTransform(Transforms):
def __init__(
self,
module: Module,
target: Union[Literal["inputs", "outputs"], str],
scale: float,
zero_point: int,
quant_dtype: torch.dtype,
dtype: torch.dtype,
):
super().__init__(module, target)
self.scale = scale
self.zero_point = zero_point
self.quant_dtype = quant_dtype
self.dtype = dtype
def update(self, scale: float, zero_point: int):
self.scale = scale
self.zero_point = zero_point
def forward(self, inputs: Any) -> Any:
scale = torch.tensor(self.scale, dtype=torch.float)
zero_point = torch.tensor(self.zero_point, dtype=torch.int)
dequantized = (inputs - zero_point) * scale
return dequantized.to(self.dtype)
@markurtz
The transforms exactly make use of Parameterization and therefore, directly benefit from device handling. Again, please look through the remaining PRs, especially: #276.
The forward and inverse operations make use of the same matrix/data but are two separate operations, as implemented by
I am not sure what this is referring to. Each transform class is dedicated to one type of transform which is responsible for its specific set-up. I strongly disagree that the registry should be an extension of the Module class. I do agree that there should be functionality to support updating. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
# If updating the module parameter data, assumes this is also the transform | ||
# data | ||
if name is None: | ||
raise ValueError("Name and module are required to update parma data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
param data?
def inverse_apply( | ||
self, | ||
input_tensor: torch.Tensor, | ||
transpose: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having first
and call_args in general seems to be an extensible pattern, but maybe transpose
should not be a call arg, as it seems to be determined by the specific Transform and whether you're applying inverse, not something that a caller should have control over?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Afaict transpose is only used when apply an inverse hadamard, and therefore should not be controllable by the user to avoid footgunning
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Summary:
inverse_apply
method. A helper methodfetch_inverse_apply
has been set-up to fetch the method during runtimeempty
option which if set to True, allows the transform classes to initialize an empty parameter onto which the deserialized transforms can be loaded. If False, a new transform is created and returnedWhere H is the hadamard rotation. When multiplied by its transpose, the matrix produced is a multiple of the identity. For the randomized hadamard, the matrix is normalized while for the deterministic transform, it is not.
A series of utilities have been added to support the generation of the randomized hadamard matrix, leveraged from the SpinQuant and QuaRot repositories. The utilities include hadamard matrices for commonly used sizes. So far, only had12 and had20 have been included which have saved in the files in their packed bit format to potentially experiment with other ways we could store the matrices before we decide to add more sizes. Suggestions are welcome!
When the preset hadamards are used, the matrices can be unpacked using the following code:
Example Use:
To Do / Future: