Skip to content

Commit 23a6ef2

Browse files
devesh-2002stevhliu
authored and
jaycha
committed
Improvements in Gemma2 model card (huggingface#37076)
* Improved Model card for Gemma2 * Made changes in gemma2 as suggested * Made more changes in the doc (adding image, notes, closing hfoptions) * minor fixes --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent 8a6fb39 commit 23a6ef2

File tree

1 file changed

+113
-16
lines changed

1 file changed

+113
-16
lines changed

docs/source/en/model_doc/gemma2.md

+113-16
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,133 @@ specific language governing permissions and limitations under the License.
1414
rendered properly in your Markdown viewer.
1515
1616
-->
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
21+
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
22+
">
23+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
24+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
25+
</div>
26+
</div>
1727

1828
# Gemma2
1929

20-
<div class="flex flex-wrap space-x-1">
21-
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
22-
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
23-
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
24-
</div>
30+
[Gemma 2](https://huggingface.co/papers/2408.00118) is a family of language models with pretrained and instruction-tuned variants, available in 2B, 9B, 27B parameters. The architecture is similar to the previous Gemma, except it features interleaved local attention (4096 tokens) and global attention (8192 tokens) and grouped-query attention (GQA) to increase inference performance.
31+
32+
The 2B and 9B models are trained with knowledge distillation, and the instruction-tuned variant was post-trained with supervised fine-tuning and reinforcement learning.
33+
34+
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) collection.
35+
36+
> [!TIP]
37+
> Click on the Gemma 2 models in the right sidebar for more examples of how to apply Gemma to different language tasks.
38+
39+
The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line.
40+
41+
<hfoptions id="usage">
42+
<hfoption id="Pipeline">
43+
44+
45+
```python
46+
import torch
47+
from transformers import pipeline
48+
49+
pipe = pipeline(
50+
task="text-generation",
51+
model="google/gemma-2-9b",
52+
torch_dtype=torch.bfloat16,
53+
device="cuda",
54+
)
55+
56+
pipe("Explain quantum computing simply. ", max_new_tokens=50)
57+
```
2558

26-
## Overview
59+
</hfoption>
60+
<hfoption id="AutoModel">
61+
62+
```python
63+
import torch
64+
from transformers import AutoTokenizer, AutoModelForCausalLM
2765

28-
The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by Gemma2 Team, Google.
29-
Two Gemma2 models are released, with parameters sizes of 9 billion (9B) and 27 billion (27B).
66+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
67+
model = AutoModelForCausalLM.from_pretrained(
68+
"google/gemma-2-9b",
69+
torch_dtype=torch.bfloat16,
70+
device_map="auto",
71+
attn_implementation="sdpa"
72+
)
3073

31-
The abstract from the blog post is the following:
74+
input_text = "Explain quantum computing simply."
75+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
3276

33-
*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.*
77+
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
78+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
3479

35-
Tips:
80+
```
3681

37-
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`
82+
</hfoption>
83+
<hfoption id="transformers-cli">
84+
85+
```
86+
echo -e "Explain quantum computing simply." | transformers-cli run --task text-generation --model google/gemma-2-2b --device 0
87+
```
88+
</hfoption>
89+
</hfoptions>
90+
91+
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
92+
93+
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
94+
95+
```python
96+
import torch
97+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
98+
99+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
100+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
101+
model = AutoModelForCausalLM.from_pretrained(
102+
"google/gemma-2-27b",
103+
torch_dtype=torch.bfloat16,
104+
device_map="auto",
105+
attn_implementation="sdpa"
106+
)
107+
108+
input_text = "Explain quantum computing simply."
109+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
110+
111+
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
112+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
113+
```
114+
115+
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
116+
117+
118+
```python
119+
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
120+
visualizer = AttentionMaskVisualizer("google/gemma-2b")
121+
visualizer("You are an assistant. Make sure you print me")
122+
```
123+
124+
<div class="flex justify-center">
125+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/gemma-2-attn-mask.png"/>
126+
</div>
38127

39-
<Tip warning={true}>
128+
## Notes
40129

41-
- Gemma2 uses sliding window attention every second layer, which makes it unsuitable for typical kv caching with [`~DynamicCache`] or tuples of tensors. To enable caching in Gemma2 forward call, you must initialize a [`~HybridCache`] instance and pass it as `past_key_values` to the forward call. Note, that you also have to prepare `cache_position` if the `past_key_values` already contains previous keys and values.
130+
- Use a [`HybridCache`] instance to enable caching in Gemma 2. Gemma 2 doesn't support kv-caching strategies like [`DynamicCache`] or tuples of tensors because it uses sliding window attention every second layer.
42131

43-
</Tip>
132+
```python
133+
from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
44134

45-
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().
135+
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
136+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
46137

138+
inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
139+
max_generated_length = inputs.input_ids.shape[1] + 10
140+
past_key_values = HybridCache(config=model.config, max_batch_size=1,
141+
max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
142+
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
143+
```
47144

48145
## Gemma2Config
49146

0 commit comments

Comments
 (0)