|
20 | 20 |
|
21 | 21 |
|
22 | 22 | if is_torch_available():
|
23 |
| - from transformers import PreTrainedModel, StaticCache |
| 23 | + from transformers import HybridCache, PreTrainedModel, StaticCache |
24 | 24 | from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
25 | 25 |
|
26 | 26 |
|
| 27 | +class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): |
| 28 | + """ |
| 29 | + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, |
| 30 | + specifically for decoder-only LM with cache. This module ensures that the |
| 31 | + exported model is compatible with further lowering and execution in `ExecuTorch`. |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + model: PreTrainedModel, |
| 37 | + max_batch_size: int = 1, |
| 38 | + max_cache_len: int = 4096, |
| 39 | + ): |
| 40 | + """ |
| 41 | + Initializes the exportable module with `HybridCache`. |
| 42 | +
|
| 43 | + Args: |
| 44 | + model (`PreTrainedModel`): The pretrained model to wrap. |
| 45 | + max_batch_size (int): Maximum batch size for the cache. |
| 46 | + max_cache_len (int): Maximum sequence length for the cache. |
| 47 | +
|
| 48 | + Raises: |
| 49 | + ValueError: If the model is configured with a unsupported cache implementation. |
| 50 | + """ |
| 51 | + super().__init__() |
| 52 | + |
| 53 | + if model.config.cache_implementation == "static": |
| 54 | + self.model = TorchExportableModuleWithStaticCache(model) |
| 55 | + elif model.config.cache_implementation == "hybrid": |
| 56 | + self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) |
| 57 | + else: |
| 58 | + raise ValueError( |
| 59 | + f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'" |
| 60 | + ) |
| 61 | + |
| 62 | + def forward( |
| 63 | + self, |
| 64 | + input_ids: torch.Tensor, |
| 65 | + cache_position: torch.Tensor, |
| 66 | + ) -> torch.Tensor: |
| 67 | + """ |
| 68 | + Forward pass of the module, which is compatible with the ExecuTorch llm runner. |
| 69 | +
|
| 70 | + Args: |
| 71 | + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. |
| 72 | + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + torch.Tensor: Logits output from the model. |
| 76 | + """ |
| 77 | + return self.model.forward(input_ids, cache_position) |
| 78 | + |
| 79 | + def export( |
| 80 | + self, |
| 81 | + input_ids: Optional[torch.Tensor] = None, |
| 82 | + cache_position: Optional[torch.Tensor] = None, |
| 83 | + dynamic_shapes: Optional[dict] = None, |
| 84 | + strict: Optional[bool] = None, |
| 85 | + ) -> torch.export.ExportedProgram: |
| 86 | + """ |
| 87 | + Export the wrapped module using `torch.export`. |
| 88 | +
|
| 89 | + Args: |
| 90 | + input_ids (`Optional[torch.Tensor]`): |
| 91 | + Tensor representing current input token id to the module. If not provided, a default tensor will be used. |
| 92 | + cache_position (`Optional[torch.Tensor]`): |
| 93 | + Tensor representing current input position in the cache. If not provided, a default tensor will be used. |
| 94 | + dynamic_shapes (`Optional[dict]`): |
| 95 | + Dynamic shapes to use for export if specified. |
| 96 | + strict(`Optional[bool]`): |
| 97 | + Flag to instruct `torch.export` to use `torchdynamo`. |
| 98 | + """ |
| 99 | + example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) |
| 100 | + example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) |
| 101 | + |
| 102 | + return torch.export.export( |
| 103 | + self.model, |
| 104 | + args=(example_input_ids, example_cache_position), |
| 105 | + kwargs={}, |
| 106 | + dynamic_shapes=dynamic_shapes, |
| 107 | + strict=strict if strict is not None else True, |
| 108 | + ) |
| 109 | + |
| 110 | + @staticmethod |
| 111 | + def generate( |
| 112 | + exported_program: torch.export.ExportedProgram, |
| 113 | + tokenizer, |
| 114 | + prompt: str, |
| 115 | + max_new_tokens: int = 20, |
| 116 | + do_sample: bool = False, |
| 117 | + temperature: float = 1.0, |
| 118 | + top_k: int = 50, |
| 119 | + top_p: float = 1.0, |
| 120 | + device: str = "cpu", |
| 121 | + ) -> str: |
| 122 | + """ |
| 123 | + Generate a sequence of tokens using an exported program. |
| 124 | +
|
| 125 | + Args: |
| 126 | + exported_program (`torch.export.ExportedProgram`): The exported model being used for generate. |
| 127 | + tokenizer: The tokenizer to use. |
| 128 | + prompt (str): The input prompt. |
| 129 | + max_new_tokens (int): Maximum number of new tokens to generate. |
| 130 | + do_sample (bool): Whether to use sampling or greedy decoding. |
| 131 | + temperature (float): The temperature for sampling. |
| 132 | + top_k (int): The number of highest probability tokens to keep for top-k sampling. |
| 133 | + top_p (float): The cumulative probability for nucleus sampling. |
| 134 | + device (str): The device to use. |
| 135 | +
|
| 136 | + Returns: |
| 137 | + str: The generated text. |
| 138 | + """ |
| 139 | + # Get the module from the exported program |
| 140 | + exported_module = exported_program.module() |
| 141 | + |
| 142 | + # Tokenize the prompt |
| 143 | + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
| 144 | + |
| 145 | + # Initialize with the prompt |
| 146 | + generated_ids = input_ids.clone() |
| 147 | + |
| 148 | + # Process the prompt tokens first |
| 149 | + curr_position = 0 |
| 150 | + for i in range(input_ids.shape[1]): |
| 151 | + # Process one token at a time |
| 152 | + curr_input_ids = input_ids[:, i : i + 1] |
| 153 | + curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) |
| 154 | + |
| 155 | + # Forward pass |
| 156 | + _ = exported_module(curr_input_ids, curr_cache_position) |
| 157 | + curr_position += 1 |
| 158 | + |
| 159 | + # Generate new tokens |
| 160 | + for _ in range(max_new_tokens): |
| 161 | + # Get the last token as input |
| 162 | + curr_input_ids = generated_ids[:, -1:] |
| 163 | + curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) |
| 164 | + |
| 165 | + # Forward pass to get next token logits |
| 166 | + outputs = exported_module(curr_input_ids, curr_cache_position) |
| 167 | + |
| 168 | + # Get the next token ID |
| 169 | + if do_sample: |
| 170 | + # Apply temperature |
| 171 | + if temperature > 0: |
| 172 | + logits = outputs / temperature |
| 173 | + else: |
| 174 | + logits = outputs |
| 175 | + |
| 176 | + # Apply top-k filtering |
| 177 | + if top_k > 0: |
| 178 | + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| 179 | + logits[indices_to_remove] = float("-inf") |
| 180 | + |
| 181 | + # Apply top-p (nucleus) filtering |
| 182 | + if top_p < 1.0: |
| 183 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| 184 | + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| 185 | + |
| 186 | + # Remove tokens with cumulative probability above the threshold |
| 187 | + sorted_indices_to_remove = cumulative_probs > top_p |
| 188 | + # Shift the indices to the right to keep also the first token above the threshold |
| 189 | + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 190 | + sorted_indices_to_remove[..., 0] = 0 |
| 191 | + |
| 192 | + # Scatter sorted tensors to original indexing |
| 193 | + indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) |
| 194 | + logits[indices_to_remove] = float("-inf") |
| 195 | + |
| 196 | + # Sample from the filtered distribution |
| 197 | + probs = torch.softmax(logits, dim=-1) |
| 198 | + next_token_id = torch.multinomial(probs, num_samples=1) |
| 199 | + else: |
| 200 | + # Greedy decoding |
| 201 | + next_token_id = outputs.argmax(dim=-1, keepdim=True) |
| 202 | + |
| 203 | + # Ensure next_token_id has the right shape before concatenation |
| 204 | + if next_token_id.dim() > 2: |
| 205 | + next_token_id = next_token_id.squeeze(-1) |
| 206 | + |
| 207 | + # Append to the generated sequence |
| 208 | + generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
| 209 | + curr_position += 1 |
| 210 | + |
| 211 | + # Stop if we generate an EOS token |
| 212 | + if next_token_id.item() == tokenizer.eos_token_id: |
| 213 | + break |
| 214 | + |
| 215 | + # Decode the generated text |
| 216 | + return tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| 217 | + |
| 218 | + |
27 | 219 | class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
28 | 220 | """
|
29 |
| - A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`, |
30 |
| - specifically for use with static caching. This module ensures that the exported model |
31 |
| - is compatible with further lowering and execution in `ExecuTorch`. |
| 221 | + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, |
| 222 | + specifically for decoder-only LM to `StaticCache`. This module ensures that the |
| 223 | + exported model is compatible with further lowering and execution in `ExecuTorch`. |
32 | 224 |
|
33 | 225 | Note:
|
34 | 226 | This class is specifically designed to support export process using `torch.export`
|
@@ -178,6 +370,94 @@ def generate(
|
178 | 370 | return torch.tensor([response_tokens], dtype=torch.long)
|
179 | 371 |
|
180 | 372 |
|
| 373 | +class TorchExportableModuleWithHybridCache(torch.nn.Module): |
| 374 | + """ |
| 375 | + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, |
| 376 | + specifically for decoder-only LM to `HybridCache`. This module ensures that the |
| 377 | + exported model is compatible with further lowering and execution in `ExecuTorch`. |
| 378 | + """ |
| 379 | + |
| 380 | + def __init__( |
| 381 | + self, |
| 382 | + model: PreTrainedModel, |
| 383 | + max_batch_size: int = 1, |
| 384 | + max_cache_len: int = 4096, |
| 385 | + ): |
| 386 | + """ |
| 387 | + Initializes the exportable module with `HybridCache`. |
| 388 | +
|
| 389 | + Args: |
| 390 | + model (`PreTrainedModel`): The pretrained model to wrap. |
| 391 | + max_batch_size (int): Maximum batch size for the cache. |
| 392 | + max_cache_len (int): Maximum sequence length for the cache. |
| 393 | +
|
| 394 | + Raises: |
| 395 | + AssertionError: If the model doesn't have the expected configuration for HybridCache. |
| 396 | + """ |
| 397 | + super().__init__() |
| 398 | + self.model = model |
| 399 | + |
| 400 | + # Verify the model is configured for HybridCache |
| 401 | + if not self.model.config.use_cache: |
| 402 | + raise AssertionError("Model must have caching enabled") |
| 403 | + |
| 404 | + if ( |
| 405 | + not hasattr(self.model.config, "cache_implementation") |
| 406 | + or self.model.config.cache_implementation != "hybrid" |
| 407 | + ): |
| 408 | + raise AssertionError("Model must use 'hybrid' cache implementation") |
| 409 | + |
| 410 | + # Initialize the HybridCache |
| 411 | + self.cache = HybridCache( |
| 412 | + config=self.model.config, |
| 413 | + max_batch_size=max_batch_size, |
| 414 | + max_cache_len=max_cache_len, |
| 415 | + device=self.model.device, |
| 416 | + dtype=self.model.dtype, |
| 417 | + ) |
| 418 | + |
| 419 | + # Register all key and value cache tensors as buffers |
| 420 | + for i in range(len(self.cache.key_cache)): |
| 421 | + self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) |
| 422 | + self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) |
| 423 | + |
| 424 | + def forward( |
| 425 | + self, |
| 426 | + input_ids: torch.Tensor, |
| 427 | + cache_position: torch.Tensor, |
| 428 | + ) -> torch.Tensor: |
| 429 | + """ |
| 430 | + Forward pass of the module, which is compatible with the ExecuTorch llm runner. |
| 431 | +
|
| 432 | + Args: |
| 433 | + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. |
| 434 | + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. |
| 435 | +
|
| 436 | + Returns: |
| 437 | + torch.Tensor: Logits output from the model. |
| 438 | + """ |
| 439 | + batch_size, seq_len = input_ids.shape |
| 440 | + |
| 441 | + # Generate position_ids from cache_position |
| 442 | + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) |
| 443 | + |
| 444 | + # Create attention mask (always ones for token-by-token generation) |
| 445 | + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device) |
| 446 | + |
| 447 | + # Forward pass with the model |
| 448 | + outputs = self.model( |
| 449 | + input_ids=input_ids, |
| 450 | + attention_mask=attention_mask, |
| 451 | + position_ids=position_ids, |
| 452 | + past_key_values=self.cache, |
| 453 | + use_cache=True, |
| 454 | + cache_position=cache_position, |
| 455 | + ) |
| 456 | + |
| 457 | + # Return only the logits to simplify the export |
| 458 | + return outputs.logits |
| 459 | + |
| 460 | + |
181 | 461 | def convert_and_export_with_cache(
|
182 | 462 | model: PreTrainedModel,
|
183 | 463 | example_input_ids: Optional[torch.Tensor] = None,
|
|
0 commit comments