Skip to content

Commit 387280b

Browse files
committed
Improved Model card for Gemma2
1 parent b7fc2da commit 387280b

File tree

1 file changed

+115
-11
lines changed

1 file changed

+115
-11
lines changed

docs/source/en/model_doc/gemma2.md

Lines changed: 115 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,138 @@ rendered properly in your Markdown viewer.
1616
-->
1717

1818
# Gemma2
19-
20-
<div class="flex flex-wrap space-x-1">
21-
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
22-
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
23-
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
19+
<div style="float: right;">
20+
<div class="flex flex-wrap space-x-1">
21+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
23+
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
24+
">
25+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
26+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
27+
</div>
2428
</div>
2529

2630
## Overview
2731

28-
The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by Gemma2 Team, Google.
29-
Two Gemma2 models are released, with parameters sizes of 9 billion (9B) and 27 billion (27B).
32+
**[Gemma 2](https://arxiv.org/pdf/2408.00118)** is Google's open-weight language model family (2B, 9B, 27B parameters) featuring interleaved local-global attention (4K sliding window + 8K global context), knowledge distillation for smaller models, and GQA for efficient inference. The 27B variant rivals models twice its size, scoring 75.2 on MMLU and 74.0 on GSM8K, while the instruction-tuned versions excel in multi-turn chat.
3033

31-
The abstract from the blog post is the following:
34+
Key improvements over Gemma 1 include deeper networks, logit soft-capping, and stricter safety filters (<0.1% memorization). Available in base and instruction-tuned variants.
3235

33-
*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.*
36+
The original checkpoints of Gemma 2 can be found [here](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315).
3437

35-
Tips:
38+
> [!TIP]
39+
> Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks.
3640
37-
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`
3841

3942
<Tip warning={true}>
4043

4144
- Gemma2 uses sliding window attention every second layer, which makes it unsuitable for typical kv caching with [`~DynamicCache`] or tuples of tensors. To enable caching in Gemma2 forward call, you must initialize a [`~HybridCache`] instance and pass it as `past_key_values` to the forward call. Note, that you also have to prepare `cache_position` if the `past_key_values` already contains previous keys and values.
4245

4346
</Tip>
4447

48+
4549
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().
4650

51+
<Tip>
52+
Click the right sidebar's Gemma 2 models for additional task examples.
53+
</Tip>
54+
55+
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
56+
57+
<hfoptions id="usage">
58+
<hfoption id="Pipeline">
59+
60+
61+
### Text Generation with `Pipeline`
62+
63+
```python
64+
from transformers import pipeline
65+
import torch
66+
67+
pipe = pipeline(
68+
"text-generation",
69+
model="google/gemma-2-9b-it",
70+
model_kwargs={"torch_dtype": torch.bfloat16},
71+
device="cuda",
72+
)
73+
74+
messages = [
75+
{"role": "user", "content": "Explain quantum computing simply"},
76+
]
77+
outputs = pipe(messages, max_new_tokens=256)
78+
print(outputs[0]["generated_text"])
79+
```
80+
### Text Generation with `AutoModel`
81+
```python
82+
from transformers import AutoTokenizer, AutoModelForCausalLM
83+
import torch
84+
85+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
86+
model = AutoModelForCausalLM.from_pretrained(
87+
"google/gemma-2-9b-it",
88+
device_map="auto",
89+
)
90+
91+
input_text = "Write me a poem about Machine Learning."
92+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
93+
94+
outputs = model.generate(**input_ids, max_new_tokens=32)
95+
print(tokenizer.decode(outputs[0]))
96+
```
97+
### Using `transformers-cli`
98+
```
99+
echo -e "Plants create energy through a process known as" | transformers-cli run --task text-generation --model google/gemma-2-2b --device 0
100+
```
101+
102+
### Quantized version through `bitsandbytes`
103+
104+
Quantization reduces model size and speeds up inference by converting high-precision numbers (e.g., 32-bit floats) to lower-precision formats (e.g., 8-bit integers), with minimal accuracy loss
105+
#### Using 8-bit precision (int8)
106+
107+
```python
108+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
109+
110+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
111+
112+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b-it")
113+
model = AutoModelForCausalLM.from_pretrained(
114+
"google/gemma-2-27b-it",
115+
quantization_config=quantization_config,
116+
)
117+
118+
input_text = "Write me a poem about Machine Learning."
119+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
120+
121+
outputs = model.generate(**input_ids, max_new_tokens=32)
122+
print(tokenizer.decode(outputs[0]))
123+
```
124+
#### Using 4-bit precision
125+
```python
126+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
127+
128+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
129+
130+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b-it")
131+
model = AutoModelForCausalLM.from_pretrained(
132+
"google/gemma-2-27b-it",
133+
quantization_config=quantization_config,
134+
)
135+
136+
input_text = "Write me a poem about Machine Learning."
137+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
138+
139+
outputs = model.generate(**input_ids, max_new_tokens=32)
140+
print(tokenizer.decode(outputs[0]))
141+
142+
```
143+
### AttentionMaskVisualizer
144+
145+
```python
146+
visualizer = AttentionMaskVisualizer("google/gemma-2b")
147+
visualizer("You are an assistant. Make sure you print me")
148+
```
149+
## Notes
150+
- Gemma 2's sliding window attention enables efficient long-context processing - see sidebar examples for >4K token use cases
47151

48152
## Gemma2Config
49153

0 commit comments

Comments
 (0)