Skip to content

Gemma3 is Torch Exportable #37728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 284 additions & 4 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,207 @@


if is_torch_available():
from transformers import PreTrainedModel, StaticCache
from transformers import HybridCache, PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3


class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"""
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for decoder-only LM with cache. This module ensures that the
exported model is compatible with further lowering and execution in `ExecuTorch`.
"""

def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the exportable module with `HybridCache`.

Args:
model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.

Raises:
ValueError: If the model is configured with a unsupported cache implementation.
"""
super().__init__()

if model.config.cache_implementation == "static":
self.model = TorchExportableModuleWithStaticCache(model)
elif model.config.cache_implementation == "hybrid":
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
else:
raise ValueError(
f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'"
)

def forward(
self,
input_ids: torch.Tensor,
cache_position: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the module, which is compatible with the ExecuTorch llm runner.

Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.

Returns:
torch.Tensor: Logits output from the model.
"""
return self.model.forward(input_ids, cache_position)

def export(
self,
input_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
dynamic_shapes: Optional[dict] = None,
strict: Optional[bool] = None,
) -> torch.export.ExportedProgram:
"""
Export the wrapped module using `torch.export`.

Args:
input_ids (`Optional[torch.Tensor]`):
Tensor representing current input token id to the module. If not provided, a default tensor will be used.
cache_position (`Optional[torch.Tensor]`):
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
dynamic_shapes (`Optional[dict]`):
Dynamic shapes to use for export if specified.
strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`.
"""
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)

return torch.export.export(
self.model,
args=(example_input_ids, example_cache_position),
kwargs={},
dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else True,
)

@staticmethod
def generate(
exported_program: torch.export.ExportedProgram,
tokenizer,
prompt: str,
max_new_tokens: int = 20,
do_sample: bool = False,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
device: str = "cpu",
) -> str:
"""
Generate a sequence of tokens using an exported program.

Args:
exported_program (`torch.export.ExportedProgram`): The exported model being used for generate.
tokenizer: The tokenizer to use.
prompt (str): The input prompt.
max_new_tokens (int): Maximum number of new tokens to generate.
do_sample (bool): Whether to use sampling or greedy decoding.
temperature (float): The temperature for sampling.
top_k (int): The number of highest probability tokens to keep for top-k sampling.
top_p (float): The cumulative probability for nucleus sampling.
device (str): The device to use.

Returns:
str: The generated text.
"""
# Get the module from the exported program
exported_module = exported_program.module()

# Tokenize the prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Initialize with the prompt
generated_ids = input_ids.clone()

# Process the prompt tokens first
curr_position = 0
for i in range(input_ids.shape[1]):
# Process one token at a time
curr_input_ids = input_ids[:, i : i + 1]
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)

# Forward pass
_ = exported_module(curr_input_ids, curr_cache_position)
curr_position += 1

# Generate new tokens
for _ in range(max_new_tokens):
# Get the last token as input
curr_input_ids = generated_ids[:, -1:]
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)

# Forward pass to get next token logits
outputs = exported_module(curr_input_ids, curr_cache_position)

# Get the next token ID
if do_sample:
# Apply temperature
if temperature > 0:
logits = outputs / temperature
else:
logits = outputs

# Apply top-k filtering
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float("-inf")

# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

# Scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float("-inf")

# Sample from the filtered distribution
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
else:
# Greedy decoding
next_token_id = outputs.argmax(dim=-1, keepdim=True)

# Ensure next_token_id has the right shape before concatenation
if next_token_id.dim() > 2:
next_token_id = next_token_id.squeeze(-1)

# Append to the generated sequence
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
curr_position += 1

# Stop if we generate an EOS token
if next_token_id.item() == tokenizer.eos_token_id:
break

# Decode the generated text
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)


class TorchExportableModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for use with static caching. This module ensures that the exported model
is compatible with further lowering and execution in `ExecuTorch`.
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for decoder-only LM to `StaticCache`. This module ensures that the
exported model is compatible with further lowering and execution in `ExecuTorch`.

Note:
This class is specifically designed to support export process using `torch.export`
Expand Down Expand Up @@ -178,6 +370,94 @@ def generate(
return torch.tensor([response_tokens], dtype=torch.long)


class TorchExportableModuleWithHybridCache(torch.nn.Module):
"""
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for decoder-only LM to `HybridCache`. This module ensures that the
exported model is compatible with further lowering and execution in `ExecuTorch`.
"""

def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the exportable module with `HybridCache`.

Args:
model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.

Raises:
AssertionError: If the model doesn't have the expected configuration for HybridCache.
"""
super().__init__()
self.model = model

# Verify the model is configured for HybridCache
if not self.model.config.use_cache:
raise AssertionError("Model must have caching enabled")

if (
not hasattr(self.model.config, "cache_implementation")
or self.model.config.cache_implementation != "hybrid"
):
raise AssertionError("Model must use 'hybrid' cache implementation")

# Initialize the HybridCache
self.cache = HybridCache(
config=self.model.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.model.device,
dtype=self.model.dtype,
)

# Register all key and value cache tensors as buffers
for i in range(len(self.cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)

def forward(
self,
input_ids: torch.Tensor,
cache_position: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the module, which is compatible with the ExecuTorch llm runner.

Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.

Returns:
torch.Tensor: Logits output from the model.
"""
batch_size, seq_len = input_ids.shape

# Generate position_ids from cache_position
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)

# Create attention mask (always ones for token-by-token generation)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device)

# Forward pass with the model
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)

# Return only the logits to simplify the export
return outputs.logits


def convert_and_export_with_cache(
model: PreTrainedModel,
example_input_ids: Optional[torch.Tensor] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma3/modular_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
Expand Down
Loading