|
30 | 30 | from transformers.testing_utils import (
|
31 | 31 | cleanup,
|
32 | 32 | require_flash_attn,
|
33 |
| - require_read_token, |
34 | 33 | require_torch,
|
35 | 34 | require_torch_gpu,
|
36 | 35 | slow,
|
@@ -664,3 +663,257 @@ def test_generation_beyond_sliding_window_with_generation_config(self):
|
664 | 663 | model.generation_config.transformers_version = "4.49.0"
|
665 | 664 | with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
|
666 | 665 | out = model.generate(**inputs, generation_config=generation_config)
|
| 666 | + |
| 667 | + def test_export_text_only_with_hybrid_cache(self): |
| 668 | + from transformers import HybridCache |
| 669 | + |
| 670 | + class Gemma3ExportableModule(torch.nn.Module): |
| 671 | + """ |
| 672 | + A wrapper module designed to make Gemma3 models exportable with `torch.export`, |
| 673 | + specifically for use with HybridCache to support interleaved global and local attention. |
| 674 | +
|
| 675 | + This wrapper ensures that the exported model is compatible with further lowering |
| 676 | + and execution in frameworks like ExecuTorch. |
| 677 | + """ |
| 678 | + |
| 679 | + def __init__(self, model: PreTrainedModel, max_batch_size: int = 1, max_seq_len: int = 4096): |
| 680 | + """ |
| 681 | + Initializes the wrapper module with the Gemma3 model. |
| 682 | +
|
| 683 | + Args: |
| 684 | + model (`PreTrainedModel`): The Gemma3 model to wrap. |
| 685 | + max_batch_size (int): Maximum batch size for the cache. |
| 686 | + max_seq_len (int): Maximum sequence length for the cache. |
| 687 | +
|
| 688 | + Raises: |
| 689 | + AssertionError: If the model doesn't have the expected configuration for HybridCache. |
| 690 | + """ |
| 691 | + super().__init__() |
| 692 | + self.model = model |
| 693 | + self.config = model.config |
| 694 | + |
| 695 | + # Verify the model is configured for HybridCache |
| 696 | + assert self.config.cache_implementation == "hybrid", "Model must use 'hybrid' cache implementation" |
| 697 | + |
| 698 | + # Verify sliding window configuration for local attention |
| 699 | + assert hasattr(self.config, "sliding_window"), "Model config must have sliding_window attribute" |
| 700 | + assert hasattr(self.config, "sliding_window_pattern"), ( |
| 701 | + "Model config must have sliding_window_pattern attribute" |
| 702 | + ) |
| 703 | + |
| 704 | + # Initialize the HybridCache |
| 705 | + self.cache = HybridCache( |
| 706 | + config=self.config, |
| 707 | + max_batch_size=max_batch_size, |
| 708 | + max_cache_len=max_seq_len, |
| 709 | + device=model.device, |
| 710 | + dtype=model.dtype, |
| 711 | + ) |
| 712 | + |
| 713 | + # Register buffers for tracking state |
| 714 | + self.register_buffer("last_position", torch.tensor([-1], dtype=torch.long)) |
| 715 | + |
| 716 | + # Store the sliding window pattern for reference |
| 717 | + self.sliding_window = self.config.sliding_window |
| 718 | + self.sliding_window_pattern = self.config.sliding_window_pattern |
| 719 | + |
| 720 | + # Determine which layers use global vs local attention |
| 721 | + # In Gemma3, typically every 6th layer (0-indexed) uses global attention |
| 722 | + self.global_attention_layers = [ |
| 723 | + i |
| 724 | + for i in range(self.config.num_hidden_layers) |
| 725 | + if i % self.sliding_window_pattern == (self.sliding_window_pattern - 1) |
| 726 | + ] |
| 727 | + |
| 728 | + def forward( |
| 729 | + self, |
| 730 | + input_ids: torch.Tensor, |
| 731 | + cache_position: torch.Tensor, |
| 732 | + ) -> torch.Tensor: |
| 733 | + """ |
| 734 | + Forward pass of the module, compatible with torch.export. |
| 735 | +
|
| 736 | + Args: |
| 737 | + input_ids (`torch.Tensor`): Tensor representing current input token id(s). |
| 738 | + cache_position (`torch.Tensor`): Tensor representing current position(s) in the cache. |
| 739 | +
|
| 740 | + Returns: |
| 741 | + torch.Tensor: Logits output from the model. |
| 742 | + """ |
| 743 | + batch_size, seq_len = input_ids.shape |
| 744 | + |
| 745 | + # Update the last_position for tracking |
| 746 | + self.last_position = cache_position[-1].unsqueeze(0) |
| 747 | + |
| 748 | + # Generate position_ids from cache_position |
| 749 | + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) |
| 750 | + |
| 751 | + # Create attention mask (always ones for token-by-token generation) |
| 752 | + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device) |
| 753 | + |
| 754 | + # Forward pass with the model |
| 755 | + outputs = self.model( |
| 756 | + input_ids=input_ids, |
| 757 | + attention_mask=attention_mask, |
| 758 | + position_ids=position_ids, |
| 759 | + past_key_values=self.cache, |
| 760 | + use_cache=True, |
| 761 | + cache_position=cache_position, |
| 762 | + ) |
| 763 | + |
| 764 | + # Return only the logits to simplify the export |
| 765 | + return outputs.logits |
| 766 | + |
| 767 | + @staticmethod |
| 768 | + def generate( |
| 769 | + exported_model: torch.export.ExportedProgram, |
| 770 | + tokenizer, |
| 771 | + prompt: str, |
| 772 | + max_new_tokens: int = 20, |
| 773 | + do_sample: bool = False, |
| 774 | + temperature: float = 1.0, |
| 775 | + top_k: int = 50, |
| 776 | + top_p: float = 1.0, |
| 777 | + device: str = "cpu", |
| 778 | + ) -> str: |
| 779 | + """ |
| 780 | + Generate text using an exported Gemma3 model. |
| 781 | +
|
| 782 | + Args: |
| 783 | + exported_model (`torch.export.ExportedProgram`): The exported model being used for generate. |
| 784 | + tokenizer: The tokenizer to use. |
| 785 | + prompt (str): The input prompt. |
| 786 | + max_new_tokens (int): Maximum number of new tokens to generate. |
| 787 | + do_sample (bool): Whether to use sampling or greedy decoding. |
| 788 | + temperature (float): The temperature for sampling. |
| 789 | + top_k (int): The number of highest probability tokens to keep for top-k sampling. |
| 790 | + top_p (float): The cumulative probability for nucleus sampling. |
| 791 | + device (str): The device to use. |
| 792 | +
|
| 793 | + Returns: |
| 794 | + str: The generated text. |
| 795 | + """ |
| 796 | + # Get the module from the exported program |
| 797 | + exported_module = exported_model.module() |
| 798 | + |
| 799 | + # Tokenize the prompt |
| 800 | + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
| 801 | + |
| 802 | + # Initialize with the prompt |
| 803 | + generated_ids = input_ids.clone() |
| 804 | + |
| 805 | + # Process the prompt tokens first |
| 806 | + curr_position = 0 |
| 807 | + for i in range(input_ids.shape[1]): |
| 808 | + # Process one token at a time |
| 809 | + curr_input_ids = input_ids[:, i : i + 1] |
| 810 | + curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) |
| 811 | + |
| 812 | + # Forward pass |
| 813 | + _ = exported_module(curr_input_ids, curr_cache_position) |
| 814 | + curr_position += 1 |
| 815 | + |
| 816 | + # Generate new tokens |
| 817 | + for _ in range(max_new_tokens): |
| 818 | + # Get the last token as input |
| 819 | + curr_input_ids = generated_ids[:, -1:] |
| 820 | + curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) |
| 821 | + |
| 822 | + # Forward pass to get next token logits |
| 823 | + outputs = exported_module(curr_input_ids, curr_cache_position) |
| 824 | + |
| 825 | + # Get the next token ID |
| 826 | + if do_sample: |
| 827 | + # Apply temperature |
| 828 | + if temperature > 0: |
| 829 | + logits = outputs / temperature |
| 830 | + else: |
| 831 | + logits = outputs |
| 832 | + |
| 833 | + # Apply top-k filtering |
| 834 | + if top_k > 0: |
| 835 | + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| 836 | + logits[indices_to_remove] = float("-inf") |
| 837 | + |
| 838 | + # Apply top-p (nucleus) filtering |
| 839 | + if top_p < 1.0: |
| 840 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| 841 | + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| 842 | + |
| 843 | + # Remove tokens with cumulative probability above the threshold |
| 844 | + sorted_indices_to_remove = cumulative_probs > top_p |
| 845 | + # Shift the indices to the right to keep also the first token above the threshold |
| 846 | + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 847 | + sorted_indices_to_remove[..., 0] = 0 |
| 848 | + |
| 849 | + # Scatter sorted tensors to original indexing |
| 850 | + indices_to_remove = sorted_indices_to_remove.scatter( |
| 851 | + -1, sorted_indices, sorted_indices_to_remove |
| 852 | + ) |
| 853 | + logits[indices_to_remove] = float("-inf") |
| 854 | + |
| 855 | + # Sample from the filtered distribution |
| 856 | + probs = torch.softmax(logits, dim=-1) |
| 857 | + next_token_id = torch.multinomial(probs, num_samples=1) |
| 858 | + else: |
| 859 | + # Greedy decoding |
| 860 | + next_token_id = outputs.argmax(dim=-1, keepdim=True) |
| 861 | + |
| 862 | + # Ensure next_token_id has the right shape before concatenation |
| 863 | + if next_token_id.dim() > 2: |
| 864 | + next_token_id = next_token_id.squeeze(-1) |
| 865 | + |
| 866 | + # Append to the generated sequence |
| 867 | + generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
| 868 | + curr_position += 1 |
| 869 | + |
| 870 | + # Stop if we generate an EOS token |
| 871 | + if next_token_id.item() == tokenizer.eos_token_id: |
| 872 | + break |
| 873 | + |
| 874 | + # Decode the generated text |
| 875 | + return tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| 876 | + |
| 877 | + model_id = "google/gemma-3-1b-it" |
| 878 | + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 879 | + model = AutoModelForCausalLM.from_pretrained(model_id) |
| 880 | + model.eval() |
| 881 | + print(f"Model config: {model.config}") |
| 882 | + |
| 883 | + # Create a wrapper for export with static batch size |
| 884 | + wrapper = Gemma3ExportableModule(model) |
| 885 | + |
| 886 | + # Prepare example inputs |
| 887 | + example_input_ids = torch.tensor([[1]], dtype=torch.long) |
| 888 | + example_cache_position = torch.tensor([0], dtype=torch.long) |
| 889 | + |
| 890 | + # Export the model with static shapes |
| 891 | + exported_program = torch.export.export( |
| 892 | + wrapper, |
| 893 | + (example_input_ids, example_cache_position), |
| 894 | + strict=False, |
| 895 | + ) |
| 896 | + print(f"Exported program: {exported_program}") |
| 897 | + |
| 898 | + # Test generation with the exported model |
| 899 | + # prompt = "What is the capital of France?" |
| 900 | + prompt = "Write a poem about Machine Learning." |
| 901 | + max_new_tokens_to_generate = 100 |
| 902 | + # Generate text with the exported model |
| 903 | + export_generated_text = Gemma3ExportableModule.generate( |
| 904 | + exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate |
| 905 | + ) |
| 906 | + print(f"Export generated texts: '{export_generated_text}'") |
| 907 | + |
| 908 | + input_text = tokenizer(prompt, return_tensors="pt") |
| 909 | + with torch.no_grad(): |
| 910 | + eager_outputs = model.generate( |
| 911 | + **input_text, |
| 912 | + max_new_tokens=max_new_tokens_to_generate, |
| 913 | + do_sample=False, # Use greedy decoding to match the exported model |
| 914 | + ) |
| 915 | + |
| 916 | + eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True) |
| 917 | + print(f"Eager generated texts: '{eager_generated_text}'") |
| 918 | + |
| 919 | + self.assertEqual(export_generated_text, eager_generated_text) |
0 commit comments