Skip to content

Commit a8c5241

Browse files
Fixing embedding dimensions issue and test field names
1 parent e7f6ac5 commit a8c5241

14 files changed

+103
-75
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
119119
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
120120
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
121+
import org.elasticsearch.xpack.inference.services.custom.CustomService;
121122
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
122123
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
123124
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
@@ -376,6 +377,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
376377
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
377378
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
378379
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
380+
context -> new CustomService(httpFactory.get(), serviceComponents.get()),
379381
ElasticsearchInternalService::new
380382
);
381383
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,14 @@
77

88
package org.elasticsearch.xpack.inference.services.custom;
99

10-
import org.apache.logging.log4j.LogManager;
11-
import org.apache.logging.log4j.Logger;
1210
import org.elasticsearch.ElasticsearchStatusException;
1311
import org.elasticsearch.TransportVersion;
1412
import org.elasticsearch.TransportVersions;
1513
import org.elasticsearch.action.ActionListener;
1614
import org.elasticsearch.common.ValidationException;
17-
import org.elasticsearch.common.bytes.BytesReference;
1815
import org.elasticsearch.common.util.LazyInitializable;
1916
import org.elasticsearch.core.Nullable;
17+
import org.elasticsearch.core.Strings;
2018
import org.elasticsearch.core.TimeValue;
2119
import org.elasticsearch.inference.ChunkedInference;
2220
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -26,10 +24,9 @@
2624
import org.elasticsearch.inference.ModelConfigurations;
2725
import org.elasticsearch.inference.ModelSecrets;
2826
import org.elasticsearch.inference.SettingsConfiguration;
27+
import org.elasticsearch.inference.SimilarityMeasure;
2928
import org.elasticsearch.inference.TaskType;
3029
import org.elasticsearch.rest.RestStatus;
31-
import org.elasticsearch.xcontent.XContentBuilder;
32-
import org.elasticsearch.xcontent.XContentFactory;
3330
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3431
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3532
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -38,15 +35,14 @@
3835
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3936
import org.elasticsearch.xpack.inference.services.SenderService;
4037
import org.elasticsearch.xpack.inference.services.ServiceComponents;
38+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4139

42-
import java.io.IOException;
4340
import java.util.EnumSet;
4441
import java.util.HashMap;
4542
import java.util.List;
4643
import java.util.Map;
4744

4845
import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg;
49-
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
5046
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
5147
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
5248
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
@@ -55,8 +51,8 @@
5551
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
5652

5753
public class CustomService extends SenderService {
58-
private static final Logger logger = LogManager.getLogger(CustomService.class);
5954
public static final String NAME = "custom";
55+
private static final String SERVICE_NAME = "Custom";
6056

6157
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
6258
TaskType.TEXT_EMBEDDING,
@@ -94,8 +90,6 @@ public void parseRequestConfig(
9490
ConfigurationParseContext.REQUEST
9591
);
9692

97-
logModelConfig(model.getConfigurations());
98-
9993
throwIfNotEmptyMap(config, NAME);
10094
throwIfNotEmptyMap(serviceSettingsMap, NAME);
10195
throwIfNotEmptyMap(taskSettingsMap, NAME);
@@ -168,27 +162,15 @@ public CustomModel parsePersistedConfigWithSecrets(
168162
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
169163
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
170164

171-
return createModelWithoutLoggingDeprecations(
172-
inferenceEntityId,
173-
taskType,
174-
serviceSettingsMap,
175-
taskSettingsMap,
176-
secretSettingsMap
177-
);
165+
return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap);
178166
}
179167

180168
@Override
181169
public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
182170
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
183171
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
184172

185-
return createModelWithoutLoggingDeprecations(
186-
inferenceEntityId,
187-
taskType,
188-
serviceSettingsMap,
189-
taskSettingsMap,
190-
null
191-
);
173+
return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null);
192174
}
193175

194176
@Override
@@ -208,7 +190,7 @@ public void doInfer(
208190

209191
var overriddenModel = CustomModel.of(customModel, taskSettings);
210192

211-
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Custom");
193+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME);
212194
var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool());
213195
var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage);
214196

@@ -217,7 +199,7 @@ public void doInfer(
217199

218200
@Override
219201
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
220-
// TODO
202+
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException);
221203
}
222204

223205
@Override
@@ -232,6 +214,43 @@ protected void doChunkedInfer(
232214
listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
233215
}
234216

217+
@Override
218+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
219+
if (model instanceof CustomModel customModel && customModel.getTaskType() == TaskType.TEXT_EMBEDDING) {
220+
var newServiceSettings = getCustomServiceSettings(customModel, embeddingSize);
221+
222+
return new CustomModel(customModel, newServiceSettings);
223+
} else {
224+
throw new ElasticsearchStatusException(
225+
Strings.format(
226+
"Can't update embedding details for model of type: [%s], task type: [%s]",
227+
model.getClass().getSimpleName(),
228+
model.getTaskType()
229+
),
230+
RestStatus.BAD_REQUEST
231+
);
232+
}
233+
}
234+
235+
private static CustomServiceSettings getCustomServiceSettings(CustomModel customModel, int embeddingSize) {
236+
var serviceSettings = customModel.getServiceSettings();
237+
var similarityFromModel = serviceSettings.similarity();
238+
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
239+
240+
return new CustomServiceSettings(
241+
similarityToUse,
242+
embeddingSize,
243+
serviceSettings.getMaxInputTokens(),
244+
serviceSettings.getUrl(),
245+
serviceSettings.getHeaders(),
246+
serviceSettings.getQueryParameters(),
247+
serviceSettings.getRequestContentString(),
248+
serviceSettings.getResponseJsonParser(),
249+
serviceSettings.rateLimitSettings(),
250+
serviceSettings.getErrorParser()
251+
);
252+
}
253+
235254
@Override
236255
public TransportVersion getMinimalSupportedVersion() {
237256
return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL;
@@ -246,18 +265,11 @@ public static InferenceServiceConfiguration get() {
246265
() -> {
247266
var configurationMap = new HashMap<String, SettingsConfiguration>();
248267
return new InferenceServiceConfiguration.Builder().setService(NAME)
249-
.setName(NAME)
268+
.setName(SERVICE_NAME)
250269
.setTaskTypes(supportedTaskTypes)
251270
.setConfigurations(configurationMap)
252271
.build();
253272
}
254273
);
255274
}
256-
257-
private void logModelConfig(ModelConfigurations modelConfigurations) throws IOException {
258-
XContentBuilder builder = XContentFactory.jsonBuilder();
259-
XContentBuilder modelBuilder = modelConfigurations.toXContent(builder, EMPTY_PARAMS);
260-
String jsonString = BytesReference.bytes(modelBuilder).utf8ToString();
261-
logger.info("add custom model: " + jsonString);
262-
}
263275
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
6161
public static final String ERROR_PARSER = "error_parser";
6262

6363
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
64+
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
6465

6566
public static CustomServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context, TaskType taskType) {
6667
ValidationException validationException = new ValidationException();
@@ -96,7 +97,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
9697
Map<String, Object> jsonParserMap = extractRequiredMap(
9798
Objects.requireNonNullElse(responseParserMap, new HashMap<>()),
9899
JSON_PARSER,
99-
ModelConfigurations.SERVICE_SETTINGS,
100+
RESPONSE_SCOPE,
100101
validationException
101102
);
102103

@@ -105,11 +106,11 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
105106
Map<String, Object> errorParserMap = extractRequiredMap(
106107
Objects.requireNonNullElse(responseParserMap, new HashMap<>()),
107108
ERROR_PARSER,
108-
ModelConfigurations.SERVICE_SETTINGS,
109+
RESPONSE_SCOPE,
109110
validationException
110111
);
111112

112-
var errorParser = ErrorResponseParser.fromMap(errorParserMap, validationException);
113+
var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, validationException);
113114

114115
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
115116
map,
@@ -371,10 +372,10 @@ private static CustomResponseParser extractResponseParser(
371372
}
372373

373374
return switch (taskType) {
374-
case TEXT_EMBEDDING -> TextEmbeddingResponseParser.fromMap(responseParserMap, validationException);
375-
case SPARSE_EMBEDDING -> SparseEmbeddingResponseParser.fromMap(responseParserMap, validationException);
376-
case RERANK -> RerankResponseParser.fromMap(responseParserMap, validationException);
377-
case COMPLETION -> CompletionResponseParser.fromMap(responseParserMap, validationException);
375+
case TEXT_EMBEDDING -> TextEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
376+
case SPARSE_EMBEDDING -> SparseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
377+
case RERANK -> RerankResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
378+
case COMPLETION -> CompletionResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
378379
default -> throw new IllegalArgumentException(
379380
Strings.format("Invalid task type received [%s] while constructing response parser", taskType)
380381
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ public class CompletionResponseParser extends BaseCustomResponseParser<ChatCompl
3030

3131
private final String completionResultPath;
3232

33-
public static CompletionResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
34-
var path = extractRequiredString(responseParserMap, COMPLETION_PARSER_RESULT, JSON_PARSER, validationException);
33+
public static CompletionResponseParser fromMap(Map<String, Object> responseParserMap, String scope, ValidationException validationException) {
34+
var path = extractRequiredString(responseParserMap, COMPLETION_PARSER_RESULT, String.join(".", scope, JSON_PARSER), validationException);
3535

3636
if (path == null) {
3737
throw validationException;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ public class ErrorResponseParser implements ToXContentFragment, Function<HttpRes
3535

3636
private final String messagePath;
3737

38-
public static ErrorResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
39-
var path = extractRequiredString(responseParserMap, MESSAGE_PATH, ERROR_PARSER, validationException);
38+
public static ErrorResponseParser fromMap(
39+
Map<String, Object> responseParserMap,
40+
String scope,
41+
ValidationException validationException
42+
) {
43+
var path = extractRequiredString(responseParserMap, MESSAGE_PATH, String.join(".", scope, ERROR_PARSER), validationException);
4044

4145
if (path == null) {
4246
throw validationException;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ public class RerankResponseParser extends BaseCustomResponseParser<RankedDocsRes
3737
private final String rerankIndexPath;
3838
private final String documentTextPath;
3939

40-
public static RerankResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
40+
public static RerankResponseParser fromMap(Map<String, Object> responseParserMap, String scope, ValidationException validationException) {
41+
var fullScope = String.join(".", scope, JSON_PARSER);
4142

42-
var relevanceScore = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, JSON_PARSER, validationException);
43-
var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, JSON_PARSER, validationException);
44-
var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, JSON_PARSER, validationException);
43+
var relevanceScore = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, fullScope, validationException);
44+
var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, fullScope, validationException);
45+
var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, fullScope, validationException);
4546

4647
if (relevanceScore == null) {
4748
throw validationException;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ public class SparseEmbeddingResponseParser extends BaseCustomResponseParser<Spar
3535
private final String tokenPath;
3636
private final String weightPath;
3737

38-
public static SparseEmbeddingResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
39-
var tokenPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, JSON_PARSER, validationException);
38+
public static SparseEmbeddingResponseParser fromMap(Map<String, Object> responseParserMap, String scope, ValidationException validationException) {
39+
var fullScope = String.join(".", scope, JSON_PARSER);
40+
var tokenPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, fullScope, validationException);
4041

41-
var weightPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, JSON_PARSER, validationException);
42+
var weightPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, fullScope, validationException);
4243

4344
if (tokenPath == null || weightPath == null) {
4445
throw validationException;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ public class TextEmbeddingResponseParser extends BaseCustomResponseParser<TextEm
3030

3131
private final String textEmbeddingsPath;
3232

33-
public static TextEmbeddingResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
34-
var path = extractRequiredString(responseParserMap, TEXT_EMBEDDING_PARSER_EMBEDDINGS, JSON_PARSER, validationException);
33+
public static TextEmbeddingResponseParser fromMap(Map<String, Object> responseParserMap, String scope, ValidationException validationException) {
34+
var path = extractRequiredString(responseParserMap, TEXT_EMBEDDING_PARSER_EMBEDDINGS, String.join(".", scope, JSON_PARSER), validationException);
3535

3636
if (path == null) {
3737
throw validationException;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,8 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() {
446446
exception.getMessage(),
447447
is(
448448
"Validation Failed: 1: [service_settings] does not contain the required setting [response];"
449-
+ "2: [service_settings] does not contain the required setting [json_parser];"
450-
+ "3: [service_settings] does not contain the required setting [error_parser];"
449+
+ "2: [service_settings.response] does not contain the required setting [json_parser];"
450+
+ "3: [service_settings.response] does not contain the required setting [error_parser];"
451451
+ "4: Encountered a null input map while parsing field [path];"
452452
)
453453
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ public static CompletionResponseParser createRandom() {
3838

3939
public void testFromMap() {
4040
var validation = new ValidationException();
41-
var parser = CompletionResponseParser.fromMap(new HashMap<>(Map.of(COMPLETION_PARSER_RESULT, "$.result[*].text")), validation);
41+
var parser = CompletionResponseParser.fromMap(
42+
new HashMap<>(Map.of(COMPLETION_PARSER_RESULT, "$.result[*].text")),
43+
"scope",
44+
validation
45+
);
4246

4347
assertThat(parser, is(new CompletionResponseParser("$.result[*].text")));
4448
}
@@ -47,12 +51,12 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() {
4751
var validation = new ValidationException();
4852
var exception = expectThrows(
4953
ValidationException.class,
50-
() -> CompletionResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].text")), validation)
54+
() -> CompletionResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].text")), "scope", validation)
5155
);
5256

5357
assertThat(
5458
exception.getMessage(),
55-
is("Validation Failed: 1: [json_parser] does not contain the required setting [completion_result];")
59+
is("Validation Failed: 1: [scope.json_parser] does not contain the required setting [completion_result];")
5660
);
5761
}
5862

0 commit comments

Comments
 (0)