Skip to content

Commit a263342

Browse files
authored
Merge pull request #1288 from andreadimaio/main
Update watsonx.ai module to Langchain4j 1.0.0-beta1
2 parents cbeee42 + 101c14b commit a263342

20 files changed

+1504
-245
lines changed

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,6 @@ private TextChatRequest generateChatRequest(List<TextChatMessage> messages, List
452452
.timeLimit(WireMockUtil.DEFAULT_TIME_LIMIT)
453453
.build();
454454

455-
return new TextChatRequest(modelId, spaceId, projectId, messages, tools, null, parameters);
455+
return new TextChatRequest(modelId, spaceId, projectId, messages, tools, parameters);
456456
}
457457
}

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatAllPropertiesTest.java

Lines changed: 366 additions & 8 deletions
Large diffs are not rendered by default.

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatDefaultPropertiesTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ void check_config() throws Exception {
8787
assertEquals(Optional.empty(), runtimeConfig.chatModel().stop());
8888
assertEquals(1.0, runtimeConfig.chatModel().temperature());
8989
assertEquals(1.0, runtimeConfig.chatModel().topP());
90+
assertEquals(Optional.empty(), runtimeConfig.chatModel().toolChoice());
9091
assertEquals("urn:ibm:params:oauth:grant-type:apikey", runtimeConfig.iam().grantType());
9192
}
9293

@@ -101,7 +102,7 @@ void check_chat_model_config() throws Exception {
101102
TextChatMessageSystem.of("SystemMessage"),
102103
TextChatMessageUser.of("UserMessage"));
103104

104-
TextChatRequest body = new TextChatRequest(modelId, spaceId, projectId, messages, null, null, parameters);
105+
TextChatRequest body = new TextChatRequest(modelId, spaceId, projectId, messages, null, parameters);
105106

106107
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
107108
.body(mapper.writeValueAsString(body))
@@ -140,7 +141,7 @@ void check_chat_streaming_model_config() throws Exception {
140141
TextChatMessageSystem.of("SystemMessage"),
141142
TextChatMessageUser.of("UserMessage"));
142143

143-
TextChatRequest body = new TextChatRequest(modelId, spaceId, projectId, messagesToSend, null, null, parameters);
144+
TextChatRequest body = new TextChatRequest(modelId, spaceId, projectId, messagesToSend, null, parameters);
144145

145146
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200)
146147
.body(mapper.writeValueAsString(body))

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java

Lines changed: 262 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package io.quarkiverse.langchain4j.watsonx.deployment;
22

3+
import static dev.langchain4j.model.chat.request.ToolChoice.AUTO;
34
import static org.assertj.core.api.Assertions.assertThat;
45
import static org.awaitility.Awaitility.await;
56
import static org.junit.jupiter.api.Assertions.assertEquals;
67
import static org.junit.jupiter.api.Assertions.assertNotNull;
8+
import static org.junit.jupiter.api.Assertions.assertThrows;
79
import static org.junit.jupiter.api.Assertions.fail;
810

911
import java.time.Duration;
@@ -21,15 +23,27 @@
2123

2224
import dev.langchain4j.data.embedding.Embedding;
2325
import dev.langchain4j.data.message.AiMessage;
26+
import dev.langchain4j.data.message.ChatMessage;
27+
import dev.langchain4j.data.message.SystemMessage;
28+
import dev.langchain4j.data.message.UserMessage;
2429
import dev.langchain4j.data.segment.TextSegment;
30+
import dev.langchain4j.exception.UnsupportedFeatureException;
2531
import dev.langchain4j.model.StreamingResponseHandler;
2632
import dev.langchain4j.model.chat.ChatLanguageModel;
2733
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
2834
import dev.langchain4j.model.chat.TokenCountEstimator;
35+
import dev.langchain4j.model.chat.request.ChatRequest;
36+
import dev.langchain4j.model.chat.request.ResponseFormat;
37+
import dev.langchain4j.model.chat.response.ChatResponse;
38+
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
39+
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
2940
import dev.langchain4j.model.embedding.EmbeddingModel;
3041
import dev.langchain4j.model.output.FinishReason;
3142
import dev.langchain4j.model.output.Response;
43+
import dev.langchain4j.model.output.TokenUsage;
3244
import dev.langchain4j.model.scoring.ScoringModel;
45+
import io.quarkiverse.langchain4j.watsonx.WatsonxChatRequestParameters;
46+
import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationRequestParameters;
3347
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
3448
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
3549
import io.quarkiverse.langchain4j.watsonx.bean.ScoringParameters;
@@ -155,6 +169,254 @@ void check_config() throws Exception {
155169
assertEquals(10, runtimeConfig.scoringModel().truncateInputTokens().orElse(null));
156170
}
157171

172+
@Test
173+
void chat_request_test() throws Exception {
174+
// Use the chat method without customization:
175+
var config = langchain4jWatsonConfig.defaultConfig();
176+
String modelId = config.generationModel().modelId();
177+
String spaceId = config.spaceId().orElse(null);
178+
String projectId = config.projectId().orElse(null);
179+
180+
TextGenerationRequest body = new TextGenerationRequest(modelId, spaceId, projectId,
181+
"You are an helpful assistant@Hello, how are you?",
182+
parameters);
183+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200, "aaaa-mm-dd")
184+
.body(mapper.writeValueAsString(body))
185+
.response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API)
186+
.build();
187+
188+
List<ChatMessage> chatMessages = List.of(
189+
SystemMessage.from("You are an helpful assistant"),
190+
UserMessage.from("Hello, how are you?"));
191+
192+
var response = chatModel.chat(ChatRequest.builder().messages(chatMessages).build());
193+
194+
ChatResponse expected = new ChatResponse.Builder()
195+
.aiMessage(AiMessage.from("AI Response"))
196+
.metadata(ChatResponseMetadata.builder()
197+
.modelName("mistralai/mistral-large")
198+
.tokenUsage(new TokenUsage(50, 5))
199+
.finishReason(FinishReason.STOP)
200+
.build())
201+
.build();
202+
203+
assertEquals(expected, response);
204+
// ----------------------------------------------
205+
206+
// Use the chat method with customization:
207+
var request = ChatRequest.builder()
208+
.messages(chatMessages)
209+
.parameters(
210+
WatsonxGenerationRequestParameters.builder()
211+
.modelName("deepseek")
212+
.minNewTokens(1)
213+
.maxOutputTokens(2)
214+
.decodingMethod("nogreedy")
215+
.lengthPenalty(new LengthPenalty(0.0, 1))
216+
.randomSeed(1)
217+
.stopSequences(List.of("[]"))
218+
.temperature(1.0)
219+
.timeLimit(Duration.ofSeconds(1))
220+
.topK(10)
221+
.topP(1.5)
222+
.repetitionPenalty(1.0)
223+
.truncateInputTokens(2)
224+
.includeStopSequence(true)
225+
.build())
226+
.build();
227+
228+
body = new TextGenerationRequest("deepseek", spaceId, projectId,
229+
"You are an helpful assistant@Hello, how are you?",
230+
TextGenerationParameters.builder()
231+
.minNewTokens(1)
232+
.maxNewTokens(2)
233+
.decodingMethod("nogreedy")
234+
.lengthPenalty(new LengthPenalty(0.0, 1))
235+
.randomSeed(1)
236+
.stopSequences(List.of("[]"))
237+
.temperature(1.0)
238+
.timeLimit(1000L)
239+
.topK(10)
240+
.topP(1.5)
241+
.repetitionPenalty(1.0)
242+
.truncateInputTokens(2)
243+
.includeStopSequence(true)
244+
.build());
245+
246+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200, "aaaa-mm-dd")
247+
.body(mapper.writeValueAsString(body))
248+
.response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API)
249+
.build();
250+
251+
response = chatModel.chat(request);
252+
assertEquals(expected, response);
253+
// ----------------------------------------
254+
255+
// Use the chat method with unsupported parameter:
256+
assertThrows(UnsupportedFeatureException.class, () -> chatModel.chat(ChatRequest.builder()
257+
.messages(chatMessages)
258+
.parameters(WatsonxChatRequestParameters.builder().frequencyPenalty(1.0).build())
259+
.build()));
260+
261+
assertThrows(UnsupportedFeatureException.class, () -> chatModel.chat(ChatRequest.builder()
262+
.messages(chatMessages)
263+
.parameters(WatsonxChatRequestParameters.builder().presencePenalty(1.0).build())
264+
.build()));
265+
266+
assertThrows(UnsupportedFeatureException.class, () -> chatModel.chat(ChatRequest.builder()
267+
.messages(chatMessages)
268+
.parameters(WatsonxChatRequestParameters.builder().toolChoice(AUTO).build())
269+
.build()));
270+
271+
assertThrows(UnsupportedFeatureException.class, () -> chatModel.chat(ChatRequest.builder()
272+
.messages(chatMessages)
273+
.parameters(WatsonxChatRequestParameters.builder().responseFormat(ResponseFormat.JSON).build())
274+
.build()));
275+
// ----------------------------------------
276+
}
277+
278+
@Test
279+
void chat_request_streaming_test() throws Exception {
280+
// Use the chat method without customization:
281+
var config = langchain4jWatsonConfig.defaultConfig();
282+
String modelId = config.generationModel().modelId();
283+
String spaceId = config.spaceId().orElse(null);
284+
String projectId = config.projectId().orElse(null);
285+
286+
TextGenerationRequest body = new TextGenerationRequest(modelId, spaceId, projectId,
287+
"You are an helpful assistant@Hello, how are you?",
288+
parameters);
289+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, 200, "aaaa-mm-dd")
290+
.body(mapper.writeValueAsString(body))
291+
.responseMediaType(MediaType.SERVER_SENT_EVENTS)
292+
.response(WireMockUtil.RESPONSE_WATSONX_GENERATION_STREAMING_API)
293+
.build();
294+
295+
List<ChatMessage> chatMessages = List.of(
296+
SystemMessage.from("You are an helpful assistant"),
297+
UserMessage.from("Hello, how are you?"));
298+
299+
var streamingResponse = new AtomicReference<ChatResponse>();
300+
var streamingChatResponseHandler = new StreamingChatResponseHandler() {
301+
@Override
302+
public void onPartialResponse(String partialResponse) {
303+
}
304+
305+
@Override
306+
public void onCompleteResponse(ChatResponse completeResponse) {
307+
assertEquals(FinishReason.LENGTH, completeResponse.finishReason());
308+
assertEquals(2, completeResponse.tokenUsage().inputTokenCount());
309+
assertEquals(14, completeResponse.tokenUsage().outputTokenCount());
310+
assertEquals(16, completeResponse.tokenUsage().totalTokenCount());
311+
streamingResponse.set(completeResponse);
312+
}
313+
314+
@Override
315+
public void onError(Throwable error) {
316+
fail(error);
317+
}
318+
};
319+
320+
streamingChatModel.chat(ChatRequest.builder().messages(chatMessages).build(), streamingChatResponseHandler);
321+
322+
ChatResponse expected = new ChatResponse.Builder()
323+
.aiMessage(AiMessage.from(". I'm a beginner"))
324+
.metadata(ChatResponseMetadata.builder()
325+
.modelName("mistralai/mistral-large")
326+
.tokenUsage(new TokenUsage(2, 14))
327+
.finishReason(FinishReason.LENGTH)
328+
.build())
329+
.build();
330+
331+
await().atMost(Duration.ofMinutes(1))
332+
.pollInterval(Duration.ofSeconds(2))
333+
.until(() -> streamingResponse.get() != null);
334+
335+
assertEquals(expected, streamingResponse.get());
336+
// ----------------------------------------------
337+
338+
// Use the chat method with customization:
339+
var request = ChatRequest.builder()
340+
.messages(chatMessages)
341+
.parameters(
342+
WatsonxGenerationRequestParameters.builder()
343+
.modelName("deepseek")
344+
.minNewTokens(1)
345+
.maxOutputTokens(2)
346+
.decodingMethod("nogreedy")
347+
.lengthPenalty(new LengthPenalty(0.0, 1))
348+
.randomSeed(1)
349+
.stopSequences(List.of("[]"))
350+
.temperature(1.0)
351+
.timeLimit(Duration.ofSeconds(1))
352+
.topK(10)
353+
.topP(1.5)
354+
.repetitionPenalty(1.0)
355+
.truncateInputTokens(2)
356+
.includeStopSequence(true)
357+
.build())
358+
.build();
359+
360+
body = new TextGenerationRequest("deepseek", spaceId, projectId,
361+
"You are an helpful assistant@Hello, how are you?",
362+
TextGenerationParameters.builder()
363+
.minNewTokens(1)
364+
.maxNewTokens(2)
365+
.decodingMethod("nogreedy")
366+
.lengthPenalty(new LengthPenalty(0.0, 1))
367+
.randomSeed(1)
368+
.stopSequences(List.of("[]"))
369+
.temperature(1.0)
370+
.timeLimit(1000L)
371+
.topK(10)
372+
.topP(1.5)
373+
.repetitionPenalty(1.0)
374+
.truncateInputTokens(2)
375+
.includeStopSequence(true)
376+
.build());
377+
378+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, 200, "aaaa-mm-dd")
379+
.body(mapper.writeValueAsString(body))
380+
.responseMediaType(MediaType.SERVER_SENT_EVENTS)
381+
.response(WireMockUtil.RESPONSE_WATSONX_GENERATION_STREAMING_API)
382+
.build();
383+
384+
streamingChatModel.chat(request, streamingChatResponseHandler);
385+
386+
await().atMost(Duration.ofMinutes(1))
387+
.pollInterval(Duration.ofSeconds(2))
388+
.until(() -> streamingResponse.get() != null);
389+
390+
assertEquals(expected, streamingResponse.get());
391+
// ----------------------------------------
392+
393+
// Use the chat method with unsupported parameter:
394+
assertThrows(UnsupportedFeatureException.class, () -> streamingChatModel.chat(ChatRequest.builder()
395+
.messages(chatMessages)
396+
.parameters(WatsonxChatRequestParameters.builder().frequencyPenalty(1.0).build())
397+
.build(),
398+
streamingChatResponseHandler));
399+
400+
assertThrows(UnsupportedFeatureException.class, () -> streamingChatModel.chat(ChatRequest.builder()
401+
.messages(chatMessages)
402+
.parameters(WatsonxChatRequestParameters.builder().presencePenalty(1.0).build())
403+
.build(),
404+
streamingChatResponseHandler));
405+
406+
assertThrows(UnsupportedFeatureException.class, () -> streamingChatModel.chat(ChatRequest.builder()
407+
.messages(chatMessages)
408+
.parameters(WatsonxChatRequestParameters.builder().toolChoice(AUTO).build())
409+
.build(),
410+
streamingChatResponseHandler));
411+
412+
assertThrows(UnsupportedFeatureException.class, () -> streamingChatModel.chat(ChatRequest.builder()
413+
.messages(chatMessages)
414+
.parameters(WatsonxChatRequestParameters.builder().responseFormat(ResponseFormat.JSON).build())
415+
.build(),
416+
streamingChatResponseHandler));
417+
// ----------------------------------------
418+
}
419+
158420
@Test
159421
void check_chat_model_config() throws Exception {
160422
var config = langchain4jWatsonConfig.defaultConfig();
@@ -298,6 +560,5 @@ public void onComplete(Response<AiMessage> response) {
298560
assertThat(streamingResponse.get().text())
299561
.isNotNull()
300562
.isEqualTo(". I'm a beginner");
301-
302563
}
303564
}

0 commit comments

Comments
 (0)