Skip to content

Commit 6f389db

Browse files
author
Guang Yang
committed
Gemma3 is Torch Exportable
1 parent 5cd6b64 commit 6f389db

File tree

3 files changed

+326
-7
lines changed

3 files changed

+326
-7
lines changed

src/transformers/integrations/executorch.py

+284-4
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,207 @@
2020

2121

2222
if is_torch_available():
23-
from transformers import PreTrainedModel, StaticCache
23+
from transformers import HybridCache, PreTrainedModel, StaticCache
2424
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
2525

2626

27+
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
28+
"""
29+
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
30+
specifically for decoder-only LM with cache. This module ensures that the
31+
exported model is compatible with further lowering and execution in `ExecuTorch`.
32+
"""
33+
34+
def __init__(
35+
self,
36+
model: PreTrainedModel,
37+
max_batch_size: int = 1,
38+
max_cache_len: int = 4096,
39+
):
40+
"""
41+
Initializes the exportable module with `HybridCache`.
42+
43+
Args:
44+
model (`PreTrainedModel`): The pretrained model to wrap.
45+
max_batch_size (int): Maximum batch size for the cache.
46+
max_cache_len (int): Maximum sequence length for the cache.
47+
48+
Raises:
49+
ValueError: If the model is configured with a unsupported cache implementation.
50+
"""
51+
super().__init__()
52+
53+
if model.config.cache_implementation == "static":
54+
self.model = TorchExportableModuleWithStaticCache(model)
55+
elif model.config.cache_implementation == "hybrid":
56+
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
57+
else:
58+
raise ValueError(
59+
f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'"
60+
)
61+
62+
def forward(
63+
self,
64+
input_ids: torch.Tensor,
65+
cache_position: torch.Tensor,
66+
) -> torch.Tensor:
67+
"""
68+
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
69+
70+
Args:
71+
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
72+
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
73+
74+
Returns:
75+
torch.Tensor: Logits output from the model.
76+
"""
77+
return self.model.forward(input_ids, cache_position)
78+
79+
def export(
80+
self,
81+
input_ids: Optional[torch.Tensor] = None,
82+
cache_position: Optional[torch.Tensor] = None,
83+
dynamic_shapes: Optional[dict] = None,
84+
strict: Optional[bool] = None,
85+
) -> torch.export.ExportedProgram:
86+
"""
87+
Export the wrapped module using `torch.export`.
88+
89+
Args:
90+
input_ids (`Optional[torch.Tensor]`):
91+
Tensor representing current input token id to the module. If not provided, a default tensor will be used.
92+
cache_position (`Optional[torch.Tensor]`):
93+
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
94+
dynamic_shapes (`Optional[dict]`):
95+
Dynamic shapes to use for export if specified.
96+
strict(`Optional[bool]`):
97+
Flag to instruct `torch.export` to use `torchdynamo`.
98+
"""
99+
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
100+
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
101+
102+
return torch.export.export(
103+
self.model,
104+
args=(example_input_ids, example_cache_position),
105+
kwargs={},
106+
dynamic_shapes=dynamic_shapes,
107+
strict=strict if strict is not None else True,
108+
)
109+
110+
@staticmethod
111+
def generate(
112+
exported_program: torch.export.ExportedProgram,
113+
tokenizer,
114+
prompt: str,
115+
max_new_tokens: int = 20,
116+
do_sample: bool = False,
117+
temperature: float = 1.0,
118+
top_k: int = 50,
119+
top_p: float = 1.0,
120+
device: str = "cpu",
121+
) -> str:
122+
"""
123+
Generate a sequence of tokens using an exported program.
124+
125+
Args:
126+
exported_program (`torch.export.ExportedProgram`): The exported model being used for generate.
127+
tokenizer: The tokenizer to use.
128+
prompt (str): The input prompt.
129+
max_new_tokens (int): Maximum number of new tokens to generate.
130+
do_sample (bool): Whether to use sampling or greedy decoding.
131+
temperature (float): The temperature for sampling.
132+
top_k (int): The number of highest probability tokens to keep for top-k sampling.
133+
top_p (float): The cumulative probability for nucleus sampling.
134+
device (str): The device to use.
135+
136+
Returns:
137+
str: The generated text.
138+
"""
139+
# Get the module from the exported program
140+
exported_module = exported_program.module()
141+
142+
# Tokenize the prompt
143+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
144+
145+
# Initialize with the prompt
146+
generated_ids = input_ids.clone()
147+
148+
# Process the prompt tokens first
149+
curr_position = 0
150+
for i in range(input_ids.shape[1]):
151+
# Process one token at a time
152+
curr_input_ids = input_ids[:, i : i + 1]
153+
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
154+
155+
# Forward pass
156+
_ = exported_module(curr_input_ids, curr_cache_position)
157+
curr_position += 1
158+
159+
# Generate new tokens
160+
for _ in range(max_new_tokens):
161+
# Get the last token as input
162+
curr_input_ids = generated_ids[:, -1:]
163+
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
164+
165+
# Forward pass to get next token logits
166+
outputs = exported_module(curr_input_ids, curr_cache_position)
167+
168+
# Get the next token ID
169+
if do_sample:
170+
# Apply temperature
171+
if temperature > 0:
172+
logits = outputs / temperature
173+
else:
174+
logits = outputs
175+
176+
# Apply top-k filtering
177+
if top_k > 0:
178+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
179+
logits[indices_to_remove] = float("-inf")
180+
181+
# Apply top-p (nucleus) filtering
182+
if top_p < 1.0:
183+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
184+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
185+
186+
# Remove tokens with cumulative probability above the threshold
187+
sorted_indices_to_remove = cumulative_probs > top_p
188+
# Shift the indices to the right to keep also the first token above the threshold
189+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
190+
sorted_indices_to_remove[..., 0] = 0
191+
192+
# Scatter sorted tensors to original indexing
193+
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
194+
logits[indices_to_remove] = float("-inf")
195+
196+
# Sample from the filtered distribution
197+
probs = torch.softmax(logits, dim=-1)
198+
next_token_id = torch.multinomial(probs, num_samples=1)
199+
else:
200+
# Greedy decoding
201+
next_token_id = outputs.argmax(dim=-1, keepdim=True)
202+
203+
# Ensure next_token_id has the right shape before concatenation
204+
if next_token_id.dim() > 2:
205+
next_token_id = next_token_id.squeeze(-1)
206+
207+
# Append to the generated sequence
208+
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
209+
curr_position += 1
210+
211+
# Stop if we generate an EOS token
212+
if next_token_id.item() == tokenizer.eos_token_id:
213+
break
214+
215+
# Decode the generated text
216+
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
217+
218+
27219
class TorchExportableModuleWithStaticCache(torch.nn.Module):
28220
"""
29-
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
30-
specifically for use with static caching. This module ensures that the exported model
31-
is compatible with further lowering and execution in `ExecuTorch`.
221+
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
222+
specifically for decoder-only LM to `StaticCache`. This module ensures that the
223+
exported model is compatible with further lowering and execution in `ExecuTorch`.
32224
33225
Note:
34226
This class is specifically designed to support export process using `torch.export`
@@ -178,6 +370,94 @@ def generate(
178370
return torch.tensor([response_tokens], dtype=torch.long)
179371

180372

373+
class TorchExportableModuleWithHybridCache(torch.nn.Module):
374+
"""
375+
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
376+
specifically for decoder-only LM to `HybridCache`. This module ensures that the
377+
exported model is compatible with further lowering and execution in `ExecuTorch`.
378+
"""
379+
380+
def __init__(
381+
self,
382+
model: PreTrainedModel,
383+
max_batch_size: int = 1,
384+
max_cache_len: int = 4096,
385+
):
386+
"""
387+
Initializes the exportable module with `HybridCache`.
388+
389+
Args:
390+
model (`PreTrainedModel`): The pretrained model to wrap.
391+
max_batch_size (int): Maximum batch size for the cache.
392+
max_cache_len (int): Maximum sequence length for the cache.
393+
394+
Raises:
395+
AssertionError: If the model doesn't have the expected configuration for HybridCache.
396+
"""
397+
super().__init__()
398+
self.model = model
399+
400+
# Verify the model is configured for HybridCache
401+
if not self.model.config.use_cache:
402+
raise AssertionError("Model must have caching enabled")
403+
404+
if (
405+
not hasattr(self.model.config, "cache_implementation")
406+
or self.model.config.cache_implementation != "hybrid"
407+
):
408+
raise AssertionError("Model must use 'hybrid' cache implementation")
409+
410+
# Initialize the HybridCache
411+
self.cache = HybridCache(
412+
config=self.model.config,
413+
max_batch_size=max_batch_size,
414+
max_cache_len=max_cache_len,
415+
device=self.model.device,
416+
dtype=self.model.dtype,
417+
)
418+
419+
# Register all key and value cache tensors as buffers
420+
for i in range(len(self.cache.key_cache)):
421+
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
422+
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)
423+
424+
def forward(
425+
self,
426+
input_ids: torch.Tensor,
427+
cache_position: torch.Tensor,
428+
) -> torch.Tensor:
429+
"""
430+
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
431+
432+
Args:
433+
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
434+
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
435+
436+
Returns:
437+
torch.Tensor: Logits output from the model.
438+
"""
439+
batch_size, seq_len = input_ids.shape
440+
441+
# Generate position_ids from cache_position
442+
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
443+
444+
# Create attention mask (always ones for token-by-token generation)
445+
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
446+
447+
# Forward pass with the model
448+
outputs = self.model(
449+
input_ids=input_ids,
450+
attention_mask=attention_mask,
451+
position_ids=position_ids,
452+
past_key_values=self.cache,
453+
use_cache=True,
454+
cache_position=cache_position,
455+
)
456+
457+
# Return only the logits to simplify the export
458+
return outputs.logits
459+
460+
181461
def convert_and_export_with_cache(
182462
model: PreTrainedModel,
183463
example_input_ids: Optional[torch.Tensor] = None,

src/transformers/models/gemma3/modeling_gemma3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def forward(
410410
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
411411
offset = cache_position[-1] - effective_seq_len + 1
412412
# Should only be used when beyond the sliding window (i.e. offset > 0)
413-
offset = max(0, offset)
413+
offset = torch.clamp(offset, min=0)
414414
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
415415
# but without data-dependent slicing (i.e. torch.compile friendly)
416416
mask_indexes = torch.arange(

tests/models/gemma3/test_modeling_gemma3.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Testing suite for the PyTorch Gemma3 model."""
1515

16+
import logging
1617
import tempfile
1718
import unittest
1819

@@ -30,7 +31,6 @@
3031
from transformers.testing_utils import (
3132
cleanup,
3233
require_flash_attn,
33-
require_read_token,
3434
require_torch,
3535
require_torch_gpu,
3636
slow,
@@ -52,6 +52,7 @@
5252
Gemma3Processor,
5353
Gemma3TextModel,
5454
)
55+
from transformers.pytorch_utils import is_torch_greater_or_equal
5556

5657

5758
class Gemma3ModelTester(GemmaModelTester):
@@ -360,7 +361,6 @@ def test_automodelforcausallm(self):
360361

361362
@slow
362363
@require_torch_gpu
363-
@require_read_token
364364
class Gemma3IntegrationTest(unittest.TestCase):
365365
def setUp(self):
366366
self.processor = Gemma3Processor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
@@ -664,3 +664,42 @@ def test_generation_beyond_sliding_window_with_generation_config(self):
664664
model.generation_config.transformers_version = "4.49.0"
665665
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
666666
out = model.generate(**inputs, generation_config=generation_config)
667+
668+
def test_export_text_only_with_hybrid_cache(self):
669+
if not is_torch_greater_or_equal("2.6.0"):
670+
self.skipTest(reason="This test requires torch >= 2.6 to run.")
671+
672+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
673+
674+
model_id = "google/gemma-3-1b-it"
675+
model = AutoModelForCausalLM.from_pretrained(model_id)
676+
self.assertEqual(model.config.cache_implementation, "hybrid")
677+
678+
# Export + HybridCache
679+
model.eval()
680+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
681+
exported_program = exportable_module.export()
682+
logging.info(f"\nExported program: {exported_program}")
683+
684+
# Test generation with the exported model
685+
prompt = "What is the capital of France?"
686+
max_new_tokens_to_generate = 20
687+
# Generate text with the exported model
688+
tokenizer = AutoTokenizer.from_pretrained(model_id)
689+
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
690+
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
691+
)
692+
logging.info(f"\nExport generated texts: '{export_generated_text}'")
693+
694+
input_text = tokenizer(prompt, return_tensors="pt")
695+
with torch.no_grad():
696+
eager_outputs = model.generate(
697+
**input_text,
698+
max_new_tokens=max_new_tokens_to_generate,
699+
do_sample=False, # Use greedy decoding to match the exported model
700+
)
701+
702+
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
703+
logging.info(f"\nEager generated texts: '{eager_generated_text}'")
704+
705+
self.assertEqual(export_generated_text, eager_generated_text)

0 commit comments

Comments
 (0)