@@ -58,60 +58,56 @@ def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image",
58
58
for message in messages :
59
59
if not ("role" in message and "content" in message ):
60
60
raise ValueError ("When passing chat dicts as input, each dict must have a 'role' and 'content' key." )
61
- images = retrieve_images_in_messages (messages , images )
61
+ messages = add_images_to_messages (messages , images )
62
62
63
63
self .messages = messages
64
- self .images = images
65
64
66
65
67
- def retrieve_images_in_messages (
66
+ def add_images_to_messages (
68
67
messages : dict , images : Optional [Union [str , List [str ], "Image.Image" , List ["Image.Image" ]]]
69
68
):
70
69
"""
71
70
Retrieve and combine images from the chat and the images passed as input.
72
71
"""
73
72
if images is None :
74
73
images = []
75
- elif not isinstance (images , Iterable ):
74
+ elif not isinstance (images , Iterable ) or isinstance ( images , str ) :
76
75
images = [images ]
77
76
idx_images = 0
78
- retrieved_images = []
79
77
for message in messages :
80
78
for content in message ["content" ]:
81
- if isinstance (content , dict ):
82
- if content .get ("type" ) == "image" :
83
- for key in ["image" , "url" , "path" , "base64" ]:
84
- if key in content :
85
- retrieved_images .append (content [key ])
86
- break
87
- else :
88
- if idx_images < len (images ):
89
- retrieved_images .append (images [idx_images ])
90
- idx_images += 1
91
- else :
92
- raise ValueError (
93
- "The number of images in the chat messages should be the same as the number of images passed to the pipeline."
94
- )
95
- # Add support for OpenAI/TGI chat format
96
- elif content .get ("type" ) == "image_url" :
97
- if isinstance (content .get ("image_url" ), dict ) and "url" in content ["image_url" ]:
98
- retrieved_images .append (content ["image_url" ]["url" ])
99
- # Rewrite content to be in the Transformers chat format
100
- content ["type" ] = "image"
101
- content ["image" ] = content ["image_url" ]["url" ]
102
- del content ["image_url" ]
79
+ if not isinstance (content , dict ):
80
+ continue
81
+ content_type = content .get ("type" )
82
+ if content_type == "image" :
83
+ if not any (key in content for key in ["image" , "url" , "path" , "base64" ]):
84
+ if idx_images < len (images ):
85
+ # Insert the image passed as argument in the chat message
86
+ content ["image" ] = images [idx_images ]
87
+ idx_images += 1
103
88
else :
104
89
raise ValueError (
105
- "Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key ."
90
+ "The number of images in the chat messages should be the same as the number of images passed to the pipeline ."
106
91
)
92
+ # Add support for OpenAI/TGI chat format
93
+ elif content_type == "image_url" :
94
+ if isinstance (content .get ("image_url" ), dict ) and "url" in content ["image_url" ]:
95
+ # Rewrite content to be in the Transformers chat format
96
+ content ["type" ] = "image"
97
+ content ["image" ] = content ["image_url" ]["url" ]
98
+ del content ["image_url" ]
99
+ else :
100
+ raise ValueError (
101
+ "Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
102
+ )
107
103
108
104
# The number of images passed should be consistent with the number of images in the chat without an image key
109
105
if idx_images != len (images ):
110
106
raise ValueError (
111
107
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
112
108
)
113
109
114
- return retrieved_images
110
+ return messages
115
111
116
112
117
113
@add_end_docstrings (build_pipeline_init_args (has_processor = True ))
@@ -331,32 +327,30 @@ def __call__(
331
327
return super ().__call__ ({"images" : images , "text" : text }, ** kwargs )
332
328
333
329
def preprocess (self , inputs = None , timeout = None , continue_final_message = None , ** processing_kwargs ):
330
+ if isinstance (inputs , Chat ):
331
+ # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
332
+ # because very few models support multiple separate, consecutive assistant messages
333
+ if continue_final_message is None :
334
+ continue_final_message = inputs .messages [- 1 ]["role" ] == "assistant"
335
+ model_inputs = self .processor .apply_chat_template (
336
+ inputs .messages ,
337
+ add_generation_prompt = not continue_final_message ,
338
+ continue_final_message = continue_final_message ,
339
+ return_tensors = self .framework ,
340
+ tokenize = True ,
341
+ return_dict = True ,
342
+ )
343
+ model_inputs ["text" ] = inputs
344
+ return model_inputs
334
345
# In case we only have text inputs
335
346
if isinstance (inputs , (list , tuple , str )):
336
347
images = None
337
348
text = inputs
338
349
inputs_text = inputs
339
350
else :
340
- if isinstance (inputs , Chat ):
341
- # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
342
- # because very few models support multiple separate, consecutive assistant messages
343
- if continue_final_message is None :
344
- continue_final_message = inputs .messages [- 1 ]["role" ] == "assistant"
345
- text = self .processor .apply_chat_template (
346
- inputs .messages ,
347
- add_generation_prompt = not continue_final_message ,
348
- continue_final_message = continue_final_message ,
349
- return_tensors = self .framework ,
350
- ** processing_kwargs ,
351
- )
352
- inputs_text = inputs
353
- images = inputs .images
354
- else :
355
- text = inputs ["text" ]
356
- inputs_text = inputs ["text" ]
357
- images = inputs ["images" ]
358
-
359
- images = load_images (images , timeout = timeout )
351
+ images = load_images (inputs ["images" ], timeout = timeout )
352
+ text = inputs ["text" ]
353
+ inputs_text = inputs ["text" ]
360
354
361
355
# if batched text inputs, we set padding to True unless specified otherwise
362
356
if isinstance (text , (list , tuple )) and len (text ) > 1 :
0 commit comments