diff --git a/docs/changelog/127966.yaml b/docs/changelog/127966.yaml new file mode 100644 index 0000000000000..0c896715149bf --- /dev/null +++ b/docs/changelog/127966.yaml @@ -0,0 +1,5 @@ +pr: 127966 +summary: "[ML] Add Rerank support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2c119df3f8f37..8aa3f894bae64 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -179,6 +179,7 @@ static TransportVersion def(int id) { public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32); public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34); public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -261,7 +262,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00); public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00); public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00); - + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java b/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java index 7e68f41de1b2e..e99b2178c50d4 100644 --- a/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java +++ b/server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java @@ -20,9 +20,8 @@ public class SettingsConfigurationTestUtils { public static SettingsConfiguration getRandomSettingsConfigurationField() { - return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue( - randomAlphaOfLength(10) - ) + return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) + .setDefaultValue(randomAlphaOfLength(10)) .setDescription(randomAlphaOfLength(10)) .setLabel(randomAlphaOfLength(10)) .setRequired(randomBoolean()) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java index ff6f6848f4b69..6873d0e7642d5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java @@ -82,4 +82,13 @@ private List rankedDocsNullStringToEmpty(List rankedDocFields) {} + + public static Map buildExpectationRerank(List rerank) { + return Map.of( + RankedDocsResults.RERANK, + rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList() + ); + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 54d0aca772061..d76d76a8d516c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -171,6 +171,20 @@ static String mockDenseServiceModelConfig() { """; } + static String mockRerankServiceModelConfig() { + return """ + { + "service": "test_reranking_service", + "service_settings": { + "model_id": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + """; + } + static void deleteModel(String modelId) throws IOException { var request = new Request("DELETE", "_inference/" + modelId); var response = client().performRequest(request); @@ -484,6 +498,10 @@ private String jsonBody(List input, @Nullable String query) { @SuppressWarnings("unchecked") protected void assertNonEmptyInferenceResults(Map resultMap, int expectedNumberOfResults, TaskType taskType) { switch (taskType) { + case RERANK -> { + var results = (List>) resultMap.get(TaskType.RERANK.toString()); + assertThat(results, hasSize(expectedNumberOfResults)); + } case SPARSE_EMBEDDING -> { var results = (List>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString()); assertThat(results, hasSize(expectedNumberOfResults)); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 1b01f37955e5a..f85874d6e580c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -53,9 +53,12 @@ public void testCRUD() throws IOException { for (int i = 0; i < 4; i++) { putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); } + for (int i = 0; i < 3; i++) { + putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK); + } var getAllModels = getAllModels(); - int numModels = 12; + int numModels = 15; assertThat(getAllModels, hasSize(numModels)); var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING); @@ -71,6 +74,13 @@ public void testCRUD() throws IOException { for (var denseModel : getDenseModels) { assertEquals("text_embedding", denseModel.get("task_type")); } + + var getRerankModels = getModels("_all", TaskType.RERANK); + int numRerankModels = 4; + assertThat(getRerankModels, hasSize(numRerankModels)); + for (var denseModel : getRerankModels) { + assertEquals("rerank", denseModel.get("task_type")); + } String oldApiKey; { var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING); @@ -100,6 +110,9 @@ public void testCRUD() throws IOException { for (int i = 0; i < 4; i++) { deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING); } + for (int i = 0; i < 3; i++) { + deleteModel("re-model-" + i, TaskType.RERANK); + } } public void testGetModelWithWrongTaskType() throws IOException { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 1aa302fcb2b32..52bd95e9d2619 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(7)); + assertThat(services.size(), equalTo(8)); var providers = providers(services); @@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException { "googlevertexai", "jinaai", "test_reranking_service", - "voyageai" + "voyageai", + "hugging_face" ).toArray() ) ); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java new file mode 100644 index 0000000000000..2524b8c8c9ae7 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class MockRerankInferenceServiceIT extends InferenceBaseRestTest { + + @SuppressWarnings("unchecked") + public void testMockService() throws IOException { + String inferenceEntityId = "test-mock"; + var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK); + var model = getModels(inferenceEntityId, TaskType.RERANK).get(0); + + for (var modelMap : List.of(putModel, model)) { + assertEquals(inferenceEntityId, modelMap.get("inference_id")); + assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type"))); + assertEquals("test_reranking_service", modelMap.get("service")); + } + + List input = List.of(randomAlphaOfLength(10)); + var inference = infer(inferenceEntityId, input); + assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK); + assertEquals(inference, infer(inferenceEntityId, input)); + assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))); + } + + public void testMockServiceWithMultipleInputs() throws IOException { + String inferenceEntityId = "test-mock-with-multi-inputs"; + putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK); + var queryParams = Map.of("timeout", "120s"); + + var inference = infer( + inferenceEntityId, + TaskType.RERANK, + List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)), + "What if?", + queryParams + ); + + assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK); + } + + @SuppressWarnings("unchecked") + public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException { + String inferenceEntityId = "test-mock"; + var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK); + var model = getModels(inferenceEntityId, TaskType.RERANK).get(0); + + var serviceSettings = (Map) model.get("service_settings"); + assertNull(serviceSettings.get("api_key")); + assertNotNull(serviceSettings.get("model_id")); + + var putServiceSettings = (Map) putModel.get("service_settings"); + assertNull(putServiceSettings.get("api_key")); + assertNotNull(putServiceSettings.get("model_id")); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index dc6c96d0d860d..fb7140ab8dfdd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -80,6 +80,8 @@ import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; @@ -365,6 +367,16 @@ private static void addHuggingFaceNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 7f245ae854eac..c0f7aae38bf76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; import java.util.ArrayList; import java.util.Arrays; @@ -91,7 +92,10 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -66,17 +67,21 @@ public void parseRequestConfig( } var model = createModel( - inferenceEntityId, - taskType, - serviceSettingsMap, - chunkingSettings, - serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), - ConfigurationParseContext.REQUEST + new HuggingFaceModelParameters( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), + ConfigurationParseContext.REQUEST + ) ); throwIfNotEmptyMap(config, name()); throwIfNotEmptyMap(serviceSettingsMap, name()); + throwIfNotEmptyMap(taskSettingsMap, name()); parsedModelListener.onResponse(model); } catch (Exception e) { @@ -92,6 +97,7 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( Map secrets ) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); ChunkingSettings chunkingSettings = null; @@ -100,19 +106,23 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( } return createModel( - inferenceEntityId, - taskType, - serviceSettingsMap, - chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, name()), - ConfigurationParseContext.PERSISTENT + new HuggingFaceModelParameters( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, name()), + ConfigurationParseContext.PERSISTENT + ) ); } @Override public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -120,25 +130,20 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType } return createModel( - inferenceEntityId, - taskType, - serviceSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, name()), - ConfigurationParseContext.PERSISTENT + new HuggingFaceModelParameters( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, name()), + ConfigurationParseContext.PERSISTENT + ) ); } - protected abstract HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ); + protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input); @Override public void doInfer( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java index 62133eff4b658..41cc7fd283322 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java @@ -11,6 +11,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -35,6 +36,13 @@ public HuggingFaceModel( apiKey = ServiceUtils.apiKey(apiKeySecrets); } + protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) { + super(model, taskSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + } + public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() { return rateLimitServiceSettings; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java new file mode 100644 index 0000000000000..6dabaa66ffb2b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.util.Map; + +public record HuggingFaceModelParameters( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage, + ConfigurationParseContext context +) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 133f7b5be6b62..d10fb77290c6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -12,10 +12,8 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.LazyInitializable; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -32,7 +30,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; @@ -40,6 +37,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -62,6 +60,7 @@ public class HuggingFaceService extends HuggingFaceBaseService { private static final String SERVICE_NAME = "Hugging Face"; private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( + TaskType.RERANK, TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION, @@ -77,35 +76,43 @@ public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents s } @Override - protected HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ) { - return switch (taskType) { + protected HuggingFaceModel createModel(HuggingFaceModelParameters params) { + return switch (params.taskType()) { + case RERANK -> new HuggingFaceRerankModel( + params.inferenceEntityId(), + params.taskType(), + NAME, + params.serviceSettings(), + params.taskSettings(), + params.secretSettings(), + params.context() + ); case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel( - inferenceEntityId, - taskType, + params.inferenceEntityId(), + params.taskType(), + NAME, + params.serviceSettings(), + params.chunkingSettings(), + params.secretSettings(), + params.context() + ); + case SPARSE_EMBEDDING -> new HuggingFaceElserModel( + params.inferenceEntityId(), + params.taskType(), NAME, - serviceSettings, - chunkingSettings, - secretSettings, - context + params.serviceSettings(), + params.secretSettings(), + params.context() ); - case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel( - inferenceEntityId, - taskType, + params.inferenceEntityId(), + params.taskType(), NAME, - serviceSettings, - secretSettings, - context + params.serviceSettings(), + params.secretSettings(), + params.context() ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index df1ddcb017970..acd5a9e79277c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -23,8 +24,11 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceRerankResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; @@ -37,12 +41,27 @@ */ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = + "Failed to send Hugging Face %s request from inference entity id [%s]"; + private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s"; public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions"; static final String USER_ROLE = "user"; static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( "hugging face completion", OpenAiChatCompletionResponseEntity::fromResponse ); + private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> { + if ((request instanceof HuggingFaceRerankRequest) == false) { + var errorMessage = format( + INVALID_REQUEST_TYPE_MESSAGE, + "RERANK", + request != null ? request.getClass().getSimpleName() : "null" + ); + throw new IllegalArgumentException(errorMessage); + } + return HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response); + }); + private final Sender sender; private final ServiceComponents serviceComponents; @@ -51,6 +70,26 @@ public HuggingFaceActionCreator(Sender sender, ServiceComponents serviceComponen this.serviceComponents = Objects.requireNonNull(serviceComponents); } + @Override + public ExecutableAction create(HuggingFaceRerankModel model) { + var overriddenModel = HuggingFaceRerankModel.of(model, model.getTaskSettings()); + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + inputs -> new HuggingFaceRerankRequest( + inputs.getQuery(), + inputs.getChunks(), + inputs.getReturnDocuments(), + inputs.getTopN(), + model + ), + QueryAndDocsInputs.class + ); + var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + @Override public ExecutableAction create(HuggingFaceEmbeddingsModel model) { var responseHandler = new HuggingFaceResponseHandler( @@ -95,6 +134,6 @@ public ExecutableAction create(HuggingFaceChatCompletionModel model) { } public static String buildErrorMessage(TaskType requestType, String inferenceId) { - return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId); + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java index ee308db774b1d..4c7744ddd0346 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java @@ -11,8 +11,11 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; public interface HuggingFaceActionVisitor { + ExecutableAction create(HuggingFaceRerankModel model); + ExecutableAction create(HuggingFaceEmbeddingsModel model); ExecutableAction create(HuggingFaceElserModel model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 8116eaf86e74a..e61995aac91f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -13,11 +13,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.LazyInitializable; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -35,10 +33,10 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModelParameters; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -69,18 +67,17 @@ public String name() { } @Override - protected HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ) { - return switch (taskType) { - case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + protected HuggingFaceModel createModel(HuggingFaceModelParameters input) { + return switch (input.taskType()) { + case SPARSE_EMBEDDING -> new HuggingFaceElserModel( + input.inferenceEntityId(), + input.taskType(), + NAME, + input.serviceSettings(), + input.secretSettings(), + input.context() + ); + default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java new file mode 100644 index 0000000000000..87fff4caa515d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.rerank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class HuggingFaceRerankRequest implements Request { + + private final HuggingFaceAccount account; + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final HuggingFaceRerankModel model; + + public HuggingFaceRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + HuggingFaceRerankModel model + ) { + Objects.requireNonNull(model); + + this.account = HuggingFaceAccount.of(model); + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; + this.model = model; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new HuggingFaceRerankRequestEntity(query, input, returnDocuments, getTopN(), model.getTaskSettings())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + + decorateWithAuth(httpPost); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(createAuthBearerHeader(model.apiKey())); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return account.uri(); + } + + public Integer getTopN() { + return topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(); + } + + @Override + public Request truncate() { + // Not applicable for rerank, only used in text embedding requests + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Not applicable for rerank, only used in text embedding requests + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java new file mode 100644 index 0000000000000..463dd50fc14b7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record HuggingFaceRerankRequestEntity( + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + HuggingFaceRerankTaskSettings taskSettings +) implements ToXContentObject { + + private static final String RETURN_TEXT = "return_text"; + private static final String DOCUMENTS_FIELD = "texts"; + private static final String QUERY_FIELD = "query"; + + public HuggingFaceRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(DOCUMENTS_FIELD, documents); + builder.field(QUERY_FIELD, query); + + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(RETURN_TEXT, returnDocuments); + } else if (taskSettings.getReturnDocuments() != null) { + builder.field(RETURN_TEXT, taskSettings.getReturnDocuments()); + } + + if (topN != null) { + builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, topN); + } else if (taskSettings.getTopNDocumentsOnly() != null) { + builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); + } + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java new file mode 100644 index 0000000000000..4464e3aaf2dd1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; +import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class HuggingFaceRerankModel extends HuggingFaceModel { + public static HuggingFaceRerankModel of(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) { + return new HuggingFaceRerankModel(model, HuggingFaceRerankTaskSettings.of(model.getTaskSettings(), taskSettings)); + } + + public HuggingFaceRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + HuggingFaceRerankServiceSettings.fromMap(serviceSettings, context), + HuggingFaceRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // Should only be used directly for testing + HuggingFaceRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + HuggingFaceRerankServiceSettings serviceSettings, + HuggingFaceRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings, + secrets + ); + } + + private HuggingFaceRerankModel(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public HuggingFaceRerankServiceSettings getServiceSettings() { + return (HuggingFaceRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public HuggingFaceRerankTaskSettings getTaskSettings() { + return (HuggingFaceRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public Integer getTokenLimit() { + throw new UnsupportedOperationException("Token Limit for rerank is sent in request and not retrieved from the model"); + } + + @Override + public ExecutableAction accept(HuggingFaceActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java new file mode 100644 index 0000000000000..3d4c6aef71e96 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; + +public class HuggingFaceRerankServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + HuggingFaceRateLimitServiceSettings { + + public static final String NAME = "hugging_face_rerank_service_settings"; + public static final String URL = "url"; + + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + public static HuggingFaceRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var uri = extractUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + return new HuggingFaceRerankServiceSettings(uri, rateLimitSettings); + } + + private final URI uri; + private final RateLimitSettings rateLimitSettings; + + public HuggingFaceRerankServiceSettings(String url) { + uri = createUri(url); + rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; + } + + HuggingFaceRerankServiceSettings(URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.uri = Objects.requireNonNull(uri); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public HuggingFaceRerankServiceSettings(StreamInput in) throws IOException { + uri = createUri(in.readString()); + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public URI uri() { + return uri; + } + + // model is not defined in the service settings. + // since hugging face requires that the model be chosen when initializing a deployment within their service. + @Override + public String modelId() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(URL, uri.toString()); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(uri.toString()); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + HuggingFaceRerankServiceSettings that = (HuggingFaceRerankServiceSettings) o; + return Objects.equals(uri, that.uri) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java new file mode 100644 index 0000000000000..9f90386edff90 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + +public class HuggingFaceRerankTaskSettings implements TaskSettings { + + public static final String NAME = "hugging_face_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; + public static final String TOP_N_DOCS_ONLY = "top_n"; + + static final HuggingFaceRerankTaskSettings EMPTY_SETTINGS = new HuggingFaceRerankTaskSettings(null, null); + + public static HuggingFaceRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); + Integer topNDocumentsOnly = extractOptionalPositiveInteger( + map, + TOP_N_DOCS_ONLY, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(topNDocumentsOnly, returnDocuments); + } + + /** + * Creates a new {@link HuggingFaceRerankTaskSettings} + * by preferring non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link HuggingFaceRerankTaskSettings} + */ + public static HuggingFaceRerankTaskSettings of( + HuggingFaceRerankTaskSettings originalSettings, + HuggingFaceRerankTaskSettings requestTaskSettings + ) { + return new HuggingFaceRerankTaskSettings( + requestTaskSettings.getTopNDocumentsOnly() != null + ? requestTaskSettings.getTopNDocumentsOnly() + : originalSettings.getTopNDocumentsOnly(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments() + ); + } + + public static HuggingFaceRerankTaskSettings of(Integer topNDocumentsOnly, Boolean returnDocuments) { + return new HuggingFaceRerankTaskSettings(topNDocumentsOnly, returnDocuments); + } + + private final Integer topNDocumentsOnly; + private final Boolean returnDocuments; + + public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalVInt(), in.readOptionalBoolean()); + } + + public HuggingFaceRerankTaskSettings(@Nullable Integer topNDocumentsOnly, @Nullable Boolean doReturnDocuments) { + this.topNDocumentsOnly = topNDocumentsOnly; + this.returnDocuments = doReturnDocuments; + } + + @Override + public boolean isEmpty() { + return topNDocumentsOnly == null && returnDocuments == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topNDocumentsOnly != null) { + builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly); + } + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(topNDocumentsOnly); + out.writeOptionalBoolean(returnDocuments); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + HuggingFaceRerankTaskSettings that = (HuggingFaceRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topNDocumentsOnly, that.topNDocumentsOnly); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topNDocumentsOnly); + } + + public Integer getTopNDocumentsOnly() { + return topNDocumentsOnly; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + HuggingFaceRerankTaskSettings updatedSettings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return HuggingFaceRerankTaskSettings.of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java new file mode 100644 index 0000000000000..64a7d4845236a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.response; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class HuggingFaceRerankResponseEntity { + + /** + * Parses the Hugging Face rerank response. + + * For a request like: + * + *
+     *     
+     *         {
+     *              "texts": ["luke", "leia"],
+     *              "query": "star wars main character",
+     *              "return_text": true
+     *          }
+     *     
+     * 
+ + * The response would look like: + + *
+     *     
+     *         [
+     *              {
+     *                   "index": 0,
+     *                   "score": -0.07996220886707306,
+     *                   "text": "luke"
+     *               },
+     *               {
+     *                  "index": 1,
+     *                  "score": -0.08393221348524094,
+     *                  "text": "leia"
+     *              }
+     *         ]
+     *     
+     * 
+ */ + + public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + var rankedDocs = doParse(jsonParser); + var rankedDocsByRelevanceStream = rankedDocs.stream() + .sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed()); + var rankedDocStreamTopN = request.getTopN() == null + ? rankedDocsByRelevanceStream + : rankedDocsByRelevanceStream.limit(request.getTopN()); + return new RankedDocsResults(rankedDocStreamTopN.toList()); + } + } + + private static List doParse(XContentParser parser) throws IOException { + return parseList(parser, (listParser, index) -> { + var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser); + return new RankedDocsResults.RankedDoc(parsedRankedDoc.index, parsedRankedDoc.score, parsedRankedDoc.text); + }); + } + + private record RankedDocEntry(Integer index, Float score, @Nullable String text) { + + private static final ParseField TEXT = new ParseField("text"); + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField INDEX = new ParseField("index"); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "hugging_face_rerank_response", + true, + args -> new HuggingFaceRerankResponseEntity.RankedDocEntry((int) args[0], (float) args[1], (String) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), INDEX); + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), SCORE); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TEXT); + } + + public static RankedDocEntry parse(XContentParser parser) { + return PARSER.apply(parser, null); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 9f08b07ea218d..70499c7987965 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -425,7 +425,7 @@ protected void mockService( doAnswer(ans -> { listenerAction.accept(ans.getArgument(9)); return null; - }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), anyBoolean(), any(), any(), anyBoolean(), any(), any(), any(), any()); doAnswer(ans -> { listenerAction.accept(ans.getArgument(3)); return null; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 12c0d85674a58..fcc3c5bfb98fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -1292,7 +1292,7 @@ public void testGetConfiguration() throws Exception { { "service": "hugging_face", "name": "Hugging Face", - "task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"], + "task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"], "configurations": { "api_key": { "description": "API Key for the provider you're connecting to.", @@ -1301,7 +1301,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1310,7 +1310,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"] }, "url": { "description": "The URL endpoint to use for the requests.", @@ -1319,7 +1319,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 895aea8067c46..f5d700016bf81 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -27,12 +27,14 @@ import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests; import org.junit.After; import org.junit.Before; @@ -311,6 +313,66 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws } } + public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "rerank": [ + { + "index": 0, + "relevance_score": -0.07996031, + "text": "luke" + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = HuggingFaceRerankModelTests.createModel(getUrl(webServer), "secret", "model", 8, true); + var actionCreator = new HuggingFaceActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new QueryAndDocsInputs("popular name", List.of("Luke")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + } + assertRerankActionCreator(List.of("Luke"), "popular name", 8, true); + } + + private void assertRerankActionCreator(List documents, String query, int topN, boolean returnText) throws IOException { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("texts"), is(documents)); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("top_n"), is(topN)); + assertThat(requestMap.get("return_text"), is(returnText)); + } + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java new file mode 100644 index 0000000000000..8e7502defeb8d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.rerank; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; + +public class HuggingFaceRerankRequestEntityTests extends ESTestCase { + private static final String INPUT = "texts"; + private static final String QUERY = "query"; + private static final Integer TOP_N = 8; + private static final Boolean RETURN_DOCUMENTS = false; + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new HuggingFaceRerankRequestEntity( + QUERY, + List.of(INPUT), + Boolean.TRUE, + TOP_N, + new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String xContentResult = Strings.toString(builder); + String expected = """ + {"texts":["texts"], + "query":"query", + "return_text":true, + "top_n":8}"""; + assertEquals(stripWhitespace(expected), xContentResult); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new HuggingFaceRerankRequestEntity(QUERY, List.of(INPUT), null, null, new HuggingFaceRerankTaskSettings(null, null)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String xContentResult = Strings.toString(builder); + String expected = """ + {"texts":["texts"],"query":"query"}"""; + assertEquals(stripWhitespace(expected), xContentResult); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java new file mode 100644 index 0000000000000..fc4365441f7d0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.rerank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel; +import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class HuggingFaceRerankRequestTests extends ESTestCase { + private static final String INPUT = "texts"; + private static final String QUERY = "query"; + private static final String INFERENCE_ID = "model"; + private static final Integer TOP_N = 8; + private static final Boolean RETURN_TEXT = false; + + private static final String AUTH_HEADER_VALUE = "foo"; + + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { + testCreateRequest(null, null); + } + + public void testCreateRequest_WithTopN() throws IOException { + testCreateRequest(TOP_N, null); + } + + public void testCreateRequest_WithReturnDocuments() throws IOException { + testCreateRequest(null, RETURN_TEXT); + } + + private void testCreateRequest(Integer topN, Boolean returnDocuments) throws IOException { + var request = createRequest(topN, returnDocuments); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap.get(INPUT), is(List.of(INPUT))); + assertThat(requestMap.get(QUERY), is(QUERY)); + // input and query must exist + int itemsCount = 2; + if (topN != null) { + assertThat(requestMap.get("top_n"), is(topN)); + itemsCount++; + } + if (returnDocuments != null) { + assertThat(requestMap.get("return_text"), is(returnDocuments)); + itemsCount++; + } + assertThat(requestMap, aMapWithSize(itemsCount)); + } + + private static HuggingFaceRerankRequest createRequest(@Nullable Integer topN, @Nullable Boolean returnDocuments) { + var rerankModel = HuggingFaceRerankModelTests.createModel(randomAlphaOfLength(10), "secret", INFERENCE_ID, topN, returnDocuments); + + return new HuggingFaceRerankWithoutAuthRequest(QUERY, List.of(INPUT), rerankModel, topN, returnDocuments); + } + + /** + * We use this class to fake the auth implementation to avoid static mocking of {@link HuggingFaceRerankRequest} + */ + private static class HuggingFaceRerankWithoutAuthRequest extends HuggingFaceRerankRequest { + HuggingFaceRerankWithoutAuthRequest( + String query, + List input, + HuggingFaceRerankModel model, + @Nullable Integer topN, + @Nullable Boolean returnDocuments + ) { + super(query, input, returnDocuments, topN, model); + } + + @Override + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModelTests.java new file mode 100644 index 0000000000000..2c8c28819d926 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModelTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import static org.hamcrest.Matchers.containsString; + +public class HuggingFaceRerankModelTests extends ESTestCase { + + public void testThrowsURISyntaxException_ForInvalidUrl() { + var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret", "model", 8, false)); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); + } + + public static HuggingFaceRerankModel createModel( + String url, + String apiKey, + String modelId, + @Nullable Integer topN, + @Nullable Boolean returnDocuments + ) { + return new HuggingFaceRerankModel( + modelId, + TaskType.RERANK, + "service", + new HuggingFaceRerankServiceSettings(url), + new HuggingFaceRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..301a797cd990f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class HuggingFaceRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static HuggingFaceRerankTaskSettings createRandom() { + var returnDocuments = randomBoolean() ? randomBoolean() : null; + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + + return new HuggingFaceRerankTaskSettings(topNDocsOnly, returnDocuments); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + Map taskMap = Map.of( + HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, + true, + HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, + 5 + ); + var settings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + assertTrue(settings.getReturnDocuments()); + assertEquals(5, settings.getTopNDocumentsOnly().intValue()); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = HuggingFaceRerankTaskSettings.fromMap(Map.of()); + assertNull(settings.getReturnDocuments()); + assertNull(settings.getTopNDocumentsOnly()); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, + "invalid", + HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, + 5 + ); + var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, + true, + HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + } + + public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new HuggingFaceRerankTaskSettings(5, true); + HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertEquals(initialSettings, updatedSettings); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new HuggingFaceRerankTaskSettings(5, true); + Map newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, false); + HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly()); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new HuggingFaceRerankTaskSettings(5, true); + Map newSettings = Map.of(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, 7); + HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new HuggingFaceRerankTaskSettings(5, true); + Map newSettings = Map.of( + HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, + false, + HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, + 7 + ); + HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + } + + @Override + protected Writeable.Reader instanceReader() { + return HuggingFaceRerankTaskSettings::new; + } + + @Override + protected HuggingFaceRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected HuggingFaceRerankTaskSettings mutateInstance(HuggingFaceRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, HuggingFaceRerankTaskSettingsTests::createRandom); + } + + @Override + protected HuggingFaceRerankTaskSettings mutateInstanceForVersion(HuggingFaceRerankTaskSettings instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java new file mode 100644 index 0000000000000..5d1b14c46f099 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java @@ -0,0 +1,147 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HuggingFaceRerankResponseEntityTests extends ESTestCase { + private static final String MISSED_FIELD_INDEX = "index"; + private static final String MISSED_FIELD_SCORE = "score"; + private static final String RESPONSE_JSON_TWO_DOCS = """ + [ + { + "index": 4, + "score": -0.22222222222222222, + "text": "ranked second" + }, + { + "index": 1, + "score": 1.11111111111111111, + "text": "ranked first" + } + ] + """; + private static final List EXPECTED_TWO_DOCS = List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second")) + ); + + private static final String RESPONSE_JSON_FIVE_DOCS = """ + [ + { + "index": 1, + "score": 1.11111111111111111, + "text": "ranked first" + }, + { + "index": 3, + "score": -0.33333333333333333, + "text": "ranked third" + }, + { + "index": 0, + "score": -0.55555555555555555, + "text": "ranked fifth" + }, + { + "index": 2, + "score": -0.44444444444444444, + "text": "ranked fourth" + }, + { + "index": 4, + "score": -0.22222222222222222, + "text": "ranked second" + } + ] + """; + private static final List EXPECTED_FIVE_DOCS = List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 3, "relevance_score", -0.33333333333333333F, "text", "ranked third")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 2, "relevance_score", -0.44444444444444444F, "text", "ranked fourth")), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", -0.55555555555555555F, "text", "ranked fifth")) + ); + + private static final HuggingFaceRerankRequest REQUEST_MOCK = mock(HuggingFaceRerankRequest.class); + + public void testFromResponse_CreatesRankedDocsResults_TopNNull_FiveDocs_NoLimit() throws IOException { + assertTopNLimit(null, RESPONSE_JSON_FIVE_DOCS, EXPECTED_FIVE_DOCS); + } + + public void testFromResponse_CreatesRankedDocsResults_TopN5_TwoDocs_NoLimit() throws IOException { + assertTopNLimit(5, RESPONSE_JSON_TWO_DOCS, EXPECTED_TWO_DOCS); + } + + public void testFromResponse_CreatesRankedDocsResults_TopN2_FiveDocs_Limits() throws IOException { + assertTopNLimit(2, RESPONSE_JSON_FIVE_DOCS, EXPECTED_TWO_DOCS); + } + + public void testFails_CreateRankedDocsResults_IndexFieldNull() { + String responseJson = """ + [ + { + "score": 1.11111111111111111, + "text": "ranked first" + } + ] + """; + assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_INDEX); + } + + public void testFails_CreateRankedDocsResults_ScoreFieldNull() { + String responseJson = """ + [ + { + "index": 1, + "text": "ranked first" + } + ] + """; + assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_SCORE); + } + + private void assertMissingFieldThrowsIllegalArgumentException(String responseJson, String missingField) { + when(REQUEST_MOCK.getTopN()).thenReturn(1); + + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> HuggingFaceRerankResponseEntity.fromResponse( + REQUEST_MOCK, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + assertThat(thrownException.getMessage(), is("Required [" + missingField + "]")); + } + + private void assertTopNLimit( + Integer topN, String responseJson, List expectation) throws IOException { + when(REQUEST_MOCK.getTopN()).thenReturn(topN); + + RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse( + REQUEST_MOCK, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertThat(parsedResults.asMap(), is(buildExpectationRerank(expectation))); + } +}