Skip to content

Commit 5cd6b64

Browse files
authored
Process inputs directly in apply_chat_template in image-text-to-text pipeline (#35616)
* tokenize inputs directly in apply_chat_template * refactor processing * revert changes processing llava * Update docs * fix issue with str being iterable * add test chat text only * change function name
1 parent 80ea2c0 commit 5cd6b64

File tree

3 files changed

+186
-54
lines changed

3 files changed

+186
-54
lines changed

docs/source/en/tasks/image_text_to_text.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,48 @@ outputs[0]["generated_text"]
160160
# with a yellow center in the foreground. The flower is surrounded by red and white flowers with green stems
161161
```
162162

163-
## Streaming
163+
If you prefer, you can also load the images separately and pass them to the pipeline like so:
164+
165+
```python
166+
pipe = pipeline("image-text-to-text", model="HuggingFaceTB/SmolVLM-256M-Instruct")
167+
168+
img_urls = [
169+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png",
170+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
171+
]
172+
images = [
173+
Image.open(requests.get(img_urls[0], stream=True).raw),
174+
Image.open(requests.get(img_urls[1], stream=True).raw),
175+
]
176+
177+
messages = [
178+
{
179+
"role": "user",
180+
"content": [
181+
{"type": "image"},
182+
{"type": "image"},
183+
{"type": "text", "text": "What do you see in these images?"},
184+
],
185+
}
186+
]
187+
outputs = pipe(text=messages, images=images, max_new_tokens=50, return_full_text=False)
188+
outputs[0]["generated_text"]
189+
" In the first image, there are two cats sitting on a plant. In the second image, there are flowers with a pinkish hue."
190+
```
191+
192+
The images will still be included in the `"input_text"` field of the output:
193+
194+
```python
195+
outputs[0]['input_text']
196+
"""
197+
[{'role': 'user',
198+
'content': [{'type': 'image',
199+
'image': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=622x412>},
200+
{'type': 'image',
201+
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=5184x3456>},
202+
{'type': 'text', 'text': 'What do you see in these images?'}]}]## Streaming
203+
"""
204+
```
164205

165206
We can use [text streaming](./generation_strategies#streaming) for a better generation experience. Transformers supports streaming with the [`TextStreamer`] or [`TextIteratorStreamer`] classes. We will use the [`TextIteratorStreamer`] with IDEFICS-8B.
166207

src/transformers/pipelines/image_text_to_text.py

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -58,60 +58,56 @@ def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image",
5858
for message in messages:
5959
if not ("role" in message and "content" in message):
6060
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
61-
images = retrieve_images_in_messages(messages, images)
61+
messages = add_images_to_messages(messages, images)
6262

6363
self.messages = messages
64-
self.images = images
6564

6665

67-
def retrieve_images_in_messages(
66+
def add_images_to_messages(
6867
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
6968
):
7069
"""
7170
Retrieve and combine images from the chat and the images passed as input.
7271
"""
7372
if images is None:
7473
images = []
75-
elif not isinstance(images, Iterable):
74+
elif not isinstance(images, Iterable) or isinstance(images, str):
7675
images = [images]
7776
idx_images = 0
78-
retrieved_images = []
7977
for message in messages:
8078
for content in message["content"]:
81-
if isinstance(content, dict):
82-
if content.get("type") == "image":
83-
for key in ["image", "url", "path", "base64"]:
84-
if key in content:
85-
retrieved_images.append(content[key])
86-
break
87-
else:
88-
if idx_images < len(images):
89-
retrieved_images.append(images[idx_images])
90-
idx_images += 1
91-
else:
92-
raise ValueError(
93-
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
94-
)
95-
# Add support for OpenAI/TGI chat format
96-
elif content.get("type") == "image_url":
97-
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
98-
retrieved_images.append(content["image_url"]["url"])
99-
# Rewrite content to be in the Transformers chat format
100-
content["type"] = "image"
101-
content["image"] = content["image_url"]["url"]
102-
del content["image_url"]
79+
if not isinstance(content, dict):
80+
continue
81+
content_type = content.get("type")
82+
if content_type == "image":
83+
if not any(key in content for key in ["image", "url", "path", "base64"]):
84+
if idx_images < len(images):
85+
# Insert the image passed as argument in the chat message
86+
content["image"] = images[idx_images]
87+
idx_images += 1
10388
else:
10489
raise ValueError(
105-
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
90+
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
10691
)
92+
# Add support for OpenAI/TGI chat format
93+
elif content_type == "image_url":
94+
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
95+
# Rewrite content to be in the Transformers chat format
96+
content["type"] = "image"
97+
content["image"] = content["image_url"]["url"]
98+
del content["image_url"]
99+
else:
100+
raise ValueError(
101+
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
102+
)
107103

108104
# The number of images passed should be consistent with the number of images in the chat without an image key
109105
if idx_images != len(images):
110106
raise ValueError(
111107
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
112108
)
113109

114-
return retrieved_images
110+
return messages
115111

116112

117113
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
@@ -331,32 +327,30 @@ def __call__(
331327
return super().__call__({"images": images, "text": text}, **kwargs)
332328

333329
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
330+
if isinstance(inputs, Chat):
331+
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
332+
# because very few models support multiple separate, consecutive assistant messages
333+
if continue_final_message is None:
334+
continue_final_message = inputs.messages[-1]["role"] == "assistant"
335+
model_inputs = self.processor.apply_chat_template(
336+
inputs.messages,
337+
add_generation_prompt=not continue_final_message,
338+
continue_final_message=continue_final_message,
339+
return_tensors=self.framework,
340+
tokenize=True,
341+
return_dict=True,
342+
)
343+
model_inputs["text"] = inputs
344+
return model_inputs
334345
# In case we only have text inputs
335346
if isinstance(inputs, (list, tuple, str)):
336347
images = None
337348
text = inputs
338349
inputs_text = inputs
339350
else:
340-
if isinstance(inputs, Chat):
341-
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
342-
# because very few models support multiple separate, consecutive assistant messages
343-
if continue_final_message is None:
344-
continue_final_message = inputs.messages[-1]["role"] == "assistant"
345-
text = self.processor.apply_chat_template(
346-
inputs.messages,
347-
add_generation_prompt=not continue_final_message,
348-
continue_final_message=continue_final_message,
349-
return_tensors=self.framework,
350-
**processing_kwargs,
351-
)
352-
inputs_text = inputs
353-
images = inputs.images
354-
else:
355-
text = inputs["text"]
356-
inputs_text = inputs["text"]
357-
images = inputs["images"]
358-
359-
images = load_images(images, timeout=timeout)
351+
images = load_images(inputs["images"], timeout=timeout)
352+
text = inputs["text"]
353+
inputs_text = inputs["text"]
360354

361355
# if batched text inputs, we set padding to True unless specified otherwise
362356
if isinstance(text, (list, tuple)) and len(text) > 1:

tests/pipelines/test_pipelines_image_text_to_text.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,78 @@ def run_pipeline_test(self, pipe, examples):
6666
],
6767
)
6868

69+
@require_torch
70+
def test_small_model_pt_token_text_only(self):
71+
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
72+
text = "What is the capital of France? Assistant:"
73+
74+
outputs = pipe(text=text)
75+
self.assertEqual(
76+
outputs,
77+
[
78+
{
79+
"input_text": "What is the capital of France? Assistant:",
80+
"generated_text": "What is the capital of France? Assistant: The capital of France is Paris.",
81+
}
82+
],
83+
)
84+
85+
messages = [
86+
[
87+
{
88+
"role": "user",
89+
"content": [
90+
{"type": "text", "text": "Write a poem on Hugging Face, the company"},
91+
],
92+
},
93+
],
94+
[
95+
{
96+
"role": "user",
97+
"content": [
98+
{"type": "text", "text": "What is the capital of France?"},
99+
],
100+
},
101+
],
102+
]
103+
outputs = pipe(text=messages)
104+
self.assertEqual(
105+
outputs,
106+
[
107+
[
108+
{
109+
"input_text": [
110+
{
111+
"role": "user",
112+
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
113+
}
114+
],
115+
"generated_text": [
116+
{
117+
"role": "user",
118+
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
119+
},
120+
{
121+
"role": "assistant",
122+
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom",
123+
},
124+
],
125+
}
126+
],
127+
[
128+
{
129+
"input_text": [
130+
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}
131+
],
132+
"generated_text": [
133+
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
134+
{"role": "assistant", "content": "Paris"},
135+
],
136+
}
137+
],
138+
],
139+
)
140+
69141
@require_torch
70142
def test_small_model_pt_token(self):
71143
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
@@ -124,7 +196,7 @@ def test_model_pt_chat_template(self):
124196
],
125197
}
126198
]
127-
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=False, max_new_tokens=10)
199+
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=True, max_new_tokens=10)
128200
self.assertEqual(
129201
outputs,
130202
[
@@ -134,12 +206,37 @@ def test_model_pt_chat_template(self):
134206
"role": "user",
135207
"content": [
136208
{"type": "text", "text": "What’s the difference between these two images?"},
137-
{"type": "image"},
138-
{"type": "image"},
209+
{
210+
"type": "image",
211+
"image": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
212+
},
213+
{
214+
"type": "image",
215+
"image": "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
216+
},
139217
],
140218
}
141219
],
142-
"generated_text": "The first image shows a statue of Liberty in the",
220+
"generated_text": [
221+
{
222+
"role": "user",
223+
"content": [
224+
{"type": "text", "text": "What’s the difference between these two images?"},
225+
{
226+
"type": "image",
227+
"image": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
228+
},
229+
{
230+
"type": "image",
231+
"image": "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
232+
},
233+
],
234+
},
235+
{
236+
"role": "assistant",
237+
"content": "The first image shows a statue of Liberty in the",
238+
},
239+
],
143240
}
144241
],
145242
)

0 commit comments

Comments
 (0)