Skip to content

Commit 6b3edf4

Browse files
committed
[watsonx.ai] Add truncate_input_tokens property
1 parent 7068836 commit 6b3edf4

File tree

8 files changed

+46
-10
lines changed

8 files changed

+46
-10
lines changed

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void test_embed_list_of_three_textsegment() throws Exception {
8383

8484
var input = "Embedding THIS!";
8585
EmbeddingRequest request = new EmbeddingRequest(WireMockUtil.DEFAULT_EMBEDDING_MODEL, WireMockUtil.PROJECT_ID,
86-
List.of(input, input, input));
86+
List.of(input, input, input), null);
8787

8888
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200)
8989
.body(mapper.writeValueAsString(request))
@@ -140,7 +140,7 @@ private List<Float> mockEmbeddingServer(String input) throws Exception {
140140
.build();
141141

142142
EmbeddingRequest request = new EmbeddingRequest(WireMockUtil.DEFAULT_EMBEDDING_MODEL, WireMockUtil.PROJECT_ID,
143-
List.of(input));
143+
List.of(input), null);
144144

145145
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200)
146146
.body(mapper.writeValueAsString(request))

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import dev.langchain4j.model.chat.TokenCountEstimator;
2626
import dev.langchain4j.model.embedding.EmbeddingModel;
2727
import dev.langchain4j.model.output.Response;
28+
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
2829
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
2930
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters;
3031
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters.LengthPenalty;
@@ -64,6 +65,7 @@ public class GenerationAllPropertiesTest extends WireMockAbstract {
6465
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.truncate-input-tokens", "0")
6566
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.include-stop-sequence", "false")
6667
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", "my_super_embedding_model")
68+
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.truncate-input-tokens", "10")
6769
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));
6870

6971
@Override
@@ -102,6 +104,8 @@ void handlerBeforeEach() {
102104
.includeStopSequence(false)
103105
.build();
104106

107+
static EmbeddingParameters embeddingParameters = new EmbeddingParameters(10);
108+
105109
@Test
106110
void check_config() throws Exception {
107111
var runtimeConfig = langchain4jWatsonConfig.defaultConfig();
@@ -133,6 +137,7 @@ void check_config() throws Exception {
133137
assertEquals("@", runtimeConfig.chatModel().promptJoiner());
134138
assertEquals(true, fixedRuntimeConfig.chatModel().promptFormatter());
135139
assertEquals("my_super_embedding_model", runtimeConfig.embeddingModel().modelId());
140+
assertEquals(10, runtimeConfig.embeddingModel().truncateInputTokens().orElse(null));
136141
}
137142

138143
@Test
@@ -158,7 +163,7 @@ void check_embedding_model() throws Exception {
158163
String projectId = config.projectId();
159164

160165
EmbeddingRequest request = new EmbeddingRequest(modelId, projectId,
161-
List.of("Embedding THIS!"));
166+
List.of("Embedding THIS!"), embeddingParameters);
162167

163168
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200, "aaaa-mm-dd")
164169
.body(mapper.writeValueAsString(request))

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void check_config() throws Exception {
9898
assertTrue(runtimeConfig.chatModel().includeStopSequence().isEmpty());
9999
assertEquals("urn:ibm:params:oauth:grant-type:apikey", runtimeConfig.iam().grantType());
100100
assertEquals(WireMockUtil.DEFAULT_EMBEDDING_MODEL, runtimeConfig.embeddingModel().modelId());
101+
assertTrue(runtimeConfig.embeddingModel().truncateInputTokens().isEmpty());
101102
}
102103

103104
@Test
@@ -124,7 +125,7 @@ void check_embedding_model() throws Exception {
124125
String projectId = config.projectId();
125126

126127
EmbeddingRequest request = new EmbeddingRequest(modelId, projectId,
127-
List.of("Embedding THIS!"));
128+
List.of("Embedding THIS!"), null);
128129

129130
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200)
130131
.body(mapper.writeValueAsString(request))

model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dev.langchain4j.model.embedding.EmbeddingModel;
1818
import dev.langchain4j.model.embedding.TokenCountEstimator;
1919
import dev.langchain4j.model.output.Response;
20+
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
2021
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
2122
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse;
2223
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse.Result;
@@ -28,6 +29,7 @@
2829
public class WatsonxEmbeddingModel implements EmbeddingModel, TokenCountEstimator {
2930

3031
private final String modelId, projectId, version;
32+
private final EmbeddingParameters parameters;
3133
private final WatsonxRestApi client;
3234

3335
public WatsonxEmbeddingModel(Builder builder) {
@@ -49,6 +51,11 @@ public WatsonxEmbeddingModel(Builder builder) {
4951
this.modelId = builder.modelId;
5052
this.projectId = builder.projectId;
5153
this.version = builder.version;
54+
55+
if (builder.truncateInputTokens != null)
56+
this.parameters = new EmbeddingParameters(builder.truncateInputTokens);
57+
else
58+
this.parameters = null;
5259
}
5360

5461
@Override
@@ -61,7 +68,7 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
6168
.map(TextSegment::text)
6269
.collect(Collectors.toList());
6370

64-
EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, inputs);
71+
EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, inputs, parameters);
6572
EmbeddingResponse result = retryOn(new Callable<EmbeddingResponse>() {
6673
@Override
6774
public EmbeddingResponse call() throws Exception {
@@ -102,6 +109,7 @@ public static final class Builder {
102109
private String version;
103110
private String projectId;
104111
private Duration timeout;
112+
private Integer truncateInputTokens;
105113
private boolean logResponses;
106114
private boolean logRequests;
107115
private URL url;
@@ -127,6 +135,11 @@ public Builder timeout(Duration timeout) {
127135
return this;
128136
}
129137

138+
public Builder truncateInputTokens(Integer truncateInputTokens) {
139+
this.truncateInputTokens = truncateInputTokens;
140+
return this;
141+
}
142+
130143
public Builder logRequests(boolean logRequests) {
131144
this.logRequests = logRequests;
132145
return this;
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package io.quarkiverse.langchain4j.watsonx.bean;
2+
3+
public record EmbeddingParameters(Integer truncateInputTokens) {
4+
}

model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingRequest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import java.util.List;
44

5-
public record EmbeddingRequest(String modelId, String projectId, List<String> inputs) {
5+
public record EmbeddingRequest(String modelId, String projectId, List<String> inputs, EmbeddingParameters parameters) {
66

7-
public EmbeddingRequest of(String modelId, String projectId, String input) {
8-
return new EmbeddingRequest(modelId, projectId, List.of(input));
7+
public EmbeddingRequest of(String modelId, String projectId, String input, EmbeddingParameters parameters) {
8+
return new EmbeddingRequest(modelId, projectId, List.of(input), parameters);
99
}
1010
}

model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jWatsonxConfig runtimeC
197197
.logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), watsonConfig.logResponses()))
198198
.version(watsonConfig.version())
199199
.projectId(watsonConfig.projectId())
200-
.modelId(embeddingModelConfig.modelId());
200+
.modelId(embeddingModelConfig.modelId())
201+
.truncateInputTokens(embeddingModelConfig.truncateInputTokens().orElse(null));
201202

202203
return new Supplier<>() {
203204
@Override

model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/EmbeddingModelConfig.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,25 @@ public interface EmbeddingModelConfig {
1212
/**
1313
* Model id to use.
1414
*
15-
* To view the complete model list, <a href=
15+
* To view the complete model list,
16+
* <a href=
1617
* "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx&audience=wdp">click
1718
* here</a>.
1819
*/
1920
@WithDefault("ibm/slate-125m-english-rtrvr")
2021
String modelId();
2122

23+
/**
24+
* Represents the maximum number of input tokens accepted. This can be used to avoid requests failing due to input being
25+
* longer
26+
* than configured limits. If the text is truncated, then it truncates the end of the input (on the right), so the start of
27+
* the
28+
* input will remain the same. If this value exceeds the maximum sequence length (refer to the documentation to find this
29+
* value
30+
* for the model) then the call will fail if the total number of tokens exceeds the maximum sequence length.
31+
*/
32+
Optional<Integer> truncateInputTokens();
33+
2234
/**
2335
* Whether embedding model requests should be logged.
2436
*/

0 commit comments

Comments
 (0)