Skip to content

Commit bfe40fd

Browse files
Add Hugging Face Rerank support (elastic#127966)
* Add Hugging Face Rerank support * Address comments * Add transport version * Add transport version * Add to inference service and crud IT rerank tests * Refactor slightly / error message * correct 'testGetConfiguration' test case * apply suggestions * fix tests * apply suggestions * [CI] Auto commit changes from spotless * add changelog information --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co> (cherry picked from commit c7cf850) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
1 parent deb0e61 commit bfe40fd

File tree

31 files changed

+1484
-84
lines changed

31 files changed

+1484
-84
lines changed

docs/changelog/127966.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127966
2+
summary: "[ML] Add Rerank support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ static TransportVersion def(int id) {
224224
public static final TransportVersion INCLUDE_INDEX_MODE_IN_GET_DATA_STREAM_BACKPORT_8_19 = def(8_841_0_33);
225225
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
226226
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
227+
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
227228

228229
/*
229230
* STOP! READ THIS FIRST! No, really,

server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
public class SettingsConfigurationTestUtils {
2121

2222
public static SettingsConfiguration getRandomSettingsConfigurationField() {
23-
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
24-
randomAlphaOfLength(10)
25-
)
23+
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
24+
.setDefaultValue(randomAlphaOfLength(10))
2625
.setDescription(randomAlphaOfLength(10))
2726
.setLabel(randomAlphaOfLength(10))
2827
.setRequired(randomBoolean())

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,13 @@ private List<RankedDocsResults.RankedDoc> rankedDocsNullStringToEmpty(List<Ranke
8282
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
8383
return RankedDocsResults.createParser(true).apply(parser, null);
8484
}
85+
86+
public record RerankExpectation(Map<String, Object> rankedDocFields) {}
87+
88+
public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
89+
return Map.of(
90+
RankedDocsResults.RERANK,
91+
rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList()
92+
);
93+
}
8594
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174+
static String mockRerankServiceModelConfig() {
175+
return """
176+
{
177+
"service": "test_reranking_service",
178+
"service_settings": {
179+
"model_id": "my_model",
180+
"api_key": "abc64"
181+
},
182+
"task_settings": {
183+
}
184+
}
185+
""";
186+
}
187+
174188
static void deleteModel(String modelId) throws IOException {
175189
var request = new Request("DELETE", "_inference/" + modelId);
176190
var response = client().performRequest(request);
@@ -484,6 +498,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
484498
@SuppressWarnings("unchecked")
485499
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
486500
switch (taskType) {
501+
case RERANK -> {
502+
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
503+
assertThat(results, hasSize(expectedNumberOfResults));
504+
}
487505
case SPARSE_EMBEDDING -> {
488506
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
489507
assertThat(results, hasSize(expectedNumberOfResults));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ public void testCRUD() throws IOException {
5353
for (int i = 0; i < 4; i++) {
5454
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
5555
}
56+
for (int i = 0; i < 3; i++) {
57+
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
58+
}
5659

5760
var getAllModels = getAllModels();
58-
int numModels = 12;
61+
int numModels = 15;
5962
assertThat(getAllModels, hasSize(numModels));
6063

6164
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -71,6 +74,13 @@ public void testCRUD() throws IOException {
7174
for (var denseModel : getDenseModels) {
7275
assertEquals("text_embedding", denseModel.get("task_type"));
7376
}
77+
78+
var getRerankModels = getModels("_all", TaskType.RERANK);
79+
int numRerankModels = 4;
80+
assertThat(getRerankModels, hasSize(numRerankModels));
81+
for (var denseModel : getRerankModels) {
82+
assertEquals("rerank", denseModel.get("task_type"));
83+
}
7484
String oldApiKey;
7585
{
7686
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
@@ -100,6 +110,9 @@ public void testCRUD() throws IOException {
100110
for (int i = 0; i < 4; i++) {
101111
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
102112
}
113+
for (int i = 0; i < 3; i++) {
114+
deleteModel("re-model-" + i, TaskType.RERANK);
115+
}
103116
}
104117

105118
public void testGetModelWithWrongTaskType() throws IOException {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
101101

102102
public void testGetServicesWithRerankTaskType() throws IOException {
103103
List<Object> services = getServices(TaskType.RERANK);
104-
assertThat(services.size(), equalTo(7));
104+
assertThat(services.size(), equalTo(8));
105105

106106
var providers = providers(services);
107107

@@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
115115
"googlevertexai",
116116
"jinaai",
117117
"test_reranking_service",
118-
"voyageai"
118+
"voyageai",
119+
"hugging_face"
119120
).toArray()
120121
)
121122
);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
12+
import java.io.IOException;
13+
import java.util.List;
14+
import java.util.Map;
15+
16+
public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {
17+
18+
@SuppressWarnings("unchecked")
19+
public void testMockService() throws IOException {
20+
String inferenceEntityId = "test-mock";
21+
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
22+
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
23+
24+
for (var modelMap : List.of(putModel, model)) {
25+
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
26+
assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type")));
27+
assertEquals("test_reranking_service", modelMap.get("service"));
28+
}
29+
30+
List<String> input = List.of(randomAlphaOfLength(10));
31+
var inference = infer(inferenceEntityId, input);
32+
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
33+
assertEquals(inference, infer(inferenceEntityId, input));
34+
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
35+
}
36+
37+
public void testMockServiceWithMultipleInputs() throws IOException {
38+
String inferenceEntityId = "test-mock-with-multi-inputs";
39+
putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
40+
var queryParams = Map.of("timeout", "120s");
41+
42+
var inference = infer(
43+
inferenceEntityId,
44+
TaskType.RERANK,
45+
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)),
46+
"What if?",
47+
queryParams
48+
);
49+
50+
assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK);
51+
}
52+
53+
@SuppressWarnings("unchecked")
54+
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
55+
String inferenceEntityId = "test-mock";
56+
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
57+
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
58+
59+
var serviceSettings = (Map<String, Object>) model.get("service_settings");
60+
assertNull(serviceSettings.get("api_key"));
61+
assertNotNull(serviceSettings.get("model_id"));
62+
63+
var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
64+
assertNull(putServiceSettings.get("api_key"));
65+
assertNotNull(putServiceSettings.get("model_id"));
66+
}
67+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
8181
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
8282
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
83+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
84+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
8385
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
8486
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
8587
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@@ -365,6 +367,16 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
365367
HuggingFaceChatCompletionServiceSettings::new
366368
)
367369
);
370+
namedWriteables.add(
371+
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
372+
);
373+
namedWriteables.add(
374+
new NamedWriteableRegistry.Entry(
375+
ServiceSettings.class,
376+
HuggingFaceRerankServiceSettings.NAME,
377+
HuggingFaceRerankServiceSettings::new
378+
)
379+
);
368380
}
369381

370382
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2020
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
2121
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
22+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
2223

2324
import java.util.ArrayList;
2425
import java.util.Arrays;
@@ -94,7 +95,10 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
9495
} else if (r.getEndpoints().isEmpty() == false
9596
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
9697
configuredTopN = googleVertexAiTaskSettings.topN();
97-
}
98+
} else if (r.getEndpoints().isEmpty() == false
99+
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
100+
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
101+
}
98102
if (configuredTopN != null && configuredTopN < rankWindowSize) {
99103
l.onFailure(
100104
new IllegalArgumentException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public void parseRequestConfig(
5757
) {
5858
try {
5959
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
60+
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
6061

6162
ChunkingSettings chunkingSettings = null;
6263
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@@ -66,17 +67,21 @@ public void parseRequestConfig(
6667
}
6768

6869
var model = createModel(
69-
inferenceEntityId,
70-
taskType,
71-
serviceSettingsMap,
72-
chunkingSettings,
73-
serviceSettingsMap,
74-
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
75-
ConfigurationParseContext.REQUEST
70+
new HuggingFaceModelParameters(
71+
inferenceEntityId,
72+
taskType,
73+
serviceSettingsMap,
74+
taskSettingsMap,
75+
chunkingSettings,
76+
serviceSettingsMap,
77+
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
78+
ConfigurationParseContext.REQUEST
79+
)
7680
);
7781

7882
throwIfNotEmptyMap(config, name());
7983
throwIfNotEmptyMap(serviceSettingsMap, name());
84+
throwIfNotEmptyMap(taskSettingsMap, name());
8085

8186
parsedModelListener.onResponse(model);
8287
} catch (Exception e) {
@@ -92,6 +97,7 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
9297
Map<String, Object> secrets
9398
) {
9499
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
100+
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
95101
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
96102

97103
ChunkingSettings chunkingSettings = null;
@@ -100,45 +106,44 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
100106
}
101107

102108
return createModel(
103-
inferenceEntityId,
104-
taskType,
105-
serviceSettingsMap,
106-
chunkingSettings,
107-
secretSettingsMap,
108-
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
109-
ConfigurationParseContext.PERSISTENT
109+
new HuggingFaceModelParameters(
110+
inferenceEntityId,
111+
taskType,
112+
serviceSettingsMap,
113+
taskSettingsMap,
114+
chunkingSettings,
115+
secretSettingsMap,
116+
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
117+
ConfigurationParseContext.PERSISTENT
118+
)
110119
);
111120
}
112121

113122
@Override
114123
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
115124
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
125+
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
116126

117127
ChunkingSettings chunkingSettings = null;
118128
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
119129
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
120130
}
121131

122132
return createModel(
123-
inferenceEntityId,
124-
taskType,
125-
serviceSettingsMap,
126-
chunkingSettings,
127-
null,
128-
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
129-
ConfigurationParseContext.PERSISTENT
133+
new HuggingFaceModelParameters(
134+
inferenceEntityId,
135+
taskType,
136+
serviceSettingsMap,
137+
taskSettingsMap,
138+
chunkingSettings,
139+
null,
140+
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
141+
ConfigurationParseContext.PERSISTENT
142+
)
130143
);
131144
}
132145

133-
protected abstract HuggingFaceModel createModel(
134-
String inferenceEntityId,
135-
TaskType taskType,
136-
Map<String, Object> serviceSettings,
137-
ChunkingSettings chunkingSettings,
138-
Map<String, Object> secretSettings,
139-
String failureMessage,
140-
ConfigurationParseContext context
141-
);
146+
protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input);
142147

143148
@Override
144149
public void doInfer(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.core.Nullable;
1212
import org.elasticsearch.inference.ModelConfigurations;
1313
import org.elasticsearch.inference.ModelSecrets;
14+
import org.elasticsearch.inference.TaskSettings;
1415
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1516
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1617
import org.elasticsearch.xpack.inference.services.ServiceUtils;
@@ -35,6 +36,13 @@ public HuggingFaceModel(
3536
apiKey = ServiceUtils.apiKey(apiKeySecrets);
3637
}
3738

39+
protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) {
40+
super(model, taskSettings);
41+
42+
rateLimitServiceSettings = model.rateLimitServiceSettings();
43+
apiKey = model.apiKey();
44+
}
45+
3846
public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
3947
return rateLimitServiceSettings;
4048
}

0 commit comments

Comments
 (0)