Skip to content

Commit 12df5b3

Browse files
authored
Merge pull request #1000 from andreadimaio/main
Enable watsonx.ai to process ImageContent in UserMessage
2 parents d21b911 + 56d73de commit 12df5b3

File tree

3 files changed

+95
-2
lines changed

3 files changed

+95
-2
lines changed

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
import com.github.tomakehurst.wiremock.stubbing.Scenario;
2121

2222
import dev.langchain4j.agent.tool.Tool;
23+
import dev.langchain4j.data.image.Image;
2324
import dev.langchain4j.data.message.AiMessage;
25+
import dev.langchain4j.data.message.ImageContent;
26+
import dev.langchain4j.data.message.ImageContent.DetailLevel;
27+
import dev.langchain4j.data.message.TextContent;
2428
import dev.langchain4j.data.message.ToolExecutionResultMessage;
2529
import dev.langchain4j.service.MemoryId;
2630
import dev.langchain4j.service.SystemMessage;
@@ -72,6 +76,12 @@ interface AIService {
7276
Multi<String> streaming(String text);
7377
}
7478

79+
@Singleton
80+
@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
81+
interface ImageDescriptor {
82+
String chat(Image image, @UserMessage String text);
83+
}
84+
7585
@Singleton
7686
@RegisterAiService(tools = Calculator.class)
7787
@SystemMessage("This is a systemMessage")
@@ -84,6 +94,9 @@ interface AIServiceWithTool {
8494
@Inject
8595
AIService aiService;
8696

97+
@Inject
98+
ImageDescriptor imageDescriptor;
99+
87100
@Inject
88101
AIServiceWithTool aiServiceWithTool;
89102

@@ -92,7 +105,6 @@ interface AIServiceWithTool {
92105

93106
@Singleton
94107
static class Calculator {
95-
96108
@Tool("Execute the sum of two numbers")
97109
public int sum(int first, int second) {
98110
return first + second;
@@ -125,6 +137,53 @@ void chat() throws Exception {
125137
assertEquals("AI Response", aiService.chat("Hello"));
126138
}
127139

140+
@Test
141+
void chat_with_image() throws Exception {
142+
143+
var messages = List.<TextChatMessage> of(
144+
TextChatMessageUser.of(
145+
dev.langchain4j.data.message.UserMessage.from(
146+
TextContent.from("Tell me more about this image"),
147+
ImageContent.from("test", "jpeg", DetailLevel.LOW))));
148+
149+
var RESPONSE = """
150+
{
151+
"id": "chat-0102753c2c33412fa639a6b0eb5401da",
152+
"model_id": "meta-llama/llama-3-2-90b-vision-instruct",
153+
"choices": [
154+
{
155+
"index": 0,
156+
"message": {
157+
"role": "assistant",
158+
"content": "The image depicts a white cat with yellow eyes."
159+
},
160+
"finish_reason": "stop"
161+
}
162+
],
163+
"created": 1729517211,
164+
"model_version": "3.2.0",
165+
"created_at": "2024-10-21T13:26:57.471Z",
166+
"usage": {
167+
"completion_tokens": 123,
168+
"prompt_tokens": 6422,
169+
"total_tokens": 6545
170+
}
171+
}""";
172+
173+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
174+
.body(mapper.writeValueAsString(generateChatRequest(messages, null)))
175+
.response(RESPONSE)
176+
.build();
177+
178+
Image image = Image.builder()
179+
.base64Data("test")
180+
.mimeType("jpeg")
181+
.build();
182+
183+
assertEquals("The image depicts a white cat with yellow eyes.",
184+
imageDescriptor.chat(image, "Tell me more about this image"));
185+
}
186+
128187
@Test
129188
void chat_with_tool() throws Exception {
130189

model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package io.quarkiverse.langchain4j.watsonx;
22

3+
import java.util.Base64;
4+
import java.util.Objects;
35
import java.util.Optional;
46
import java.util.concurrent.Callable;
57

68
import jakarta.ws.rs.WebApplicationException;
79

10+
import dev.langchain4j.data.image.Image;
11+
import dev.langchain4j.internal.Utils;
812
import io.quarkiverse.langchain4j.watsonx.bean.WatsonxError;
913
import io.quarkiverse.langchain4j.watsonx.exception.WatsonxException;
1014

@@ -44,4 +48,20 @@ public static <T> T retryOn(Callable<T> action) {
4448
}
4549
throw new RuntimeException("Failed after " + maxAttempts + " attempts");
4650
}
51+
52+
public static String base64Image(Image image) {
53+
54+
if (Objects.nonNull(image.base64Data()))
55+
return image.base64Data();
56+
57+
try {
58+
byte[] bytes = switch (image.url().getScheme()) {
59+
case "http", "https", "file" -> Utils.readBytes(image.url().toString());
60+
default -> throw new RuntimeException("The only supported image schemes are: [http, https, file]");
61+
};
62+
return Base64.getEncoder().encodeToString(bytes);
63+
} catch (Exception e) {
64+
throw new RuntimeException("Error converting the image to base64, see the log for more details", e);
65+
}
66+
}
4767
}

model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package io.quarkiverse.langchain4j.watsonx.bean;
22

3+
import static io.quarkiverse.langchain4j.watsonx.WatsonxUtils.base64Image;
4+
35
import java.util.ArrayList;
46
import java.util.List;
57
import java.util.Map;
@@ -9,6 +11,7 @@
911
import dev.langchain4j.data.message.AiMessage;
1012
import dev.langchain4j.data.message.ChatMessage;
1113
import dev.langchain4j.data.message.Content;
14+
import dev.langchain4j.data.message.ImageContent;
1215
import dev.langchain4j.data.message.SystemMessage;
1316
import dev.langchain4j.data.message.TextContent;
1417
import dev.langchain4j.data.message.ToolExecutionResultMessage;
@@ -125,7 +128,18 @@ public static TextChatMessageUser of(UserMessage userMessage) {
125128
"type", "text",
126129
"text", textContent.text()));
127130
}
128-
case AUDIO, IMAGE, PDF, TEXT_FILE, VIDEO ->
131+
case IMAGE -> {
132+
var imageContent = ImageContent.class.cast(content);
133+
var base64 = "data:image/%s;base64,%s".formatted(
134+
imageContent.image().mimeType(),
135+
base64Image(imageContent.image()));
136+
values.add(Map.of(
137+
"type", "image_url",
138+
"image_url", Map.of(
139+
"url", base64,
140+
"detail", imageContent.detailLevel().name().toLowerCase())));
141+
}
142+
case AUDIO, PDF, TEXT_FILE, VIDEO ->
129143
throw new UnsupportedOperationException("Unimplemented case: " + content.type());
130144
}
131145
}

0 commit comments

Comments
 (0)