Skip to content

Add Hugging Face Rerank support #127966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4d9ec59
Add Hugging Face Rerank support
Evgenii-Kazannik Apr 28, 2025
b58aab4
Address comments
Evgenii-Kazannik May 10, 2025
c567137
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 13, 2025
a4ebb87
Add transport version
Evgenii-Kazannik May 13, 2025
5d316c1
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 13, 2025
f74e9e5
Add transport version
Evgenii-Kazannik May 13, 2025
2ea07f0
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 14, 2025
0054891
Add to inference service and crud IT rerank tests
Evgenii-Kazannik May 14, 2025
82fd86d
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 16, 2025
733818c
Refactor slightly / error message
Evgenii-Kazannik May 16, 2025
cb67ac0
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 19, 2025
f97f818
correct 'testGetConfiguration' test case
Evgenii-Kazannik May 19, 2025
88d6929
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 19, 2025
a52a1d8
apply suggestions
Evgenii-Kazannik May 20, 2025
2eae767
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 21, 2025
c8c74d6
fix tests
Evgenii-Kazannik May 21, 2025
ae1a1d2
apply suggestions
Evgenii-Kazannik May 21, 2025
1764a4d
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
7f30c6a
[CI] Auto commit changes from spotless
elasticsearchmachine May 22, 2025
4ee7f1f
add changelog information
Evgenii-Kazannik May 22, 2025
887389f
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
68755a6
Merge remote-tracking branch 'origin/Add-Hugging-Face-Rerank-support'…
Evgenii-Kazannik May 22, 2025
43398ca
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127966.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127966
summary: "[ML] Add Rerank support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ static TransportVersion def(int id) {
public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -261,7 +262,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);

public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
public class SettingsConfigurationTestUtils {

public static SettingsConfiguration getRandomSettingsConfigurationField() {
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
randomAlphaOfLength(10)
)
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
.setDefaultValue(randomAlphaOfLength(10))
.setDescription(randomAlphaOfLength(10))
.setLabel(randomAlphaOfLength(10))
.setRequired(randomBoolean())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,13 @@ private List<RankedDocsResults.RankedDoc> rankedDocsNullStringToEmpty(List<Ranke
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
return RankedDocsResults.createParser(true).apply(parser, null);
}

public record RerankExpectation(Map<String, Object> rankedDocFields) {}

public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
return Map.of(
RankedDocsResults.RERANK,
rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ static String mockDenseServiceModelConfig() {
""";
}

static String mockRerankServiceModelConfig() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if methods you've added to this class are actually used somewhere. Methods you've taken for reference are being called. The ones you've added - are not.

Copy link
Contributor Author

@Evgenii-Kazannik Evgenii-Kazannik May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing. It's used now

return """
{
"service": "test_reranking_service",
"service_settings": {
"model_id": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
""";
}

static void deleteModel(String modelId) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId);
var response = client().performRequest(request);
Expand Down Expand Up @@ -484,6 +498,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
@SuppressWarnings("unchecked")
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
switch (taskType) {
case RERANK -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this method is not called with TaskType.RERANK param anywhere. meaning assertion isn't triggered.

var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
assertThat(results, hasSize(expectedNumberOfResults));
}
case SPARSE_EMBEDDING -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
assertThat(results, hasSize(expectedNumberOfResults));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ public void testCRUD() throws IOException {
for (int i = 0; i < 4; i++) {
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
}
for (int i = 0; i < 3; i++) {
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
}

var getAllModels = getAllModels();
int numModels = 12;
int numModels = 15;
assertThat(getAllModels, hasSize(numModels));

var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
Expand All @@ -71,6 +74,13 @@ public void testCRUD() throws IOException {
for (var denseModel : getDenseModels) {
assertEquals("text_embedding", denseModel.get("task_type"));
}

var getRerankModels = getModels("_all", TaskType.RERANK);
int numRerankModels = 4;
assertThat(getRerankModels, hasSize(numRerankModels));
for (var denseModel : getRerankModels) {
assertEquals("rerank", denseModel.get("task_type"));
}
String oldApiKey;
{
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
Expand Down Expand Up @@ -100,6 +110,9 @@ public void testCRUD() throws IOException {
for (int i = 0; i < 4; i++) {
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
}
for (int i = 0; i < 3; i++) {
deleteModel("re-model-" + i, TaskType.RERANK);
}
}

public void testGetModelWithWrongTaskType() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
"googlevertexai",
"jinaai",
"test_reranking_service",
"voyageai"
"voyageai",
"hugging_face"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.inference.TaskType;

import java.io.IOException;
import java.util.List;
import java.util.Map;

public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {

@SuppressWarnings("unchecked")
public void testMockService() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);

for (var modelMap : List.of(putModel, model)) {
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type")));
assertEquals("test_reranking_service", modelMap.get("service"));
}

List<String> input = List.of(randomAlphaOfLength(10));
var inference = infer(inferenceEntityId, input);
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
assertEquals(inference, infer(inferenceEntityId, input));
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
}

public void testMockServiceWithMultipleInputs() throws IOException {
String inferenceEntityId = "test-mock-with-multi-inputs";
putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var queryParams = Map.of("timeout", "120s");

var inference = infer(
inferenceEntityId,
TaskType.RERANK,
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)),
"What if?",
queryParams
);

assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK);
}

@SuppressWarnings("unchecked")
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);

var serviceSettings = (Map<String, Object>) model.get("service_settings");
assertNull(serviceSettings.get("api_key"));
assertNotNull(serviceSettings.get("model_id"));

var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
assertNull(putServiceSettings.get("api_key"));
assertNotNull(putServiceSettings.get("model_id"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
Expand Down Expand Up @@ -365,6 +367,16 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
HuggingFaceChatCompletionServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceRerankServiceSettings.NAME,
HuggingFaceRerankServiceSettings::new
)
);
}

private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -91,7 +92,10 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
configuredTopN = googleVertexAiTaskSettings.topN();
}
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
}
if (configuredTopN != null && configuredTopN < rankWindowSize) {
l.onFailure(
new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public void parseRequestConfig(
) {
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
Expand All @@ -66,17 +67,21 @@ public void parseRequestConfig(
}

var model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
)
);

throwIfNotEmptyMap(config, name());
throwIfNotEmptyMap(serviceSettingsMap, name());
throwIfNotEmptyMap(taskSettingsMap, name());

parsedModelListener.onResponse(model);
} catch (Exception e) {
Expand All @@ -92,6 +97,7 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
Map<String, Object> secrets
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
Expand All @@ -100,45 +106,44 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
}

return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
)
);
}

@Override
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
)
);
}

protected abstract HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
);
protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input);

@Override
public void doInfer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
Expand All @@ -35,6 +36,13 @@ public HuggingFaceModel(
apiKey = ServiceUtils.apiKey(apiKeySecrets);
}

protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) {
super(model, taskSettings);

rateLimitServiceSettings = model.rateLimitServiceSettings();
apiKey = model.apiKey();
}

public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}
Expand Down
Loading
Loading