|
18 | 18 | import torch |
19 | 19 | import torch.nn.utils.parametrize as P |
20 | 20 | import tqdm |
| 21 | +from compressed_tensors.modeling.attention import ( |
| 22 | + initialize_hooked_attention, |
| 23 | + register_query_hook, |
| 24 | +) |
| 25 | +from compressed_tensors.modeling.kvcache import ( |
| 26 | + initialize_hooked_kv_cache, |
| 27 | + register_key_hook, |
| 28 | +) |
21 | 29 | from compressed_tensors.registry.registry import RegistryMixin, T |
22 | 30 | from compressed_tensors.transform import ( |
23 | 31 | TransformArgs, |
|
36 | 44 | from compressed_tensors.utils.internal import InternalModule |
37 | 45 | from torch import Tensor |
38 | 46 | from torch.nn import Module, Parameter |
| 47 | +from transformers import PreTrainedModel |
39 | 48 |
|
40 | 49 |
|
41 | 50 | __all__ = ["TransformFactory", "TransformBase"] |
@@ -97,12 +106,13 @@ def apply_to_model(self, model: Module, use_tqdm=True): |
97 | 106 |
|
98 | 107 | desc = f"Applying {self.name} transforms" |
99 | 108 | for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): |
100 | | - self._apply_to_module(module, arg) |
| 109 | + self._apply_to_module(model, module, arg) |
101 | 110 |
|
102 | | - def _apply_to_module(self, module: Module, args: TransformArgs): |
| 111 | + def _apply_to_module(self, model: Module, module: Module, args: TransformArgs): |
103 | 112 | """ |
104 | 113 | Create transforms and apply them to the module |
105 | 114 |
|
| 115 | + :param model: model which module belongs to |
106 | 116 | :param module: target module to apply transforms to |
107 | 117 | :param args: defines how the transform will be applied to the target module |
108 | 118 | """ |
@@ -156,7 +166,28 @@ def output_hook(_, _input, output): |
156 | 166 |
|
157 | 167 | module.register_forward_hook(output_hook) |
158 | 168 |
|
159 | | - # other locations such as q_attn and k_attn have not been implemented |
| 169 | + # register query hook to attention |
| 170 | + elif args.location == TransformLocation.Q_ATTN: |
| 171 | + if not isinstance(model, PreTrainedModel): |
| 172 | + raise ValueError(f"Cannot hook attention of model: {model}") |
| 173 | + |
| 174 | + def query_hook(_, query_states): |
| 175 | + return transform(query_states) |
| 176 | + |
| 177 | + initialize_hooked_attention(model, module) |
| 178 | + register_query_hook(module, query_hook) |
| 179 | + |
| 180 | + # register key hook to kvcache |
| 181 | + elif args.location == TransformLocation.K_CACHE: |
| 182 | + if not isinstance(model, PreTrainedModel): |
| 183 | + raise ValueError(f"Cannot hook attention of model: {model}") |
| 184 | + |
| 185 | + def key_hook(_, key_states): |
| 186 | + return transform(key_states) |
| 187 | + |
| 188 | + initialize_hooked_kv_cache(model, module) |
| 189 | + register_key_hook(module, key_hook) |
| 190 | + |
160 | 191 | else: |
161 | 192 | raise NotImplementedError() |
162 | 193 |
|
|
0 commit comments