From 4d9ec59a9d2171ea769a98cde6676a431bac7b61 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Mon, 28 Apr 2025 18:35:36 +0200 Subject: [PATCH 01/12] Add Hugging Face Rerank support --- .../SettingsConfigurationTestUtils.java | 5 +- .../inference/InferenceBaseRestTest.java | 20 +++ .../inference/InferenceGetServicesIT.java | 5 +- .../InferenceNamedWriteablesProvider.java | 12 ++ ...ankFeaturePhaseRankCoordinatorContext.java | 6 +- .../huggingface/HuggingFaceBaseService.java | 32 ++-- .../huggingface/HuggingFaceModel.java | 23 ++- .../huggingface/HuggingFaceModelInput.java | 113 +++++++++++++ .../huggingface/HuggingFaceService.java | 65 ++++---- .../action/HuggingFaceActionCreator.java | 33 ++++ .../action/HuggingFaceActionVisitor.java | 3 + .../elser/HuggingFaceElserService.java | 27 ++- .../rerank/HuggingFaceRerankRequest.java | 111 +++++++++++++ .../HuggingFaceRerankRequestEntity.java | 71 ++++++++ .../rerank/HuggingFaceRerankModel.java | 91 ++++++++++ .../HuggingFaceRerankServiceSettings.java | 154 +++++++++++++++++ .../rerank/HuggingFaceRerankTaskSettings.java | 156 ++++++++++++++++++ .../HuggingFaceRerankResponseEntity.java | 123 ++++++++++++++ .../BaseTransportInferenceActionTestCase.java | 2 +- .../huggingface/HuggingFaceServiceTests.java | 18 +- .../action/HuggingFaceActionCreatorTests.java | 59 +++++++ .../HuggingFaceRerankRequestEntityTests.java | 68 ++++++++ .../rerank/HuggingFaceRerankRequestTests.java | 99 +++++++++++ .../rerank/HuggingFaceRerankModelTests.java | 41 +++++ .../HuggingFaceRerankTaskSettingsTests.java | 118 +++++++++++++ 25 files changed, 1384 insertions(+), 71 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelInput.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java diff --git a/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java b/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java index 7e68f41de1b2e..e99b2178c50d4 100644 --- a/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java +++ b/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java @@ -20,9 +20,8 @@ public class SettingsConfigurationTestUtils { public static SettingsConfiguration getRandomSettingsConfigurationField() { - return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue( - randomAlphaOfLength(10) - ) + return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) + .setDefaultValue(randomAlphaOfLength(10)) .setDescription(randomAlphaOfLength(10)) .setLabel(randomAlphaOfLength(10)) .setRequired(randomBoolean()) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 54d0aca772061..a99b5fbd34e3a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() { """; } + static String mockRerankServiceModelConfig() { + return """ + { + "task_type": "rerank", + "service": "rerank_test_service", + "service_settings": { + "model": "rerank_model", + "api_key": "abc64" + }, + "task_settings": { + "return_documents": true + } + } + """; + } + static void deleteModel(String modelId) throws IOException { var request = new Request("DELETE", "_inference/" + modelId); var response = client().performRequest(request); @@ -484,6 +500,10 @@ private String jsonBody(List input, @Nullable String query) { @SuppressWarnings("unchecked") protected void assertNonEmptyInferenceResults(Map resultMap, int expectedNumberOfResults, TaskType taskType) { switch (taskType) { + case RERANK -> { + var results = (List>) resultMap.get(TaskType.RERANK.toString()); + assertThat(results, hasSize(expectedNumberOfResults)); + } case SPARSE_EMBEDDING -> { var results = (List>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString()); assertThat(results, hasSize(expectedNumberOfResults)); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 682eebd0fa69b..0aeb0d1d4a1cf 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(7)); + assertThat(services.size(), equalTo(8)); var providers = providers(services); @@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException { "googlevertexai", "jinaai", "test_reranking_service", - "voyageai" + "voyageai", + "hugging_face" ).toArray() ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 5c719c08142a2..a7bf5511d9833 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -79,6 +79,8 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; @@ -357,6 +359,16 @@ private static void addHuggingFaceNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 7f245ae854eac..c0f7aae38bf76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; import java.util.ArrayList; import java.util.Arrays; @@ -91,7 +92,10 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -65,7 +66,7 @@ public void parseRequestConfig( ); } - var model = createModel( + var modelBuilder = new HuggingFaceModelInput.Builder( inferenceEntityId, taskType, serviceSettingsMap, @@ -75,8 +76,13 @@ public void parseRequestConfig( ConfigurationParseContext.REQUEST ); + var model = createModel( + TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build() + ); + throwIfNotEmptyMap(config, name()); throwIfNotEmptyMap(serviceSettingsMap, name()); + throwIfNotEmptyMap(taskSettingsMap, name()); parsedModelListener.onResponse(model); } catch (Exception e) { @@ -92,6 +98,7 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( Map secrets ) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); ChunkingSettings chunkingSettings = null; @@ -99,7 +106,7 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModel( + var modelBuilder = new HuggingFaceModelInput.Builder( inferenceEntityId, taskType, serviceSettingsMap, @@ -108,18 +115,23 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( parsePersistedConfigErrorMsg(inferenceEntityId, name()), ConfigurationParseContext.PERSISTENT ); + + return createModel( + TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build() + ); } @Override public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModel( + var modelBuilder = new HuggingFaceModelInput.Builder( inferenceEntityId, taskType, serviceSettingsMap, @@ -128,17 +140,13 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType parsePersistedConfigErrorMsg(inferenceEntityId, name()), ConfigurationParseContext.PERSISTENT ); + + return createModel( + TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build() + ); } - protected abstract HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ); + protected abstract HuggingFaceModel createModel(HuggingFaceModelInput input); @Override public void doInfer( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java index 6a750a9ded9b3..41cc7fd283322 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java @@ -9,17 +9,19 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.Objects; -public abstract class HuggingFaceModel extends Model { +public abstract class HuggingFaceModel extends RateLimitGroupingModel { private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings; private final SecureString apiKey; @@ -34,10 +36,27 @@ public HuggingFaceModel( apiKey = ServiceUtils.apiKey(apiKeySecrets); } + protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) { + super(model, taskSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + } + public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() { return rateLimitServiceSettings; } + @Override + public int rateLimitGroupingHash() { + return Objects.hash(rateLimitServiceSettings.uri(), apiKey); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); + } + public SecureString apiKey() { return apiKey; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelInput.java new file mode 100644 index 0000000000000..afd2737435e0d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelInput.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.util.Map; + +public class HuggingFaceModelInput { + private final String inferenceEntityId; + private final TaskType taskType; + private final Map serviceSettings; + @Nullable + private final Map taskSettings; + private final ChunkingSettings chunkingSettings; + @Nullable + private final Map secretSettings; + private final String failureMessage; + private final ConfigurationParseContext context; + + public HuggingFaceModelInput(Builder builder) { + this.inferenceEntityId = builder.inferenceEntityId; + this.taskType = builder.taskType; + this.serviceSettings = builder.serviceSettings; + this.taskSettings = builder.taskSettings; + this.chunkingSettings = builder.chunkingSettings; + this.secretSettings = builder.secretSettings; + this.failureMessage = builder.failureMessage; + this.context = builder.context; + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + public TaskType getTaskType() { + return taskType; + } + + public Map getServiceSettings() { + return serviceSettings; + } + + @Nullable + public Map getTaskSettings() { + return taskSettings; + } + + public ChunkingSettings getChunkingSettings() { + return chunkingSettings; + } + + @Nullable + public Map getSecretSettings() { + return secretSettings; + } + + public String getFailureMessage() { + return failureMessage; + } + + public ConfigurationParseContext getContext() { + return context; + } + + public static class Builder { + private String inferenceEntityId; + private TaskType taskType; + private Map serviceSettings; + @Nullable + private Map taskSettings; + private ChunkingSettings chunkingSettings; + @Nullable + Map secretSettings; + private String failureMessage; + private ConfigurationParseContext context; + + public Builder( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + this.inferenceEntityId = inferenceEntityId; + this.taskType = taskType; + this.serviceSettings = serviceSettings; + this.chunkingSettings = chunkingSettings; + this.secretSettings = secretSettings; + this.failureMessage = failureMessage; + this.context = context; + } + + public Builder withTaskSettings(Map taskSettings) { + this.taskSettings = taskSettings; + return this; + } + + public HuggingFaceModelInput build() { + return new HuggingFaceModelInput(this); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index f2a53520e18e6..e9158d2657106 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -12,10 +12,8 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.LazyInitializable; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -29,12 +27,12 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -51,34 +49,46 @@ public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; private static final String SERVICE_NAME = "Hugging Face"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING); + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( + TaskType.RERANK, + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING + ); public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); } @Override - protected HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ) { - return switch (taskType) { + protected HuggingFaceModel createModel(HuggingFaceModelInput input) { + return switch (input.getTaskType()) { + case RERANK -> new HuggingFaceRerankModel( + input.getInferenceEntityId(), + input.getTaskType(), + NAME, + input.getServiceSettings(), + input.getTaskSettings(), + input.getSecretSettings(), + input.getContext() + ); case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel( - inferenceEntityId, - taskType, + input.getInferenceEntityId(), + input.getTaskType(), + NAME, + input.getServiceSettings(), + input.getChunkingSettings(), + input.getSecretSettings(), + input.getContext() + ); + case SPARSE_EMBEDDING -> new HuggingFaceElserModel( + input.getInferenceEntityId(), + input.getTaskType(), NAME, - serviceSettings, - chunkingSettings, - secretSettings, - context + input.getServiceSettings(), + input.getSecretSettings(), + input.getContext() ); - case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw new ElasticsearchStatusException(input.getFailureMessage(), RestStatus.BAD_REQUEST); }; } @@ -149,7 +159,7 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return supportedTaskTypes; + return SUPPORTED_TASK_TYPES; } @Override @@ -173,8 +183,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( URL, - new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings") - .setDescription("The URL endpoint to use for the requests.") + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") .setLabel("URL") .setRequired(true) .setSensitive(false) @@ -183,12 +192,12 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) - .setTaskTypes(supportedTaskTypes) + .setTaskTypes(SUPPORTED_TASK_TYPES) .setConfigurations(configurationMap) .build(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index d0578edaeaec6..5022db6b35167 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -9,14 +9,20 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceRerankResponseEntity; import java.util.Objects; @@ -29,11 +35,38 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { private final Sender sender; private final ServiceComponents serviceComponents; + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = + "Failed to send Hugging Face %s request from inference entity id [%s]"; + static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler( + "hugging face rerank", + (request, response) -> HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response) + ); + public HuggingFaceActionCreator(Sender sender, ServiceComponents serviceComponents) { this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); } + @Override + public ExecutableAction create(HuggingFaceRerankModel model) { + var overriddenModel = HuggingFaceRerankModel.of(model, model.getTaskSettings()); + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + inputs -> new HuggingFaceRerankRequest( + inputs.getQuery(), + inputs.getChunks(), + inputs.getReturnDocuments(), + inputs.getTopN(), + model + ), + QueryAndDocsInputs.class + ); + var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "RERANK", model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + @Override public ExecutableAction create(HuggingFaceEmbeddingsModel model) { var responseHandler = new HuggingFaceResponseHandler( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java index 3fb7b538769e9..6aab4176ca69d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java @@ -10,8 +10,11 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; public interface HuggingFaceActionVisitor { + ExecutableAction create(HuggingFaceRerankModel model); + ExecutableAction create(HuggingFaceEmbeddingsModel model); ExecutableAction create(HuggingFaceElserModel model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 8116eaf86e74a..b3738b9e44542 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -13,11 +13,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.LazyInitializable; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -35,10 +33,10 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModelInput; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -69,18 +67,17 @@ public String name() { } @Override - protected HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ) { - return switch (taskType) { - case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + protected HuggingFaceModel createModel(HuggingFaceModelInput input) { + return switch (input.getTaskType()) { + case SPARSE_EMBEDDING -> new HuggingFaceElserModel( + input.getInferenceEntityId(), + input.getTaskType(), + NAME, + input.getServiceSettings(), + input.getSecretSettings(), + input.getContext() + ); + default -> throw new ElasticsearchStatusException(input.getFailureMessage(), RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java new file mode 100644 index 0000000000000..87f8cb6b1d81f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.rerank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class HuggingFaceRerankRequest implements Request { + + private final HuggingFaceAccount account; + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final TaskSettings taskSettings; + private final HuggingFaceRerankModel model; + private final String inferenceEntityId; + + public HuggingFaceRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + HuggingFaceRerankModel model + ) { + Objects.requireNonNull(model); + + this.account = HuggingFaceAccount.of(model); + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; + taskSettings = model.getTaskSettings(); + this.model = model; + inferenceEntityId = model.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new HuggingFaceRerankRequestEntity( + query, + input, + returnDocuments, + topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(), + (HuggingFaceRerankTaskSettings) taskSettings, + model.getServiceSettings().modelId() + ) + ).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + decorateWithAuth(httpPost); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(createAuthBearerHeader(model.apiKey())); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + public Integer getTopN() { + return topN; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java new file mode 100644 index 0000000000000..c2ad0a622072d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record HuggingFaceRerankRequestEntity( + String model, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + HuggingFaceRerankTaskSettings taskSettings +) implements ToXContentObject { + + private static final String DOCUMENTS_FIELD = "texts"; + private static final String QUERY_FIELD = "query"; + + public HuggingFaceRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + Objects.requireNonNull(taskSettings); + } + + public HuggingFaceRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + HuggingFaceRerankTaskSettings taskSettings, + String model + ) { + this(model, query, input, returnDocuments, topN, taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(DOCUMENTS_FIELD, documents); + builder.field(QUERY_FIELD, query); + + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(HuggingFaceRerankTaskSettings.RETURN_TEXT, returnDocuments); + } else if (taskSettings.getReturnDocuments() != null) { + builder.field(HuggingFaceRerankTaskSettings.RETURN_TEXT, taskSettings.getReturnDocuments()); + } + + if (topN != null) { + builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, topN); + } else if (taskSettings.getTopNDocumentsOnly() != null) { + builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); + } + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java new file mode 100644 index 0000000000000..3d0b7e5ba5b47 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; +import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class HuggingFaceRerankModel extends HuggingFaceModel { + public static HuggingFaceRerankModel of(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) { + return new HuggingFaceRerankModel(model, HuggingFaceRerankTaskSettings.of(model.getTaskSettings(), taskSettings)); + } + + public HuggingFaceRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + HuggingFaceRerankServiceSettings.fromMap(serviceSettings, context), + HuggingFaceRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // Should only be used directly for testing + public HuggingFaceRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + HuggingFaceRerankServiceSettings serviceSettings, + HuggingFaceRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings, + secrets + ); + } + + private HuggingFaceRerankModel(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public HuggingFaceRerankServiceSettings getServiceSettings() { + return (HuggingFaceRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public HuggingFaceRerankTaskSettings getTaskSettings() { + return (HuggingFaceRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public Integer getTokenLimit() { + return getServiceSettings().maxInputTokens(); + } + + @Override + public ExecutableAction accept(HuggingFaceActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java new file mode 100644 index 0000000000000..0e925f19d12b4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; + +public class HuggingFaceRerankServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + HuggingFaceRateLimitServiceSettings { + + public static final String NAME = "hugging_face_rerank_service_settings"; + public static final String URL = "url"; + private static final int RERANK_TOKEN_LIMIT = 512; + + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + public static HuggingFaceRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var uri = extractUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + return new HuggingFaceRerankServiceSettings(uri, rateLimitSettings); + } + + private final URI uri; + private final RateLimitSettings rateLimitSettings; + + public HuggingFaceRerankServiceSettings(String url) { + uri = createUri(url); + rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; + } + + HuggingFaceRerankServiceSettings(URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.uri = Objects.requireNonNull(uri); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public HuggingFaceRerankServiceSettings(StreamInput in) throws IOException { + uri = createUri(in.readString()); + + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { + rateLimitSettings = new RateLimitSettings(in); + } else { + rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; + } + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public URI uri() { + return uri; + } + + public int maxInputTokens() { + return RERANK_TOKEN_LIMIT; + } + + // model is not defined in the service settings. + // since hugging face requires that the model be chosen when initializing a deployment within their service. + @Override + public String modelId() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(URL, uri.toString()); + builder.field(MAX_INPUT_TOKENS, RERANK_TOKEN_LIMIT); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_12_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(uri.toString()); + + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { + rateLimitSettings.writeTo(out); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + HuggingFaceRerankServiceSettings that = (HuggingFaceRerankServiceSettings) o; + return Objects.equals(uri, that.uri) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java new file mode 100644 index 0000000000000..9e9e22196ea4e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + +public class HuggingFaceRerankTaskSettings implements TaskSettings { + + public static final String NAME = "hugging_face_rerank_task_settings"; + public static final String RETURN_TEXT = "return_text"; + public static final String TOP_N_DOCS_ONLY = "top_n"; + + static final HuggingFaceRerankTaskSettings EMPTY_SETTINGS = new HuggingFaceRerankTaskSettings(null, null); + + public static HuggingFaceRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_TEXT, validationException); + Integer topNDocumentsOnly = extractOptionalPositiveInteger( + map, + TOP_N_DOCS_ONLY, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(topNDocumentsOnly, returnDocuments); + } + + /** + * Creates a new {@link HuggingFaceRerankTaskSettings} + * by preferring non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link HuggingFaceRerankTaskSettings} + */ + public static HuggingFaceRerankTaskSettings of( + HuggingFaceRerankTaskSettings originalSettings, + HuggingFaceRerankTaskSettings requestTaskSettings + ) { + return new HuggingFaceRerankTaskSettings( + requestTaskSettings.getTopNDocumentsOnly() != null + ? requestTaskSettings.getTopNDocumentsOnly() + : originalSettings.getTopNDocumentsOnly(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments() + ); + } + + public static HuggingFaceRerankTaskSettings of(Integer topNDocumentsOnly, Boolean returnDocuments) { + return new HuggingFaceRerankTaskSettings(topNDocumentsOnly, returnDocuments); + } + + private final Integer topNDocumentsOnly; + private final Boolean returnDocuments; + + public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalInt(), in.readOptionalBoolean()); + } + + public HuggingFaceRerankTaskSettings(@Nullable Integer topNDocumentsOnly, @Nullable Boolean doReturnDocuments) { + this.topNDocumentsOnly = topNDocumentsOnly; + this.returnDocuments = doReturnDocuments; + } + + @Override + public boolean isEmpty() { + return topNDocumentsOnly == null && returnDocuments == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topNDocumentsOnly != null) { + builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly); + } + if (returnDocuments != null) { + builder.field(RETURN_TEXT, returnDocuments); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_14_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(topNDocumentsOnly); + out.writeOptionalBoolean(returnDocuments); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + HuggingFaceRerankTaskSettings that = (HuggingFaceRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topNDocumentsOnly, that.topNDocumentsOnly); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topNDocumentsOnly); + } + + public Integer getTopNDocumentsOnly() { + return topNDocumentsOnly; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + HuggingFaceRerankTaskSettings updatedSettings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return HuggingFaceRerankTaskSettings.of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java new file mode 100644 index 0000000000000..b7f7c41fe6de9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.response; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest; + +import java.io.IOException; +import java.util.Comparator; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class HuggingFaceRerankResponseEntity extends ErrorResponse { + private static final Logger logger = LogManager.getLogger(HuggingFaceRerankResponseEntity.class); + + public static InferenceServiceResults fromResponse(HuggingFaceRerankRequest request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, token, jsonParser); + token = jsonParser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + var rankedDocs = parseList(jsonParser, HuggingFaceRerankResponseEntity::parseRankedDocObject); + var rankedDocsByRelevanceStream = rankedDocs.stream() + .sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed()); + var rankedDocStreamTopN = request.getTopN() == null + ? rankedDocsByRelevanceStream + : rankedDocsByRelevanceStream.limit(request.getTopN()); + return new RankedDocsResults(rankedDocStreamTopN.toList()); + } else { + throwUnknownToken(token, jsonParser); + } + + throw new IllegalStateException("Reached an invalid state while parsing the HuggingFace response"); + } + } + + private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + int index = -1; + float relevanceScore = -1; + String documentText = null; + parser.nextToken(); + while (parser.currentToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "index": + parser.nextToken(); // move to VALUE_NUMBER + index = parser.intValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "score": + parser.nextToken(); // move to VALUE_NUMBER + relevanceScore = parser.floatValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "relevance_score": + parser.nextToken(); // move to VALUE_NUMBER + relevanceScore = parser.floatValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "text": + parser.nextToken(); // move to VALUE_NUMBER + documentText = parser.text(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "document": + parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + do { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { + parser.nextToken(); // move to VALUE_STRING + documentText = parser.text(); + } + } while (parser.nextToken() != XContentParser.Token.END_OBJECT); + parser.nextToken();// move past END_OBJECT + // parser should now be at the next FIELD_NAME or END_OBJECT + break; + default: + throwUnknownField(parser.currentName(), parser); + } + } else { + parser.nextToken(); + } + } + + if (index == -1) { + logger.warn("Failed to find required field [index] in HuggingFace rerank response"); + } + if (relevanceScore == -1) { + logger.warn("Failed to find required field [relevance_score] in HuggingFace rerank response"); + } + // documentText may or may not be present depending on the request parameter + + return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); + } + + private HuggingFaceRerankResponseEntity(String errorMessage) { + super(errorMessage); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 9f08b07ea218d..70499c7987965 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -425,7 +425,7 @@ protected void mockService( doAnswer(ans -> { listenerAction.accept(ans.getArgument(9)); return null; - }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), anyBoolean(), any(), any(), anyBoolean(), any(), any(), any(), any()); doAnswer(ans -> { listenerAction.accept(ans.getArgument(3)); return null; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index b50b821d66359..f1c2a458afdba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -130,7 +130,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret")), + getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")), modelVerificationActionListener ); } @@ -821,7 +821,7 @@ public void testGetConfiguration() throws Exception { { "service": "hugging_face", "name": "Hugging Face", - "task_types": ["text_embedding", "sparse_embedding"], + "task_types": ["text_embedding", "sparse_embedding", "rerank"], "configurations": { "api_key": { "description": "API Key for the provider you're connecting to.", @@ -830,7 +830,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -839,17 +839,16 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank"] }, "url": { - "default_value": "https://api.openai.com/v1/embeddings", "description": "The URL endpoint to use for the requests.", "label": "URL", "required": true, "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank"] } } } @@ -886,10 +885,15 @@ private Map getRequestConfigMap( private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { var builtServiceSettings = new HashMap<>(); + var builtTaskSettings = new HashMap<>(); builtServiceSettings.putAll(serviceSettings); builtServiceSettings.putAll(secretSettings); - return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + Map map = new HashMap<>(); + map.put(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings); + map.put(ModelConfigurations.TASK_SETTINGS, builtTaskSettings); + + return map; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 09619cd076ff5..db423a9738dd6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -26,10 +26,12 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests; import org.junit.After; import org.junit.Before; @@ -307,6 +309,63 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws } } + public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "rerank": [ + { + "index": 0, + "relevance_score": -0.07996031, + "text": "luke" + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = HuggingFaceRerankModelTests.createModel(getUrl(webServer), "secret", "model", 8, true); + var actionCreator = new HuggingFaceActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new QueryAndDocsInputs("popular name", List.of("Luke")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + } + assertRerankActionCreator(List.of("Luke"), "popular name", 8, true); + } + + private void assertRerankActionCreator(List documents, String query, int topN, boolean returnText) throws IOException { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("texts"), is(documents)); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("top_n"), is(topN)); + assertThat(requestMap.get("return_text"), is(returnText)); + } + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java new file mode 100644 index 0000000000000..0466ab15e9f0f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.rerank; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class HuggingFaceRerankRequestEntityTests extends ESTestCase { + private static final String INPUT = "texts"; + private static final String QUERY = "query"; + private static final String INFERENCE_ID = "model"; + private static final Integer TOP_N = 8; + private static final Boolean RETURN_DOCUMENTS = false; + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new HuggingFaceRerankRequestEntity( + QUERY, + List.of(INPUT), + Boolean.TRUE, + TOP_N, + new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS), + INFERENCE_ID + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + {"texts":["texts"], + "query":"query", + "return_text":true, + "top_n":8}""")); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new HuggingFaceRerankRequestEntity( + QUERY, + List.of(INPUT), + null, + null, + new HuggingFaceRerankTaskSettings(null, null), + INFERENCE_ID + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + {"texts":["texts"],"query":"query"}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java new file mode 100644 index 0000000000000..0cb2bbb3bbaee --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.rerank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class HuggingFaceRerankRequestTests extends ESTestCase { + private static final String INPUT = "texts"; + private static final String QUERY = "query"; + private static final String INFERENCE_ID = "model"; + private static final Integer TOP_N = 8; + private static final Boolean RETURN_TEXT = false; + + private static final String AUTH_HEADER_VALUE = "foo"; + + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { + testCreateRequest(null, null); + } + + public void testCreateRequest_WithTopN() throws IOException { + testCreateRequest(TOP_N, null); + } + + public void testCreateRequest_WithReturnDocuments() throws IOException { + testCreateRequest(null, RETURN_TEXT); + } + + private void testCreateRequest(Integer topN, Boolean returnDocuments) throws IOException { + var request = createRequest(topN, returnDocuments); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap.get(INPUT), is(List.of(INPUT))); + assertThat(requestMap.get(QUERY), is(QUERY)); + // input and query must exist + int itemsCount = 2; + if (topN != null) { + assertThat(requestMap.get("top_n"), is(topN)); + itemsCount++; + } + if (returnDocuments != null) { + assertThat(requestMap.get("return_text"), is(returnDocuments)); + itemsCount++; + } + assertThat(requestMap, aMapWithSize(itemsCount)); + } + + private static HuggingFaceRerankRequest createRequest(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + var rerankModel = HuggingFaceRerankModelTests.createModel(randomAlphaOfLength(10), "secret", INFERENCE_ID, topN, returnDocuments); + + return new HuggingFaceRerankWithoutAuthRequest(QUERY, List.of(INPUT), rerankModel, topN, returnDocuments); + } + + /** + * We use this class to fake the auth implementation to avoid static mocking of {@link HuggingFaceRerankRequest} + */ + private static class HuggingFaceRerankWithoutAuthRequest extends HuggingFaceRerankRequest { + HuggingFaceRerankWithoutAuthRequest( + String query, + List input, + HuggingFaceRerankModel model, + @Nullable Integer topN, + @Nullable Boolean returnDocuments + ) { + super(query, input, returnDocuments, topN, model); + } + + @Override + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModelTests.java new file mode 100644 index 0000000000000..2c8c28819d926 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModelTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import static org.hamcrest.Matchers.containsString; + +public class HuggingFaceRerankModelTests extends ESTestCase { + + public void testThrowsURISyntaxException_ForInvalidUrl() { + var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret", "model", 8, false)); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); + } + + public static HuggingFaceRerankModel createModel( + String url, + String apiKey, + String modelId, + @Nullable Integer topN, + @Nullable Boolean returnDocuments + ) { + return new HuggingFaceRerankModel( + modelId, + TaskType.RERANK, + "service", + new HuggingFaceRerankServiceSettings(url), + new HuggingFaceRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..a45f2ac9766cc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class HuggingFaceRerankTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static HuggingFaceRerankTaskSettings createRandom() { + var returnDocuments = randomBoolean() ? randomBoolean() : null; + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + + return new HuggingFaceRerankTaskSettings(topNDocsOnly, returnDocuments); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + Map taskMap = Map.of( + HuggingFaceRerankTaskSettings.RETURN_TEXT, + true, + HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, + 5 + ); + var settings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + assertTrue(settings.getReturnDocuments()); + assertEquals(5, settings.getTopNDocumentsOnly().intValue()); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = HuggingFaceRerankTaskSettings.fromMap(Map.of()); + assertNull(settings.getReturnDocuments()); + assertNull(settings.getTopNDocumentsOnly()); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + HuggingFaceRerankTaskSettings.RETURN_TEXT, + "invalid", + HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, + 5 + ); + var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_text] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + HuggingFaceRerankTaskSettings.RETURN_TEXT, + true, + HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + } + + public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new HuggingFaceRerankTaskSettings(5, true); + HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertEquals(initialSettings, updatedSettings); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new HuggingFaceRerankTaskSettings(5, true); + Map newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_TEXT, false); + HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly()); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new HuggingFaceRerankTaskSettings(5, true); + Map newSettings = Map.of(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, 7); + HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new HuggingFaceRerankTaskSettings(5, true); + Map newSettings = Map.of( + HuggingFaceRerankTaskSettings.RETURN_TEXT, + false, + HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, + 7 + ); + HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + } + + @Override + protected Writeable.Reader instanceReader() { + return HuggingFaceRerankTaskSettings::new; + } + + @Override + protected HuggingFaceRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected HuggingFaceRerankTaskSettings mutateInstance(HuggingFaceRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, HuggingFaceRerankTaskSettingsTests::createRandom); + } +} From b58aab402b8f63937df63f76da2091ca938eb13d Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Sat, 10 May 2025 12:47:54 +0200 Subject: [PATCH 02/12] Address comments --- .../org/elasticsearch/TransportVersions.java | 2 + .../inference/InferenceBaseRestTest.java | 16 -- .../huggingface/HuggingFaceBaseService.java | 84 +++++---- .../huggingface/HuggingFaceModelInput.java | 113 ------------ .../HuggingFaceModelParameters.java | 25 +++ .../huggingface/HuggingFaceService.java | 40 ++--- .../action/HuggingFaceActionCreator.java | 19 +- .../elser/HuggingFaceElserService.java | 18 +- .../rerank/HuggingFaceRerankRequest.java | 12 +- .../rerank/HuggingFaceRerankModel.java | 2 +- .../HuggingFaceRerankServiceSettings.java | 21 +-- .../rerank/HuggingFaceRerankTaskSettings.java | 2 +- .../HuggingFaceRerankResponseEntity.java | 165 +++++++++--------- 13 files changed, 208 insertions(+), 311 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelInput.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 680770a41c44d..5757fa589e82f 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -174,6 +174,7 @@ static TransportVersion def(int id) { public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27); public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28); public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_843_0_29); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -253,6 +254,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_0_00); public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_078_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index a99b5fbd34e3a..b910985031e7c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -171,22 +171,6 @@ static String mockDenseServiceModelConfig() { """; } - static String mockRerankServiceModelConfig() { - return """ - { - "task_type": "rerank", - "service": "rerank_test_service", - "service_settings": { - "model": "rerank_model", - "api_key": "abc64" - }, - "task_settings": { - "return_documents": true - } - } - """; - } - static void deleteModel(String modelId) throws IOException { var request = new Request("DELETE", "_inference/" + modelId); var response = client().performRequest(request); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index 1114a139330d4..63f43adcb2613 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; +import java.util.Collections; import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -57,7 +58,11 @@ public void parseRequestConfig( ) { try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map taskSettingsMap = Collections.emptyMap(); + + if (TaskType.RERANK.equals(taskType)) { + taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + } ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -66,18 +71,17 @@ public void parseRequestConfig( ); } - var modelBuilder = new HuggingFaceModelInput.Builder( - inferenceEntityId, - taskType, - serviceSettingsMap, - chunkingSettings, - serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), - ConfigurationParseContext.REQUEST - ); - var model = createModel( - TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build() + new HuggingFaceModelParameters( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), + ConfigurationParseContext.REQUEST + ) ); throwIfNotEmptyMap(config, name()); @@ -98,55 +102,61 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( Map secrets ) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + Map taskSettingsMap = Collections.emptyMap(); + + if (TaskType.RERANK.equals(taskType)) { + taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + } ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - var modelBuilder = new HuggingFaceModelInput.Builder( - inferenceEntityId, - taskType, - serviceSettingsMap, - chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, name()), - ConfigurationParseContext.PERSISTENT - ); - return createModel( - TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build() + new HuggingFaceModelParameters( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, name()), + ConfigurationParseContext.PERSISTENT + ) ); } @Override public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map taskSettingsMap = Collections.emptyMap(); + + if (TaskType.RERANK.equals(taskType)) { + taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + } ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - var modelBuilder = new HuggingFaceModelInput.Builder( - inferenceEntityId, - taskType, - serviceSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, name()), - ConfigurationParseContext.PERSISTENT - ); - return createModel( - TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build() + new HuggingFaceModelParameters( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, name()), + ConfigurationParseContext.PERSISTENT + ) ); } - protected abstract HuggingFaceModel createModel(HuggingFaceModelInput input); + protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input); @Override public void doInfer( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelInput.java deleted file mode 100644 index afd2737435e0d..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelInput.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.huggingface; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; - -import java.util.Map; - -public class HuggingFaceModelInput { - private final String inferenceEntityId; - private final TaskType taskType; - private final Map serviceSettings; - @Nullable - private final Map taskSettings; - private final ChunkingSettings chunkingSettings; - @Nullable - private final Map secretSettings; - private final String failureMessage; - private final ConfigurationParseContext context; - - public HuggingFaceModelInput(Builder builder) { - this.inferenceEntityId = builder.inferenceEntityId; - this.taskType = builder.taskType; - this.serviceSettings = builder.serviceSettings; - this.taskSettings = builder.taskSettings; - this.chunkingSettings = builder.chunkingSettings; - this.secretSettings = builder.secretSettings; - this.failureMessage = builder.failureMessage; - this.context = builder.context; - } - - public String getInferenceEntityId() { - return inferenceEntityId; - } - - public TaskType getTaskType() { - return taskType; - } - - public Map getServiceSettings() { - return serviceSettings; - } - - @Nullable - public Map getTaskSettings() { - return taskSettings; - } - - public ChunkingSettings getChunkingSettings() { - return chunkingSettings; - } - - @Nullable - public Map getSecretSettings() { - return secretSettings; - } - - public String getFailureMessage() { - return failureMessage; - } - - public ConfigurationParseContext getContext() { - return context; - } - - public static class Builder { - private String inferenceEntityId; - private TaskType taskType; - private Map serviceSettings; - @Nullable - private Map taskSettings; - private ChunkingSettings chunkingSettings; - @Nullable - Map secretSettings; - private String failureMessage; - private ConfigurationParseContext context; - - public Builder( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ) { - this.inferenceEntityId = inferenceEntityId; - this.taskType = taskType; - this.serviceSettings = serviceSettings; - this.chunkingSettings = chunkingSettings; - this.secretSettings = secretSettings; - this.failureMessage = failureMessage; - this.context = context; - } - - public Builder withTaskSettings(Map taskSettings) { - this.taskSettings = taskSettings; - return this; - } - - public HuggingFaceModelInput build() { - return new HuggingFaceModelInput(this); - } - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java new file mode 100644 index 0000000000000..6dabaa66ffb2b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.util.Map; + +public record HuggingFaceModelParameters( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage, + ConfigurationParseContext context +) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index e9158d2657106..b487c04102836 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -60,35 +60,35 @@ public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents s } @Override - protected HuggingFaceModel createModel(HuggingFaceModelInput input) { - return switch (input.getTaskType()) { + protected HuggingFaceModel createModel(HuggingFaceModelParameters input) { + return switch (input.taskType()) { case RERANK -> new HuggingFaceRerankModel( - input.getInferenceEntityId(), - input.getTaskType(), + input.inferenceEntityId(), + input.taskType(), NAME, - input.getServiceSettings(), - input.getTaskSettings(), - input.getSecretSettings(), - input.getContext() + input.serviceSettings(), + input.taskSettings(), + input.secretSettings(), + input.context() ); case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel( - input.getInferenceEntityId(), - input.getTaskType(), + input.inferenceEntityId(), + input.taskType(), NAME, - input.getServiceSettings(), - input.getChunkingSettings(), - input.getSecretSettings(), - input.getContext() + input.serviceSettings(), + input.chunkingSettings(), + input.secretSettings(), + input.context() ); case SPARSE_EMBEDDING -> new HuggingFaceElserModel( - input.getInferenceEntityId(), - input.getTaskType(), + input.inferenceEntityId(), + input.taskType(), NAME, - input.getServiceSettings(), - input.getSecretSettings(), - input.getContext() + input.serviceSettings(), + input.secretSettings(), + input.context() ); - default -> throw new ElasticsearchStatusException(input.getFailureMessage(), RestStatus.BAD_REQUEST); + default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index 5022db6b35167..642558af06b11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -37,10 +37,15 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Hugging Face %s request from inference entity id [%s]"; - static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler( - "hugging face rerank", - (request, response) -> HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response) - ); + private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s"; + static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> { + var errorMessage = format(INVALID_REQUEST_TYPE_MESSAGE, "RERANK", request != null ? request.getClass().getName() : "null"); + + if ((request instanceof HuggingFaceRerankRequest) == false) { + throw new IllegalArgumentException(errorMessage); + } + return HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response); + }); public HuggingFaceActionCreator(Sender sender, ServiceComponents serviceComponents) { this.sender = Objects.requireNonNull(sender); @@ -96,11 +101,7 @@ public ExecutableAction create(HuggingFaceElserModel model) { serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format( - "Failed to send Hugging Face %s request from inference entity id [%s]", - "ELSER", - model.getInferenceEntityId() - ); + var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index b3738b9e44542..e61995aac91f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -36,7 +36,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; -import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModelInput; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModelParameters; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -67,17 +67,17 @@ public String name() { } @Override - protected HuggingFaceModel createModel(HuggingFaceModelInput input) { - return switch (input.getTaskType()) { + protected HuggingFaceModel createModel(HuggingFaceModelParameters input) { + return switch (input.taskType()) { case SPARSE_EMBEDDING -> new HuggingFaceElserModel( - input.getInferenceEntityId(), - input.getTaskType(), + input.inferenceEntityId(), + input.taskType(), NAME, - input.getServiceSettings(), - input.getSecretSettings(), - input.getContext() + input.serviceSettings(), + input.secretSettings(), + input.context() ); - default -> throw new ElasticsearchStatusException(input.getFailureMessage(), RestStatus.BAD_REQUEST); + default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java index 87f8cb6b1d81f..9759bc9ce7c69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java @@ -12,13 +12,11 @@ import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount; import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; -import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; import java.net.URI; import java.nio.charset.StandardCharsets; @@ -34,9 +32,7 @@ public class HuggingFaceRerankRequest implements Request { private final List input; private final Boolean returnDocuments; private final Integer topN; - private final TaskSettings taskSettings; private final HuggingFaceRerankModel model; - private final String inferenceEntityId; public HuggingFaceRerankRequest( String query, @@ -52,9 +48,7 @@ public HuggingFaceRerankRequest( this.query = Objects.requireNonNull(query); this.returnDocuments = returnDocuments; this.topN = topN; - taskSettings = model.getTaskSettings(); this.model = model; - inferenceEntityId = model.getInferenceEntityId(); } @Override @@ -68,7 +62,7 @@ public HttpRequest createHttpRequest() { input, returnDocuments, topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(), - (HuggingFaceRerankTaskSettings) taskSettings, + model.getTaskSettings(), model.getServiceSettings().modelId() ) ).getBytes(StandardCharsets.UTF_8) @@ -87,7 +81,7 @@ public void decorateWithAuth(HttpPost httpPost) { @Override public String getInferenceEntityId() { - return inferenceEntityId; + return model.getInferenceEntityId(); } @Override @@ -101,11 +95,13 @@ public Integer getTopN() { @Override public Request truncate() { + // Not applicable for rerank, only used in text embedding requests return this; } @Override public boolean[] getTruncationInfo() { + // Not applicable for rerank, only used in text embedding requests return null; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java index 3d0b7e5ba5b47..59c2569ec3695 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java @@ -81,7 +81,7 @@ public DefaultSecretSettings getSecretSettings() { @Override public Integer getTokenLimit() { - return getServiceSettings().maxInputTokens(); + throw new UnsupportedOperationException("Token Limit for rerank is sent in request and not retrieved from the model"); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java index 0e925f19d12b4..3d4c6aef71e96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java @@ -26,7 +26,6 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; @@ -37,7 +36,6 @@ public class HuggingFaceRerankServiceSettings extends FilteredXContentObject public static final String NAME = "hugging_face_rerank_service_settings"; public static final String URL = "url"; - private static final int RERANK_TOKEN_LIMIT = 512; private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); @@ -73,12 +71,7 @@ public HuggingFaceRerankServiceSettings(String url) { public HuggingFaceRerankServiceSettings(StreamInput in) throws IOException { uri = createUri(in.readString()); - - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { - rateLimitSettings = new RateLimitSettings(in); - } else { - rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; - } + rateLimitSettings = new RateLimitSettings(in); } @Override @@ -91,10 +84,6 @@ public URI uri() { return uri; } - public int maxInputTokens() { - return RERANK_TOKEN_LIMIT; - } - // model is not defined in the service settings. // since hugging face requires that the model be chosen when initializing a deployment within their service. @Override @@ -114,7 +103,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { builder.field(URL, uri.toString()); - builder.field(MAX_INPUT_TOKENS, RERANK_TOKEN_LIMIT); rateLimitSettings.toXContent(builder, params); return builder; @@ -127,16 +115,13 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_12_0; + return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED; } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(uri.toString()); - - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { - rateLimitSettings.writeTo(out); - } + rateLimitSettings.writeTo(out); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java index 9e9e22196ea4e..cd23699e836a2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java @@ -118,7 +118,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_14_0; + return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java index b7f7c41fe6de9..566bd71438f7c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java @@ -7,32 +7,66 @@ package org.elasticsearch.xpack.inference.services.huggingface.response; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest; import java.io.IOException; import java.util.Comparator; +import java.util.List; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; -import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; -import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; +import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -public class HuggingFaceRerankResponseEntity extends ErrorResponse { - private static final Logger logger = LogManager.getLogger(HuggingFaceRerankResponseEntity.class); +public class HuggingFaceRerankResponseEntity { - public static InferenceServiceResults fromResponse(HuggingFaceRerankRequest request, HttpResult response) throws IOException { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Hugging Face rerank response"; + private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Hugging Face rerank "; + + /** + * Parses the Hugging Face rerank response. + + * For a request like: + * + *
+     *     
+     *         {
+     *              "input": ["luke", "like", "leia", "chewy","r2d2", "star", "wars"],
+     *              "query": "star wars main character",
+     *              "return_documents": true,
+     *              "top_n": 1
+     *          }
+     *     
+     * 
+ + * The response would look like: + + *
+     *     
+     *         {
+     *              "rerank": [
+     *                  {
+     *                       "index": 5,
+     *                       "relevance_score": -0.06920313,
+     *                       "text": "star"
+     *                   }
+     *               ]
+     *          }
+     *     
+     * 
+ */ + + public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -40,84 +74,57 @@ public static InferenceServiceResults fromResponse(HuggingFaceRerankRequest requ XContentParser.Token token = jsonParser.currentToken(); ensureExpectedToken(XContentParser.Token.START_ARRAY, token, jsonParser); - token = jsonParser.currentToken(); - if (token == XContentParser.Token.START_ARRAY) { - var rankedDocs = parseList(jsonParser, HuggingFaceRerankResponseEntity::parseRankedDocObject); - var rankedDocsByRelevanceStream = rankedDocs.stream() - .sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed()); - var rankedDocStreamTopN = request.getTopN() == null - ? rankedDocsByRelevanceStream - : rankedDocsByRelevanceStream.limit(request.getTopN()); - return new RankedDocsResults(rankedDocStreamTopN.toList()); - } else { - throwUnknownToken(token, jsonParser); - } - throw new IllegalStateException("Reached an invalid state while parsing the HuggingFace response"); + var rankedDocs = doParse(jsonParser); + var rankedDocsByRelevanceStream = rankedDocs.stream() + .sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed()); + var rankedDocStreamTopN = request.getTopN() == null + ? rankedDocsByRelevanceStream + : rankedDocsByRelevanceStream.limit(request.getTopN()); + return new RankedDocsResults(rankedDocStreamTopN.toList()); } } - private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - int index = -1; - float relevanceScore = -1; - String documentText = null; - parser.nextToken(); - while (parser.currentToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "index": - parser.nextToken(); // move to VALUE_NUMBER - index = parser.intValue(); - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT - break; - case "score": - parser.nextToken(); // move to VALUE_NUMBER - relevanceScore = parser.floatValue(); - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT - break; - case "relevance_score": - parser.nextToken(); // move to VALUE_NUMBER - relevanceScore = parser.floatValue(); - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT - break; - case "text": - parser.nextToken(); // move to VALUE_NUMBER - documentText = parser.text(); - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT - break; - case "document": - parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - do { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { - parser.nextToken(); // move to VALUE_STRING - documentText = parser.text(); - } - } while (parser.nextToken() != XContentParser.Token.END_OBJECT); - parser.nextToken();// move past END_OBJECT - // parser should now be at the next FIELD_NAME or END_OBJECT - break; - default: - throwUnknownField(parser.currentName(), parser); - } - } else { - parser.nextToken(); + private static List doParse(XContentParser parser) throws IOException { + return parseList(parser, (listParser, index) -> { + var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser); + + if (parsedRankedDoc.id == null) { + throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDocEntry.ID.getPreferredName())); } - } - if (index == -1) { - logger.warn("Failed to find required field [index] in HuggingFace rerank response"); - } - if (relevanceScore == -1) { - logger.warn("Failed to find required field [relevance_score] in HuggingFace rerank response"); - } - // documentText may or may not be present depending on the request parameter + if (parsedRankedDoc.score == null) { + throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDocEntry.SCORE.getPreferredName())); + } - return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); + try { + return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text); + } catch (NumberFormatException e) { + throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id)); + } + }); } - private HuggingFaceRerankResponseEntity(String errorMessage) { - super(errorMessage); + private record RankedDocEntry(@Nullable Integer id, @Nullable Float score, @Nullable String text) { + + private static final ParseField TEXT = new ParseField("text"); + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField ID = new ParseField("index"); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "hugging_face_rerank_response", + true, + args -> new HuggingFaceRerankResponseEntity.RankedDocEntry((int) args[0], (float) args[1], (String) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), ID); + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), SCORE); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TEXT); + } + + public static RankedDocEntry parse(XContentParser parser) { + return PARSER.apply(parser, null); + } } } From a4ebb87d7f2bfe2cc4fec3f7000e867b3e625598 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Tue, 13 May 2025 14:04:56 +0200 Subject: [PATCH 03/12] Add transport version --- server/src/main/java/org/elasticsearch/TransportVersions.java | 2 +- .../inference/services/huggingface/HuggingFaceServiceTests.java | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 228d116d31266..0ecbcb964acba 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -254,7 +254,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_0_00); public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00); - + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_078_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index f1c2a458afdba..9f33580bd6d61 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -885,13 +885,11 @@ private Map getRequestConfigMap( private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { var builtServiceSettings = new HashMap<>(); - var builtTaskSettings = new HashMap<>(); builtServiceSettings.putAll(serviceSettings); builtServiceSettings.putAll(secretSettings); Map map = new HashMap<>(); map.put(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings); - map.put(ModelConfigurations.TASK_SETTINGS, builtTaskSettings); return map; } From f74e9e549b8811c88e99b550e94fbc2c7cf5fb2e Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Tue, 13 May 2025 15:04:29 +0200 Subject: [PATCH 04/12] Add transport version --- server/src/main/java/org/elasticsearch/TransportVersions.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 0ecbcb964acba..171138272d427 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -254,6 +254,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_0_00); public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00); + public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00); public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_078_0_00); /* * STOP! READ THIS FIRST! No, really, From 00548915417c048fe9a0cf39238cb402e17eee56 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Wed, 14 May 2025 11:34:27 +0200 Subject: [PATCH 05/12] Add to inference service and crud IT rerank tests --- .../inference/InferenceBaseRestTest.java | 14 ++++ .../xpack/inference/InferenceCrudIT.java | 15 ++++- .../MockRerankInferenceServiceIT.java | 67 +++++++++++++++++++ 3 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index b910985031e7c..d76d76a8d516c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -171,6 +171,20 @@ static String mockDenseServiceModelConfig() { """; } + static String mockRerankServiceModelConfig() { + return """ + { + "service": "test_reranking_service", + "service_settings": { + "model_id": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + """; + } + static void deleteModel(String modelId) throws IOException { var request = new Request("DELETE", "_inference/" + modelId); var response = client().performRequest(request); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 1b01f37955e5a..f85874d6e580c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -53,9 +53,12 @@ public void testCRUD() throws IOException { for (int i = 0; i < 4; i++) { putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); } + for (int i = 0; i < 3; i++) { + putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK); + } var getAllModels = getAllModels(); - int numModels = 12; + int numModels = 15; assertThat(getAllModels, hasSize(numModels)); var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING); @@ -71,6 +74,13 @@ public void testCRUD() throws IOException { for (var denseModel : getDenseModels) { assertEquals("text_embedding", denseModel.get("task_type")); } + + var getRerankModels = getModels("_all", TaskType.RERANK); + int numRerankModels = 4; + assertThat(getRerankModels, hasSize(numRerankModels)); + for (var denseModel : getRerankModels) { + assertEquals("rerank", denseModel.get("task_type")); + } String oldApiKey; { var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING); @@ -100,6 +110,9 @@ public void testCRUD() throws IOException { for (int i = 0; i < 4; i++) { deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING); } + for (int i = 0; i < 3; i++) { + deleteModel("re-model-" + i, TaskType.RERANK); + } } public void testGetModelWithWrongTaskType() throws IOException { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java new file mode 100644 index 0000000000000..2524b8c8c9ae7 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class MockRerankInferenceServiceIT extends InferenceBaseRestTest { + + @SuppressWarnings("unchecked") + public void testMockService() throws IOException { + String inferenceEntityId = "test-mock"; + var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK); + var model = getModels(inferenceEntityId, TaskType.RERANK).get(0); + + for (var modelMap : List.of(putModel, model)) { + assertEquals(inferenceEntityId, modelMap.get("inference_id")); + assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type"))); + assertEquals("test_reranking_service", modelMap.get("service")); + } + + List input = List.of(randomAlphaOfLength(10)); + var inference = infer(inferenceEntityId, input); + assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK); + assertEquals(inference, infer(inferenceEntityId, input)); + assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))); + } + + public void testMockServiceWithMultipleInputs() throws IOException { + String inferenceEntityId = "test-mock-with-multi-inputs"; + putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK); + var queryParams = Map.of("timeout", "120s"); + + var inference = infer( + inferenceEntityId, + TaskType.RERANK, + List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)), + "What if?", + queryParams + ); + + assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK); + } + + @SuppressWarnings("unchecked") + public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException { + String inferenceEntityId = "test-mock"; + var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK); + var model = getModels(inferenceEntityId, TaskType.RERANK).get(0); + + var serviceSettings = (Map) model.get("service_settings"); + assertNull(serviceSettings.get("api_key")); + assertNotNull(serviceSettings.get("model_id")); + + var putServiceSettings = (Map) putModel.get("service_settings"); + assertNull(putServiceSettings.get("api_key")); + assertNotNull(putServiceSettings.get("model_id")); + } +} From 733818c0d1773ef34212a3d9411a247c03391085 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Fri, 16 May 2025 10:10:10 +0200 Subject: [PATCH 06/12] Refactor slightly / error message --- .../action/HuggingFaceActionCreator.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index 642558af06b11..103a11828826e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.huggingface.action; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -38,7 +39,7 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Hugging Face %s request from inference entity id [%s]"; private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s"; - static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> { + private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> { var errorMessage = format(INVALID_REQUEST_TYPE_MESSAGE, "RERANK", request != null ? request.getClass().getName() : "null"); if ((request instanceof HuggingFaceRerankRequest) == false) { @@ -68,7 +69,7 @@ public ExecutableAction create(HuggingFaceRerankModel model) { ), QueryAndDocsInputs.class ); - var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "RERANK", model.getInferenceEntityId()); + var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId()); return new SenderExecutableAction(sender, manager, errorMessage); } @@ -84,11 +85,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) { serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format( - "Failed to send Hugging Face %s request from inference entity id [%s]", - "text embeddings", - model.getInferenceEntityId() - ); + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } @@ -101,7 +98,11 @@ public ExecutableAction create(HuggingFaceElserModel model) { serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId()); + var errorMessage = buildErrorMessage(TaskType.SPARSE_EMBEDDING, model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } + + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } } From f97f81814cacf92885834f6a8f683380dae73143 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Mon, 19 May 2025 20:49:27 +0200 Subject: [PATCH 07/12] correct 'testGetConfiguration' test case --- .../services/huggingface/HuggingFaceServiceTests.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 76149cef4586a..fcc3c5bfb98fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -1292,7 +1292,7 @@ public void testGetConfiguration() throws Exception { { "service": "hugging_face", "name": "Hugging Face", - "task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion", "rerank"], + "task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"], "configurations": { "api_key": { "description": "API Key for the provider you're connecting to.", @@ -1301,7 +1301,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion", "rerank"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1310,7 +1310,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion", "rerank"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"] }, "url": { "description": "The URL endpoint to use for the requests.", @@ -1319,7 +1319,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion", "rerank"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"] } } } From a52a1d862ff0f3e04e6f386665a5be1657ba05f6 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Tue, 20 May 2025 23:50:18 +0200 Subject: [PATCH 08/12] apply suggestions --- .../results/RankedDocsResultsTests.java | 16 +++ .../huggingface/HuggingFaceBaseService.java | 19 +--- .../huggingface/HuggingFaceService.java | 50 ++++----- .../action/HuggingFaceActionCreator.java | 7 +- .../rerank/HuggingFaceRerankRequest.java | 7 +- .../HuggingFaceRerankRequestEntity.java | 12 --- .../rerank/HuggingFaceRerankModel.java | 2 +- .../HuggingFaceRerankResponseEntity.java | 52 +++------ .../HuggingFaceRerankRequestEntityTests.java | 27 ++--- .../HuggingFaceRerankResponseEntityTests.java | 100 ++++++++++++++++++ 10 files changed, 178 insertions(+), 114 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java index ff6f6848f4b69..4971248e87431 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java @@ -82,4 +82,20 @@ private List rankedDocsNullStringToEmpty(List rankedDocFields) {} + + public static Map buildExpectationRerank(List rerank) { + return Map.of( + RankedDocsResults.RERANK, + rerank.stream() + .map( + rerankExpectation -> Map.of( + RankedDocsResults.RankedDoc.NAME, + rerankExpectation.rankedDocFields + ) + ) + .toList() + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index 63f43adcb2613..b0d40b41914d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -26,7 +26,6 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; -import java.util.Collections; import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -58,11 +57,7 @@ public void parseRequestConfig( ) { try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = Collections.emptyMap(); - - if (TaskType.RERANK.equals(taskType)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -102,12 +97,8 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( Map secrets ) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - Map taskSettingsMap = Collections.emptyMap(); - - if (TaskType.RERANK.equals(taskType)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -131,11 +122,7 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( @Override public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = Collections.emptyMap(); - - if (TaskType.RERANK.equals(taskType)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index c5adc7cfcbfeb..d10fb77290c6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -76,43 +76,43 @@ public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents s } @Override - protected HuggingFaceModel createModel(HuggingFaceModelParameters input) { - return switch (input.taskType()) { + protected HuggingFaceModel createModel(HuggingFaceModelParameters params) { + return switch (params.taskType()) { case RERANK -> new HuggingFaceRerankModel( - input.inferenceEntityId(), - input.taskType(), + params.inferenceEntityId(), + params.taskType(), NAME, - input.serviceSettings(), - input.taskSettings(), - input.secretSettings(), - input.context() + params.serviceSettings(), + params.taskSettings(), + params.secretSettings(), + params.context() ); case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel( - input.inferenceEntityId(), - input.taskType(), + params.inferenceEntityId(), + params.taskType(), NAME, - input.serviceSettings(), - input.chunkingSettings(), - input.secretSettings(), - input.context() + params.serviceSettings(), + params.chunkingSettings(), + params.secretSettings(), + params.context() ); case SPARSE_EMBEDDING -> new HuggingFaceElserModel( - input.inferenceEntityId(), - input.taskType(), + params.inferenceEntityId(), + params.taskType(), NAME, - input.serviceSettings(), - input.secretSettings(), - input.context() + params.serviceSettings(), + params.secretSettings(), + params.context() ); case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel( - input.inferenceEntityId(), - input.taskType(), + params.inferenceEntityId(), + params.taskType(), NAME, - input.serviceSettings(), - input.secretSettings(), - input.context() + params.serviceSettings(), + params.secretSettings(), + params.context() ); - default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST); + default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index af6c2eab11091..acd5a9e79277c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -51,9 +51,12 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { OpenAiChatCompletionResponseEntity::fromResponse ); private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> { - var errorMessage = format(INVALID_REQUEST_TYPE_MESSAGE, "RERANK", request != null ? request.getClass().getName() : "null"); - if ((request instanceof HuggingFaceRerankRequest) == false) { + var errorMessage = format( + INVALID_REQUEST_TYPE_MESSAGE, + "RERANK", + request != null ? request.getClass().getSimpleName() : "null" + ); throw new IllegalArgumentException(errorMessage); } return HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java index 9759bc9ce7c69..c8e833d78c3b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java @@ -62,20 +62,19 @@ public HttpRequest createHttpRequest() { input, returnDocuments, topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(), - model.getTaskSettings(), - model.getServiceSettings().modelId() + model.getTaskSettings() ) ).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); - httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); decorateWithAuth(httpPost); return new HttpRequest(httpPost, getInferenceEntityId()); } - public void decorateWithAuth(HttpPost httpPost) { + void decorateWithAuth(HttpPost httpPost) { httpPost.setHeader(createAuthBearerHeader(model.apiKey())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java index c2ad0a622072d..7ae161ee35657 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java @@ -17,7 +17,6 @@ import java.util.Objects; public record HuggingFaceRerankRequestEntity( - String model, String query, List documents, @Nullable Boolean returnDocuments, @@ -34,17 +33,6 @@ public record HuggingFaceRerankRequestEntity( Objects.requireNonNull(taskSettings); } - public HuggingFaceRerankRequestEntity( - String query, - List input, - @Nullable Boolean returnDocuments, - @Nullable Integer topN, - HuggingFaceRerankTaskSettings taskSettings, - String model - ) { - this(model, query, input, returnDocuments, topN, taskSettings); - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java index 59c2569ec3695..4464e3aaf2dd1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java @@ -44,7 +44,7 @@ public HuggingFaceRerankModel( } // Should only be used directly for testing - public HuggingFaceRerankModel( + HuggingFaceRerankModel( String inferenceEntityId, TaskType taskType, String service, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java index 566bd71438f7c..9382e19b78449 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java @@ -23,16 +23,11 @@ import java.util.Comparator; import java.util.List; -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; -import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; public class HuggingFaceRerankResponseEntity { - private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Hugging Face rerank response"; - private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Hugging Face rerank "; - /** * Parses the Hugging Face rerank response. @@ -41,10 +36,9 @@ public class HuggingFaceRerankResponseEntity { *
      *     
      *         {
-     *              "input": ["luke", "like", "leia", "chewy","r2d2", "star", "wars"],
+     *              "texts": ["luke", "leia"],
      *              "query": "star wars main character",
-     *              "return_documents": true,
-     *              "top_n": 1
+     *              "return_text": true
      *          }
      *     
      * 
@@ -53,15 +47,18 @@ public class HuggingFaceRerankResponseEntity { *
      *     
-     *         {
-     *              "rerank": [
-     *                  {
-     *                       "index": 5,
-     *                       "relevance_score": -0.06920313,
-     *                       "text": "star"
-     *                   }
-     *               ]
-     *          }
+     *         [
+     *              {
+     *                   "index": 0,
+     *                   "score": -0.07996220886707306,
+     *                   "text": "luke"
+     *               },
+     *               {
+     *                  "index": 1,
+     *                  "score": -0.08393221348524094,
+     *                  "text": "leia"
+     *              }
+     *         ]
      *     
      * 
*/ @@ -71,10 +68,6 @@ public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, H try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { moveToFirstToken(jsonParser); - - XContentParser.Token token = jsonParser.currentToken(); - ensureExpectedToken(XContentParser.Token.START_ARRAY, token, jsonParser); - var rankedDocs = doParse(jsonParser); var rankedDocsByRelevanceStream = rankedDocs.stream() .sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed()); @@ -88,24 +81,11 @@ public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, H private static List doParse(XContentParser parser) throws IOException { return parseList(parser, (listParser, index) -> { var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser); - - if (parsedRankedDoc.id == null) { - throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDocEntry.ID.getPreferredName())); - } - - if (parsedRankedDoc.score == null) { - throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDocEntry.SCORE.getPreferredName())); - } - - try { - return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text); - } catch (NumberFormatException e) { - throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id)); - } + return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text); }); } - private record RankedDocEntry(@Nullable Integer id, @Nullable Float score, @Nullable String text) { + private record RankedDocEntry(Integer id, Float score, @Nullable String text) { private static final ParseField TEXT = new ParseField("text"); private static final ParseField SCORE = new ParseField("score"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java index 0466ab15e9f0f..8e7502defeb8d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java @@ -18,12 +18,11 @@ import java.io.IOException; import java.util.List; -import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; public class HuggingFaceRerankRequestEntityTests extends ESTestCase { private static final String INPUT = "texts"; private static final String QUERY = "query"; - private static final String INFERENCE_ID = "model"; private static final Integer TOP_N = 8; private static final Boolean RETURN_DOCUMENTS = false; @@ -33,36 +32,28 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException List.of(INPUT), Boolean.TRUE, TOP_N, - new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS), - INFERENCE_ID + new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + String expected = """ {"texts":["texts"], "query":"query", "return_text":true, - "top_n":8}""")); + "top_n":8}"""; + assertEquals(stripWhitespace(expected), xContentResult); } public void testXContent_WritesMinimalFields() throws IOException { - var entity = new HuggingFaceRerankRequestEntity( - QUERY, - List.of(INPUT), - null, - null, - new HuggingFaceRerankTaskSettings(null, null), - INFERENCE_ID - ); + var entity = new HuggingFaceRerankRequestEntity(QUERY, List.of(INPUT), null, null, new HuggingFaceRerankTaskSettings(null, null)); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - {"texts":["texts"],"query":"query"}""")); + String expected = """ + {"texts":["texts"],"query":"query"}"""; + assertEquals(stripWhitespace(expected), xContentResult); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java new file mode 100644 index 0000000000000..f7009a5ba3545 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HuggingFaceRerankResponseEntityTests extends ESTestCase { + private static final String MISSED_FIELD_INDEX = "index"; + private static final String MISSED_FIELD_SCORE = "score"; + + public void testFromResponse_CreatesRankedDocsResults() throws IOException { + String responseJson = """ + [ + { + "index": 0, + "score": -0.07996220886707306, + "text": "luke" + } + ] + """; + HuggingFaceRerankRequest huggingFaceRerankRequestMock = mock(HuggingFaceRerankRequest.class); + when(huggingFaceRerankRequestMock.getTopN()).thenReturn(1); + + RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse( + huggingFaceRerankRequestMock, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.asMap(), + is( + buildExpectationRerank( + List.of( + new RankedDocsResultsTests.RerankExpectation( + Map.of("index", 0, "relevance_score", -0.07996220886707306F, "text", "luke") + ) + ) + ) + ) + ); + } + + public void testFails_CreateRankedDocsResults_IndexFieldNull() { + String responseJson = """ + [ + { + "score": -0.07996220886707306, + "text": "luke" + } + ] + """; + assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_INDEX); + } + + public void testFails_CreateRankedDocsResults_ScoreFieldNull() { + String responseJson = """ + [ + { + "index": 0, + "text": "luke" + } + ] + """; + assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_SCORE); + } + + private void assertMissingFieldThrowsIllegalArgumentException(String responseJson, String missingField) { + HuggingFaceRerankRequest huggingFaceRerankRequestMock = mock(HuggingFaceRerankRequest.class); + when(huggingFaceRerankRequestMock.getTopN()).thenReturn(1); + + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> HuggingFaceRerankResponseEntity.fromResponse( + huggingFaceRerankRequestMock, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + assertThat(thrownException.getMessage(), is("Required [%s]".formatted(missingField))); + } +} From c8c74d672c481872142df5e35714b2e9bc9a9507 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Wed, 21 May 2025 15:17:10 +0200 Subject: [PATCH 09/12] fix tests --- .../action/HuggingFaceActionCreatorTests.java | 11 +++++++---- .../request/rerank/HuggingFaceRerankRequestTests.java | 2 +- .../HuggingFaceRerankResponseEntityTests.java | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 694cdb3b4eaa9..f5d700016bf81 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -358,11 +358,14 @@ public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOE private void assertRerankActionCreator(List documents, String query, int topN, boolean returnText) throws IOException { assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("texts"), is(documents)); assertThat(requestMap.get("query"), is(query)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java index 0cb2bbb3bbaee..fc4365441f7d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java @@ -51,7 +51,7 @@ private void testCreateRequest(Integer topN, Boolean returnDocuments) throws IOE assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java index f7009a5ba3545..161ec7e87904f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java @@ -95,6 +95,6 @@ private void assertMissingFieldThrowsIllegalArgumentException(String responseJso new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); - assertThat(thrownException.getMessage(), is("Required [%s]".formatted(missingField))); + assertThat(thrownException.getMessage(), is("Required [" + missingField + "]")); } } From ae1a1d23a9a8190b4d5d2c9e71f3dcd1c5c313c2 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Wed, 21 May 2025 17:34:45 +0200 Subject: [PATCH 10/12] apply suggestions --- .../rerank/HuggingFaceRerankRequest.java | 13 +- .../HuggingFaceRerankRequestEntity.java | 5 +- .../rerank/HuggingFaceRerankTaskSettings.java | 10 +- .../HuggingFaceRerankResponseEntity.java | 8 +- .../HuggingFaceRerankTaskSettingsTests.java | 22 ++-- .../HuggingFaceRerankResponseEntityTests.java | 117 ++++++++++++------ 6 files changed, 111 insertions(+), 64 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java index c8e833d78c3b2..87fff4caa515d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java @@ -56,15 +56,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString( - new HuggingFaceRerankRequestEntity( - query, - input, - returnDocuments, - topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(), - model.getTaskSettings() - ) - ).getBytes(StandardCharsets.UTF_8) + Strings.toString(new HuggingFaceRerankRequestEntity(query, input, returnDocuments, getTopN(), model.getTaskSettings())) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); @@ -89,7 +82,7 @@ public URI getURI() { } public Integer getTopN() { - return topN; + return topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java index 7ae161ee35657..463dd50fc14b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java @@ -24,6 +24,7 @@ public record HuggingFaceRerankRequestEntity( HuggingFaceRerankTaskSettings taskSettings ) implements ToXContentObject { + private static final String RETURN_TEXT = "return_text"; private static final String DOCUMENTS_FIELD = "texts"; private static final String QUERY_FIELD = "query"; @@ -42,9 +43,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws // prefer the root level return_documents over task settings if (returnDocuments != null) { - builder.field(HuggingFaceRerankTaskSettings.RETURN_TEXT, returnDocuments); + builder.field(RETURN_TEXT, returnDocuments); } else if (taskSettings.getReturnDocuments() != null) { - builder.field(HuggingFaceRerankTaskSettings.RETURN_TEXT, taskSettings.getReturnDocuments()); + builder.field(RETURN_TEXT, taskSettings.getReturnDocuments()); } if (topN != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java index cd23699e836a2..9f90386edff90 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java @@ -28,7 +28,7 @@ public class HuggingFaceRerankTaskSettings implements TaskSettings { public static final String NAME = "hugging_face_rerank_task_settings"; - public static final String RETURN_TEXT = "return_text"; + public static final String RETURN_DOCUMENTS = "return_documents"; public static final String TOP_N_DOCS_ONLY = "top_n"; static final HuggingFaceRerankTaskSettings EMPTY_SETTINGS = new HuggingFaceRerankTaskSettings(null, null); @@ -40,7 +40,7 @@ public static HuggingFaceRerankTaskSettings fromMap(Map map) { return EMPTY_SETTINGS; } - Boolean returnDocuments = extractOptionalBoolean(map, RETURN_TEXT, validationException); + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); Integer topNDocumentsOnly = extractOptionalPositiveInteger( map, TOP_N_DOCS_ONLY, @@ -85,7 +85,7 @@ public static HuggingFaceRerankTaskSettings of(Integer topNDocumentsOnly, Boolea private final Boolean returnDocuments; public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException { - this(in.readOptionalInt(), in.readOptionalBoolean()); + this(in.readOptionalVInt(), in.readOptionalBoolean()); } public HuggingFaceRerankTaskSettings(@Nullable Integer topNDocumentsOnly, @Nullable Boolean doReturnDocuments) { @@ -105,7 +105,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly); } if (returnDocuments != null) { - builder.field(RETURN_TEXT, returnDocuments); + builder.field(RETURN_DOCUMENTS, returnDocuments); } builder.endObject(); return builder; @@ -123,7 +123,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalInt(topNDocumentsOnly); + out.writeOptionalVInt(topNDocumentsOnly); out.writeOptionalBoolean(returnDocuments); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java index 9382e19b78449..64a7d4845236a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java @@ -81,15 +81,15 @@ public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, H private static List doParse(XContentParser parser) throws IOException { return parseList(parser, (listParser, index) -> { var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser); - return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text); + return new RankedDocsResults.RankedDoc(parsedRankedDoc.index, parsedRankedDoc.score, parsedRankedDoc.text); }); } - private record RankedDocEntry(Integer id, Float score, @Nullable String text) { + private record RankedDocEntry(Integer index, Float score, @Nullable String text) { private static final ParseField TEXT = new ParseField("text"); private static final ParseField SCORE = new ParseField("score"); - private static final ParseField ID = new ParseField("index"); + private static final ParseField INDEX = new ParseField("index"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "hugging_face_rerank_response", @@ -98,7 +98,7 @@ private record RankedDocEntry(Integer id, Float score, @Nullable String text) { ); static { - PARSER.declareInt(ConstructingObjectParser.constructorArg(), ID); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), INDEX); PARSER.declareFloat(ConstructingObjectParser.constructorArg(), SCORE); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TEXT); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java index a45f2ac9766cc..301a797cd990f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java @@ -7,9 +7,10 @@ package org.elasticsearch.xpack.inference.services.huggingface.rerank; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.util.HashMap; @@ -17,7 +18,7 @@ import static org.hamcrest.Matchers.containsString; -public class HuggingFaceRerankTaskSettingsTests extends AbstractWireSerializingTestCase { +public class HuggingFaceRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase { public static HuggingFaceRerankTaskSettings createRandom() { var returnDocuments = randomBoolean() ? randomBoolean() : null; @@ -28,7 +29,7 @@ public static HuggingFaceRerankTaskSettings createRandom() { public void testFromMap_WithValidValues_ReturnsSettings() { Map taskMap = Map.of( - HuggingFaceRerankTaskSettings.RETURN_TEXT, + HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, true, HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, 5 @@ -46,18 +47,18 @@ public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { Map taskMap = Map.of( - HuggingFaceRerankTaskSettings.RETURN_TEXT, + HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, "invalid", HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, 5 ); var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap))); - assertThat(thrownException.getMessage(), containsString("field [return_text] is not of the expected type")); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); } public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { Map taskMap = Map.of( - HuggingFaceRerankTaskSettings.RETURN_TEXT, + HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, true, HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, "invalid" @@ -74,7 +75,7 @@ public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { var initialSettings = new HuggingFaceRerankTaskSettings(5, true); - Map newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_TEXT, false); + Map newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, false); HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); assertFalse(updatedSettings.getReturnDocuments()); assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly()); @@ -91,7 +92,7 @@ public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { var initialSettings = new HuggingFaceRerankTaskSettings(5, true); Map newSettings = Map.of( - HuggingFaceRerankTaskSettings.RETURN_TEXT, + HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, false, HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, 7 @@ -115,4 +116,9 @@ protected HuggingFaceRerankTaskSettings createTestInstance() { protected HuggingFaceRerankTaskSettings mutateInstance(HuggingFaceRerankTaskSettings instance) throws IOException { return randomValueOtherThan(instance, HuggingFaceRerankTaskSettingsTests::createRandom); } + + @Override + protected HuggingFaceRerankTaskSettings mutateInstanceForVersion(HuggingFaceRerankTaskSettings instance, TransportVersion version) { + return instance; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java index 161ec7e87904f..5d1b14c46f099 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java @@ -27,45 +27,82 @@ public class HuggingFaceRerankResponseEntityTests extends ESTestCase { private static final String MISSED_FIELD_INDEX = "index"; private static final String MISSED_FIELD_SCORE = "score"; + private static final String RESPONSE_JSON_TWO_DOCS = """ + [ + { + "index": 4, + "score": -0.22222222222222222, + "text": "ranked second" + }, + { + "index": 1, + "score": 1.11111111111111111, + "text": "ranked first" + } + ] + """; + private static final List EXPECTED_TWO_DOCS = List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second")) + ); - public void testFromResponse_CreatesRankedDocsResults() throws IOException { - String responseJson = """ - [ - { - "index": 0, - "score": -0.07996220886707306, - "text": "luke" - } - ] - """; - HuggingFaceRerankRequest huggingFaceRerankRequestMock = mock(HuggingFaceRerankRequest.class); - when(huggingFaceRerankRequestMock.getTopN()).thenReturn(1); + private static final String RESPONSE_JSON_FIVE_DOCS = """ + [ + { + "index": 1, + "score": 1.11111111111111111, + "text": "ranked first" + }, + { + "index": 3, + "score": -0.33333333333333333, + "text": "ranked third" + }, + { + "index": 0, + "score": -0.55555555555555555, + "text": "ranked fifth" + }, + { + "index": 2, + "score": -0.44444444444444444, + "text": "ranked fourth" + }, + { + "index": 4, + "score": -0.22222222222222222, + "text": "ranked second" + } + ] + """; + private static final List EXPECTED_FIVE_DOCS = List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 3, "relevance_score", -0.33333333333333333F, "text", "ranked third")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 2, "relevance_score", -0.44444444444444444F, "text", "ranked fourth")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", -0.55555555555555555F, "text", "ranked fifth")) + ); - RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse( - huggingFaceRerankRequestMock, - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ); + private static final HuggingFaceRerankRequest REQUEST_MOCK = mock(HuggingFaceRerankRequest.class); - assertThat( - parsedResults.asMap(), - is( - buildExpectationRerank( - List.of( - new RankedDocsResultsTests.RerankExpectation( - Map.of("index", 0, "relevance_score", -0.07996220886707306F, "text", "luke") - ) - ) - ) - ) - ); + public void testFromResponse_CreatesRankedDocsResults_TopNNull_FiveDocs_NoLimit() throws IOException { + assertTopNLimit(null, RESPONSE_JSON_FIVE_DOCS, EXPECTED_FIVE_DOCS); + } + + public void testFromResponse_CreatesRankedDocsResults_TopN5_TwoDocs_NoLimit() throws IOException { + assertTopNLimit(5, RESPONSE_JSON_TWO_DOCS, EXPECTED_TWO_DOCS); + } + + public void testFromResponse_CreatesRankedDocsResults_TopN2_FiveDocs_Limits() throws IOException { + assertTopNLimit(2, RESPONSE_JSON_FIVE_DOCS, EXPECTED_TWO_DOCS); } public void testFails_CreateRankedDocsResults_IndexFieldNull() { String responseJson = """ [ { - "score": -0.07996220886707306, - "text": "luke" + "score": 1.11111111111111111, + "text": "ranked first" } ] """; @@ -76,8 +113,8 @@ public void testFails_CreateRankedDocsResults_ScoreFieldNull() { String responseJson = """ [ { - "index": 0, - "text": "luke" + "index": 1, + "text": "ranked first" } ] """; @@ -85,16 +122,26 @@ public void testFails_CreateRankedDocsResults_ScoreFieldNull() { } private void assertMissingFieldThrowsIllegalArgumentException(String responseJson, String missingField) { - HuggingFaceRerankRequest huggingFaceRerankRequestMock = mock(HuggingFaceRerankRequest.class); - when(huggingFaceRerankRequestMock.getTopN()).thenReturn(1); + when(REQUEST_MOCK.getTopN()).thenReturn(1); var thrownException = expectThrows( IllegalArgumentException.class, () -> HuggingFaceRerankResponseEntity.fromResponse( - huggingFaceRerankRequestMock, + REQUEST_MOCK, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); assertThat(thrownException.getMessage(), is("Required [" + missingField + "]")); } + + private void assertTopNLimit( + Integer topN, String responseJson, List expectation) throws IOException { + when(REQUEST_MOCK.getTopN()).thenReturn(topN); + + RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse( + REQUEST_MOCK, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertThat(parsedResults.asMap(), is(buildExpectationRerank(expectation))); + } } From 7f30c6a38f9233a195a5ffe033499fd05bc322f3 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 22 May 2025 02:35:29 +0000 Subject: [PATCH 11/12] [CI] Auto commit changes from spotless --- .../core/inference/results/RankedDocsResultsTests.java | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java index 4971248e87431..6873d0e7642d5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java @@ -88,14 +88,7 @@ public record RerankExpectation(Map rankedDocFields) {} public static Map buildExpectationRerank(List rerank) { return Map.of( RankedDocsResults.RERANK, - rerank.stream() - .map( - rerankExpectation -> Map.of( - RankedDocsResults.RankedDoc.NAME, - rerankExpectation.rankedDocFields - ) - ) - .toList() + rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList() ); } } From 4ee7f1fcf4c05fff3be38efba751ab894de4de27 Mon Sep 17 00:00:00 2001 From: Evgenii_Kazannik Date: Thu, 22 May 2025 18:16:06 +0200 Subject: [PATCH 12/12] add changelog information --- docs/changelog/127966.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/127966.yaml diff --git a/docs/changelog/127966.yaml b/docs/changelog/127966.yaml new file mode 100644 index 0000000000000..0c896715149bf --- /dev/null +++ b/docs/changelog/127966.yaml @@ -0,0 +1,5 @@ +pr: 127966 +summary: "[ML] Add Rerank support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: []