Skip to content

Commit 6257220

Browse files
committed
Made changes in gemma2 as suggested
1 parent 387280b commit 6257220

File tree

1 file changed

+45
-63
lines changed

1 file changed

+45
-63
lines changed

docs/source/en/model_doc/gemma2.md

+45-63
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ specific language governing permissions and limitations under the License.
1414
rendered properly in your Markdown viewer.
1515
1616
-->
17-
18-
# Gemma2
1917
<div style="float: right;">
2018
<div class="flex flex-wrap space-x-1">
2119
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
@@ -27,16 +25,18 @@ rendered properly in your Markdown viewer.
2725
</div>
2826
</div>
2927

28+
# Gemma2
29+
3030
## Overview
3131

32-
**[Gemma 2](https://arxiv.org/pdf/2408.00118)** is Google's open-weight language model family (2B, 9B, 27B parameters) featuring interleaved local-global attention (4K sliding window + 8K global context), knowledge distillation for smaller models, and GQA for efficient inference. The 27B variant rivals models twice its size, scoring 75.2 on MMLU and 74.0 on GSM8K, while the instruction-tuned versions excel in multi-turn chat.
32+
[Gemma 2](https://huggingface.co/papers/2408.00118) is a family of language models with pretrained and instruction-tuned variants, available in 2B, 9B, 27B parameters. The architecture is similar to the previous Gemma, except it features interleaved local attention (4096 tokens) and global attention (8192 tokens) and grouped-query attention (GQA) to increase inference performance.
3333

34-
Key improvements over Gemma 1 include deeper networks, logit soft-capping, and stricter safety filters (<0.1% memorization). Available in base and instruction-tuned variants.
34+
The 2B and 9B models are trained with knowledge distillation, and the instruction-tuned variant was post-trained with supervised fine-tuning and reinforcement learning.
3535

36-
The original checkpoints of Gemma 2 can be found [here](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315).
36+
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) release.
3737

3838
> [!TIP]
39-
> Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks.
39+
> Click on the Gemma 2 models in the right sidebar for more examples of how to apply Gemma to different language tasks.
4040
4141

4242
<Tip warning={true}>
@@ -48,106 +48,88 @@ The original checkpoints of Gemma 2 can be found [here](https://huggingface.co/c
4848

4949
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().
5050

51-
<Tip>
52-
Click the right sidebar's Gemma 2 models for additional task examples.
53-
</Tip>
54-
55-
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
51+
The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line.
5652

5753
<hfoptions id="usage">
5854
<hfoption id="Pipeline">
5955

6056

61-
### Text Generation with `Pipeline`
62-
6357
```python
64-
from transformers import pipeline
6558
import torch
59+
from transformers import pipeline
6660

6761
pipe = pipeline(
68-
"text-generation",
69-
model="google/gemma-2-9b-it",
70-
model_kwargs={"torch_dtype": torch.bfloat16},
62+
task="text-generation",
63+
model="google/gemma-2-9b",
64+
torch_dtype=torch.bfloat16,
7165
device="cuda",
7266
)
7367

74-
messages = [
75-
{"role": "user", "content": "Explain quantum computing simply"},
76-
]
77-
outputs = pipe(messages, max_new_tokens=256)
78-
print(outputs[0]["generated_text"])
68+
pipe("Explain quantum computing simply. ", max_new_tokens=50)
7969
```
80-
### Text Generation with `AutoModel`
70+
71+
</hfoption>
72+
<hfoption id="AutoModel">
73+
8174
```python
82-
from transformers import AutoTokenizer, AutoModelForCausalLM
8375
import torch
76+
from transformers import AutoTokenizer, AutoModelForCausalLM
8477

85-
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
78+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
8679
model = AutoModelForCausalLM.from_pretrained(
87-
"google/gemma-2-9b-it",
80+
"google/gemma-2-9b",
81+
torch_dtype=torch.bfloat16,
8882
device_map="auto",
83+
attn_implementation="sdpa"
8984
)
9085

91-
input_text = "Write me a poem about Machine Learning."
86+
input_text = "Explain quantum computing simply."
9287
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
9388

94-
outputs = model.generate(**input_ids, max_new_tokens=32)
95-
print(tokenizer.decode(outputs[0]))
89+
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
90+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
91+
9692
```
97-
### Using `transformers-cli`
93+
94+
</hfoption>
95+
<hfoption id="transformers-cli">
96+
9897
```
99-
echo -e "Plants create energy through a process known as" | transformers-cli run --task text-generation --model google/gemma-2-2b --device 0
98+
echo -e "Explain quantum computing simply." | transformers-cli run --task text-generation --model google/gemma-2-2b --device 0
10099
```
101100

102-
### Quantized version through `bitsandbytes`
103-
104-
Quantization reduces model size and speeds up inference by converting high-precision numbers (e.g., 32-bit floats) to lower-precision formats (e.g., 8-bit integers), with minimal accuracy loss
105-
#### Using 8-bit precision (int8)
101+
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
102+
103+
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
106104

107105
```python
106+
import torch
108107
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
109108

110-
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
111-
112-
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b-it")
109+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
110+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
113111
model = AutoModelForCausalLM.from_pretrained(
114-
"google/gemma-2-27b-it",
115-
quantization_config=quantization_config,
112+
"google/gemma-2-27b",
113+
torch_dtype=torch.bfloat16,
114+
device_map="auto",
115+
attn_implementation="sdpa"
116116
)
117117

118-
input_text = "Write me a poem about Machine Learning."
118+
input_text = "Explain quantum computing simply."
119119
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
120120

121-
outputs = model.generate(**input_ids, max_new_tokens=32)
122-
print(tokenizer.decode(outputs[0]))
121+
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
122+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
123123
```
124-
#### Using 4-bit precision
125-
```python
126-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
127124

128-
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
125+
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
129126

130-
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b-it")
131-
model = AutoModelForCausalLM.from_pretrained(
132-
"google/gemma-2-27b-it",
133-
quantization_config=quantization_config,
134-
)
135-
136-
input_text = "Write me a poem about Machine Learning."
137-
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
138-
139-
outputs = model.generate(**input_ids, max_new_tokens=32)
140-
print(tokenizer.decode(outputs[0]))
141-
142-
```
143-
### AttentionMaskVisualizer
144127

145128
```python
129+
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
146130
visualizer = AttentionMaskVisualizer("google/gemma-2b")
147131
visualizer("You are an assistant. Make sure you print me")
148132
```
149-
## Notes
150-
- Gemma 2's sliding window attention enables efficient long-context processing - see sidebar examples for >4K token use cases
151133

152134
## Gemma2Config
153135

0 commit comments

Comments
 (0)