From bff3c2fc48715c78f94b508580812c73c2c2a465 Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Wed, 23 Apr 2025 16:36:57 -0700 Subject: [PATCH 1/2] Gemma3 is Torch Exportable --- src/transformers/integrations/executorch.py | 288 +++++++++++++++++- .../models/gemma3/modeling_gemma3.py | 2 +- tests/models/gemma3/test_modeling_gemma3.py | 41 +++ 3 files changed, 326 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 591c556e59f0..7a5f1fd79763 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -20,15 +20,207 @@ if is_torch_available(): - from transformers import PreTrainedModel, StaticCache + from transformers import HybridCache, PreTrainedModel, StaticCache from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3 +class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for decoder-only LM with cache. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module with `HybridCache`. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + + Raises: + ValueError: If the model is configured with a unsupported cache implementation. + """ + super().__init__() + + if model.config.cache_implementation == "static": + self.model = TorchExportableModuleWithStaticCache(model) + elif model.config.cache_implementation == "hybrid": + self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + else: + raise ValueError( + f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'" + ) + + def forward( + self, + input_ids: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.model.forward(input_ids, cache_position) + + def export( + self, + input_ids: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, + ) -> torch.export.ExportedProgram: + """ + Export the wrapped module using `torch.export`. + + Args: + input_ids (`Optional[torch.Tensor]`): + Tensor representing current input token id to the module. If not provided, a default tensor will be used. + cache_position (`Optional[torch.Tensor]`): + Tensor representing current input position in the cache. If not provided, a default tensor will be used. + dynamic_shapes (`Optional[dict]`): + Dynamic shapes to use for export if specified. + strict(`Optional[bool]`): + Flag to instruct `torch.export` to use `torchdynamo`. + """ + example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) + example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) + + return torch.export.export( + self.model, + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + + @staticmethod + def generate( + exported_program: torch.export.ExportedProgram, + tokenizer, + prompt: str, + max_new_tokens: int = 20, + do_sample: bool = False, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 1.0, + device: str = "cpu", + ) -> str: + """ + Generate a sequence of tokens using an exported program. + + Args: + exported_program (`torch.export.ExportedProgram`): The exported model being used for generate. + tokenizer: The tokenizer to use. + prompt (str): The input prompt. + max_new_tokens (int): Maximum number of new tokens to generate. + do_sample (bool): Whether to use sampling or greedy decoding. + temperature (float): The temperature for sampling. + top_k (int): The number of highest probability tokens to keep for top-k sampling. + top_p (float): The cumulative probability for nucleus sampling. + device (str): The device to use. + + Returns: + str: The generated text. + """ + # Get the module from the exported program + exported_module = exported_program.module() + + # Tokenize the prompt + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + + # Initialize with the prompt + generated_ids = input_ids.clone() + + # Process the prompt tokens first + curr_position = 0 + for i in range(input_ids.shape[1]): + # Process one token at a time + curr_input_ids = input_ids[:, i : i + 1] + curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) + + # Forward pass + _ = exported_module(curr_input_ids, curr_cache_position) + curr_position += 1 + + # Generate new tokens + for _ in range(max_new_tokens): + # Get the last token as input + curr_input_ids = generated_ids[:, -1:] + curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) + + # Forward pass to get next token logits + outputs = exported_module(curr_input_ids, curr_cache_position) + + # Get the next token ID + if do_sample: + # Apply temperature + if temperature > 0: + logits = outputs / temperature + else: + logits = outputs + + # Apply top-k filtering + if top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = float("-inf") + + # Apply top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # Scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = float("-inf") + + # Sample from the filtered distribution + probs = torch.softmax(logits, dim=-1) + next_token_id = torch.multinomial(probs, num_samples=1) + else: + # Greedy decoding + next_token_id = outputs.argmax(dim=-1, keepdim=True) + + # Ensure next_token_id has the right shape before concatenation + if next_token_id.dim() > 2: + next_token_id = next_token_id.squeeze(-1) + + # Append to the generated sequence + generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) + curr_position += 1 + + # Stop if we generate an EOS token + if next_token_id.item() == tokenizer.eos_token_id: + break + + # Decode the generated text + return tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + class TorchExportableModuleWithStaticCache(torch.nn.Module): """ - A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`, - specifically for use with static caching. This module ensures that the exported model - is compatible with further lowering and execution in `ExecuTorch`. + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for decoder-only LM to `StaticCache`. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. Note: This class is specifically designed to support export process using `torch.export` @@ -178,6 +370,94 @@ def generate( return torch.tensor([response_tokens], dtype=torch.long) +class TorchExportableModuleWithHybridCache(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for decoder-only LM to `HybridCache`. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module with `HybridCache`. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + + Raises: + AssertionError: If the model doesn't have the expected configuration for HybridCache. + """ + super().__init__() + self.model = model + + # Verify the model is configured for HybridCache + if not self.model.config.use_cache: + raise AssertionError("Model must have caching enabled") + + if ( + not hasattr(self.model.config, "cache_implementation") + or self.model.config.cache_implementation != "hybrid" + ): + raise AssertionError("Model must use 'hybrid' cache implementation") + + # Initialize the HybridCache + self.cache = HybridCache( + config=self.model.config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.model.device, + dtype=self.model.dtype, + ) + + # Register all key and value cache tensors as buffers + for i in range(len(self.cache.key_cache)): + self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) + + def forward( + self, + input_ids: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + batch_size, seq_len = input_ids.shape + + # Generate position_ids from cache_position + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) + + # Create attention mask (always ones for token-by-token generation) + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device) + + # Forward pass with the model + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=self.cache, + use_cache=True, + cache_position=cache_position, + ) + + # Return only the logits to simplify the export + return outputs.logits + + def convert_and_export_with_cache( model: PreTrainedModel, example_input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 951e8d78ca9b..e5120c54611d 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -410,7 +410,7 @@ def forward( # In case we are beyond the sliding window, we need to correctly offset the mask slicing offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = max(0, offset) + offset = torch.clamp(offset, min=0) # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, # but without data-dependent slicing (i.e. torch.compile friendly) mask_indexes = torch.arange( diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index be83749cf8bc..39b7abeaea80 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch Gemma3 model.""" +import logging import tempfile import unittest @@ -52,6 +53,7 @@ Gemma3Processor, Gemma3TextModel, ) + from transformers.pytorch_utils import is_torch_greater_or_equal class Gemma3ModelTester(GemmaModelTester): @@ -664,3 +666,42 @@ def test_generation_beyond_sliding_window_with_generation_config(self): model.generation_config.transformers_version = "4.49.0" with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache out = model.generate(**inputs, generation_config=generation_config) + + def test_export_text_only_with_hybrid_cache(self): + if not is_torch_greater_or_equal("2.6.0"): + self.skipTest(reason="This test requires torch >= 2.6 to run.") + + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + model_id = "google/gemma-3-1b-it" + model = AutoModelForCausalLM.from_pretrained(model_id) + self.assertEqual(model.config.cache_implementation, "hybrid") + + # Export + HybridCache + model.eval() + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export() + logging.info(f"\nExported program: {exported_program}") + + # Test generation with the exported model + prompt = "What is the capital of France?" + max_new_tokens_to_generate = 20 + # Generate text with the exported model + tokenizer = AutoTokenizer.from_pretrained(model_id) + export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate( + exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate + ) + logging.info(f"\nExport generated texts: '{export_generated_text}'") + + input_text = tokenizer(prompt, return_tensors="pt") + with torch.no_grad(): + eager_outputs = model.generate( + **input_text, + max_new_tokens=max_new_tokens_to_generate, + do_sample=False, # Use greedy decoding to match the exported model + ) + + eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True) + logging.info(f"\nEager generated texts: '{eager_generated_text}'") + + self.assertEqual(export_generated_text, eager_generated_text) From ae84d3f0296439a1b73c60de0245f4f0ef92ac07 Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Fri, 25 Apr 2025 10:49:44 -0700 Subject: [PATCH 2/2] Expand the support to other mdoels using HybridCache --- .../models/cohere2/modeling_cohere2.py | 2 +- .../models/cohere2/modular_cohere2.py | 2 +- .../models/gemma2/modeling_gemma2.py | 2 +- .../models/gemma2/modular_gemma2.py | 2 +- .../models/gemma3/modular_gemma3.py | 2 +- tests/models/gemma2/test_modeling_gemma2.py | 38 +++++++++++++++++++ 6 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index cc790124ccde..43101580c496 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -351,7 +351,7 @@ def forward( # In case we are beyond the sliding window, we need to correctly offset the mask slicing offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = max(0, offset) + offset = torch.clamp(offset, min=0) # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, # but without data-dependent slicing (i.e. torch.compile friendly) mask_indexes = torch.arange( diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index e811aabedbde..9a47f493930f 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -400,7 +400,7 @@ def forward( # In case we are beyond the sliding window, we need to correctly offset the mask slicing offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = max(0, offset) + offset = torch.clamp(offset, min=0) # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, # but without data-dependent slicing (i.e. torch.compile friendly) mask_indexes = torch.arange( diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fd63ec26c1c3..283fec4b4317 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -317,7 +317,7 @@ def forward( # In case we are beyond the sliding window, we need to correctly offset the mask slicing offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = max(0, offset) + offset = torch.clamp(offset, min=0) # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, # but without data-dependent slicing (i.e. torch.compile friendly) mask_indexes = torch.arange( diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 3d16f842ec66..5234938dc72b 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -364,7 +364,7 @@ def forward( # In case we are beyond the sliding window, we need to correctly offset the mask slicing offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = max(0, offset) + offset = torch.clamp(offset, min=0) # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, # but without data-dependent slicing (i.e. torch.compile friendly) mask_indexes = torch.arange( diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 90e6a4be2ff3..b4e7301964fd 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -494,7 +494,7 @@ def forward( # In case we are beyond the sliding window, we need to correctly offset the mask slicing offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = max(0, offset) + offset = torch.clamp(offset, min=0) # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, # but without data-dependent slicing (i.e. torch.compile friendly) mask_indexes = torch.arange( diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index d1ba0cbec4e6..c05f319d5a95 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -337,6 +337,44 @@ def test_export_static_cache(self): ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + @slow + @require_read_token + def test_export_hybrid_cache(self): + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + from transformers.pytorch_utils import is_torch_greater_or_equal + + if not is_torch_greater_or_equal("2.6.0"): + self.skipTest(reason="This test requires torch >= 2.6 to run.") + + model_id = "google/gemma-2-2b" + model = AutoModelForCausalLM.from_pretrained(model_id) + self.assertEqual(model.config.cache_implementation, "hybrid") + + # Export + HybridCache + model.eval() + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export() + + # Test generation with the exported model + prompt = "What is the capital of France?" + max_new_tokens_to_generate = 20 + # Generate text with the exported model + tokenizer = AutoTokenizer.from_pretrained(model_id) + export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate( + exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate + ) + + input_text = tokenizer(prompt, return_tensors="pt") + with torch.no_grad(): + eager_outputs = model.generate( + **input_text, + max_new_tokens=max_new_tokens_to_generate, + do_sample=False, # Use greedy decoding to match the exported model + ) + + eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True) + self.assertEqual(export_generated_text, eager_generated_text) + @require_read_token @tooslow def test_model_9b_bf16_flex_attention(self):