Skip to content

Commit 41ccd3d

Browse files
author
Guang Yang
committed
Gemma3 is Torch Exportable
1 parent 5cd6b64 commit 41ccd3d

File tree

2 files changed

+255
-2
lines changed

2 files changed

+255
-2
lines changed

src/transformers/models/gemma3/modeling_gemma3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def forward(
410410
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
411411
offset = cache_position[-1] - effective_seq_len + 1
412412
# Should only be used when beyond the sliding window (i.e. offset > 0)
413-
offset = max(0, offset)
413+
offset = torch.clamp(offset, min=0)
414414
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
415415
# but without data-dependent slicing (i.e. torch.compile friendly)
416416
mask_indexes = torch.arange(

tests/models/gemma3/test_modeling_gemma3.py

+254-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from transformers.testing_utils import (
3131
cleanup,
3232
require_flash_attn,
33-
require_read_token,
3433
require_torch,
3534
require_torch_gpu,
3635
slow,
@@ -664,3 +663,257 @@ def test_generation_beyond_sliding_window_with_generation_config(self):
664663
model.generation_config.transformers_version = "4.49.0"
665664
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
666665
out = model.generate(**inputs, generation_config=generation_config)
666+
667+
def test_export_text_only_with_hybrid_cache(self):
668+
from transformers import HybridCache
669+
670+
class Gemma3ExportableModule(torch.nn.Module):
671+
"""
672+
A wrapper module designed to make Gemma3 models exportable with `torch.export`,
673+
specifically for use with HybridCache to support interleaved global and local attention.
674+
675+
This wrapper ensures that the exported model is compatible with further lowering
676+
and execution in frameworks like ExecuTorch.
677+
"""
678+
679+
def __init__(self, model: PreTrainedModel, max_batch_size: int = 1, max_seq_len: int = 4096):
680+
"""
681+
Initializes the wrapper module with the Gemma3 model.
682+
683+
Args:
684+
model (`PreTrainedModel`): The Gemma3 model to wrap.
685+
max_batch_size (int): Maximum batch size for the cache.
686+
max_seq_len (int): Maximum sequence length for the cache.
687+
688+
Raises:
689+
AssertionError: If the model doesn't have the expected configuration for HybridCache.
690+
"""
691+
super().__init__()
692+
self.model = model
693+
self.config = model.config
694+
695+
# Verify the model is configured for HybridCache
696+
assert self.config.cache_implementation == "hybrid", "Model must use 'hybrid' cache implementation"
697+
698+
# Verify sliding window configuration for local attention
699+
assert hasattr(self.config, "sliding_window"), "Model config must have sliding_window attribute"
700+
assert hasattr(self.config, "sliding_window_pattern"), (
701+
"Model config must have sliding_window_pattern attribute"
702+
)
703+
704+
# Initialize the HybridCache
705+
self.cache = HybridCache(
706+
config=self.config,
707+
max_batch_size=max_batch_size,
708+
max_cache_len=max_seq_len,
709+
device=model.device,
710+
dtype=model.dtype,
711+
)
712+
713+
# Register buffers for tracking state
714+
self.register_buffer("last_position", torch.tensor([-1], dtype=torch.long))
715+
716+
# Store the sliding window pattern for reference
717+
self.sliding_window = self.config.sliding_window
718+
self.sliding_window_pattern = self.config.sliding_window_pattern
719+
720+
# Determine which layers use global vs local attention
721+
# In Gemma3, typically every 6th layer (0-indexed) uses global attention
722+
self.global_attention_layers = [
723+
i
724+
for i in range(self.config.num_hidden_layers)
725+
if i % self.sliding_window_pattern == (self.sliding_window_pattern - 1)
726+
]
727+
728+
def forward(
729+
self,
730+
input_ids: torch.Tensor,
731+
cache_position: torch.Tensor,
732+
) -> torch.Tensor:
733+
"""
734+
Forward pass of the module, compatible with torch.export.
735+
736+
Args:
737+
input_ids (`torch.Tensor`): Tensor representing current input token id(s).
738+
cache_position (`torch.Tensor`): Tensor representing current position(s) in the cache.
739+
740+
Returns:
741+
torch.Tensor: Logits output from the model.
742+
"""
743+
batch_size, seq_len = input_ids.shape
744+
745+
# Update the last_position for tracking
746+
self.last_position = cache_position[-1].unsqueeze(0)
747+
748+
# Generate position_ids from cache_position
749+
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
750+
751+
# Create attention mask (always ones for token-by-token generation)
752+
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
753+
754+
# Forward pass with the model
755+
outputs = self.model(
756+
input_ids=input_ids,
757+
attention_mask=attention_mask,
758+
position_ids=position_ids,
759+
past_key_values=self.cache,
760+
use_cache=True,
761+
cache_position=cache_position,
762+
)
763+
764+
# Return only the logits to simplify the export
765+
return outputs.logits
766+
767+
@staticmethod
768+
def generate(
769+
exported_model: torch.export.ExportedProgram,
770+
tokenizer,
771+
prompt: str,
772+
max_new_tokens: int = 20,
773+
do_sample: bool = False,
774+
temperature: float = 1.0,
775+
top_k: int = 50,
776+
top_p: float = 1.0,
777+
device: str = "cpu",
778+
) -> str:
779+
"""
780+
Generate text using an exported Gemma3 model.
781+
782+
Args:
783+
exported_model (`torch.export.ExportedProgram`): The exported model being used for generate.
784+
tokenizer: The tokenizer to use.
785+
prompt (str): The input prompt.
786+
max_new_tokens (int): Maximum number of new tokens to generate.
787+
do_sample (bool): Whether to use sampling or greedy decoding.
788+
temperature (float): The temperature for sampling.
789+
top_k (int): The number of highest probability tokens to keep for top-k sampling.
790+
top_p (float): The cumulative probability for nucleus sampling.
791+
device (str): The device to use.
792+
793+
Returns:
794+
str: The generated text.
795+
"""
796+
# Get the module from the exported program
797+
exported_module = exported_model.module()
798+
799+
# Tokenize the prompt
800+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
801+
802+
# Initialize with the prompt
803+
generated_ids = input_ids.clone()
804+
805+
# Process the prompt tokens first
806+
curr_position = 0
807+
for i in range(input_ids.shape[1]):
808+
# Process one token at a time
809+
curr_input_ids = input_ids[:, i : i + 1]
810+
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
811+
812+
# Forward pass
813+
_ = exported_module(curr_input_ids, curr_cache_position)
814+
curr_position += 1
815+
816+
# Generate new tokens
817+
for _ in range(max_new_tokens):
818+
# Get the last token as input
819+
curr_input_ids = generated_ids[:, -1:]
820+
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
821+
822+
# Forward pass to get next token logits
823+
outputs = exported_module(curr_input_ids, curr_cache_position)
824+
825+
# Get the next token ID
826+
if do_sample:
827+
# Apply temperature
828+
if temperature > 0:
829+
logits = outputs / temperature
830+
else:
831+
logits = outputs
832+
833+
# Apply top-k filtering
834+
if top_k > 0:
835+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
836+
logits[indices_to_remove] = float("-inf")
837+
838+
# Apply top-p (nucleus) filtering
839+
if top_p < 1.0:
840+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
841+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
842+
843+
# Remove tokens with cumulative probability above the threshold
844+
sorted_indices_to_remove = cumulative_probs > top_p
845+
# Shift the indices to the right to keep also the first token above the threshold
846+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
847+
sorted_indices_to_remove[..., 0] = 0
848+
849+
# Scatter sorted tensors to original indexing
850+
indices_to_remove = sorted_indices_to_remove.scatter(
851+
-1, sorted_indices, sorted_indices_to_remove
852+
)
853+
logits[indices_to_remove] = float("-inf")
854+
855+
# Sample from the filtered distribution
856+
probs = torch.softmax(logits, dim=-1)
857+
next_token_id = torch.multinomial(probs, num_samples=1)
858+
else:
859+
# Greedy decoding
860+
next_token_id = outputs.argmax(dim=-1, keepdim=True)
861+
862+
# Ensure next_token_id has the right shape before concatenation
863+
if next_token_id.dim() > 2:
864+
next_token_id = next_token_id.squeeze(-1)
865+
866+
# Append to the generated sequence
867+
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
868+
curr_position += 1
869+
870+
# Stop if we generate an EOS token
871+
if next_token_id.item() == tokenizer.eos_token_id:
872+
break
873+
874+
# Decode the generated text
875+
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
876+
877+
model_id = "google/gemma-3-1b-it"
878+
tokenizer = AutoTokenizer.from_pretrained(model_id)
879+
model = AutoModelForCausalLM.from_pretrained(model_id)
880+
model.eval()
881+
print(f"Model config: {model.config}")
882+
883+
# Create a wrapper for export with static batch size
884+
wrapper = Gemma3ExportableModule(model)
885+
886+
# Prepare example inputs
887+
example_input_ids = torch.tensor([[1]], dtype=torch.long)
888+
example_cache_position = torch.tensor([0], dtype=torch.long)
889+
890+
# Export the model with static shapes
891+
exported_program = torch.export.export(
892+
wrapper,
893+
(example_input_ids, example_cache_position),
894+
strict=False,
895+
)
896+
print(f"Exported program: {exported_program}")
897+
898+
# Test generation with the exported model
899+
# prompt = "What is the capital of France?"
900+
prompt = "Write a poem about Machine Learning."
901+
max_new_tokens_to_generate = 100
902+
# Generate text with the exported model
903+
export_generated_text = Gemma3ExportableModule.generate(
904+
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
905+
)
906+
print(f"Export generated texts: '{export_generated_text}'")
907+
908+
input_text = tokenizer(prompt, return_tensors="pt")
909+
with torch.no_grad():
910+
eager_outputs = model.generate(
911+
**input_text,
912+
max_new_tokens=max_new_tokens_to_generate,
913+
do_sample=False, # Use greedy decoding to match the exported model
914+
)
915+
916+
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
917+
print(f"Eager generated texts: '{eager_generated_text}'")
918+
919+
self.assertEqual(export_generated_text, eager_generated_text)

0 commit comments

Comments
 (0)