Skip to content

Add Hugging Face Rerank support #127966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4d9ec59
Add Hugging Face Rerank support
Evgenii-Kazannik Apr 28, 2025
b58aab4
Address comments
Evgenii-Kazannik May 10, 2025
c567137
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 13, 2025
a4ebb87
Add transport version
Evgenii-Kazannik May 13, 2025
5d316c1
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 13, 2025
f74e9e5
Add transport version
Evgenii-Kazannik May 13, 2025
2ea07f0
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 14, 2025
0054891
Add to inference service and crud IT rerank tests
Evgenii-Kazannik May 14, 2025
82fd86d
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 16, 2025
733818c
Refactor slightly / error message
Evgenii-Kazannik May 16, 2025
cb67ac0
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 19, 2025
f97f818
correct 'testGetConfiguration' test case
Evgenii-Kazannik May 19, 2025
88d6929
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 19, 2025
a52a1d8
apply suggestions
Evgenii-Kazannik May 20, 2025
2eae767
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 21, 2025
c8c74d6
fix tests
Evgenii-Kazannik May 21, 2025
ae1a1d2
apply suggestions
Evgenii-Kazannik May 21, 2025
1764a4d
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
7f30c6a
[CI] Auto commit changes from spotless
elasticsearchmachine May 22, 2025
4ee7f1f
add changelog information
Evgenii-Kazannik May 22, 2025
887389f
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
68755a6
Merge remote-tracking branch 'origin/Add-Hugging-Face-Rerank-support'…
Evgenii-Kazannik May 22, 2025
43398ca
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_31);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_32);
public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_35);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,20 @@ private List<RankedDocsResults.RankedDoc> rankedDocsNullStringToEmpty(List<Ranke
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
return RankedDocsResults.createParser(true).apply(parser, null);
}

public record RerankExpectation(Map<String, Object> rankedDocFields) {}

public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
return Map.of(
RankedDocsResults.RERANK,
rerank.stream()
.map(
rerankExpectation -> Map.of(
RankedDocsResults.RankedDoc.NAME,
rerankExpectation.rankedDocFields
)
)
.toList()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;

import java.util.Collections;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
Expand Down Expand Up @@ -58,11 +57,7 @@ public void parseRequestConfig(
) {
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = Collections.emptyMap();

if (TaskType.RERANK.equals(taskType)) {
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
}
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
Expand Down Expand Up @@ -102,12 +97,8 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
Map<String, Object> secrets
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
Map<String, Object> taskSettingsMap = Collections.emptyMap();

if (TaskType.RERANK.equals(taskType)) {
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
}

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
Expand All @@ -131,11 +122,7 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
@Override
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = Collections.emptyMap();

if (TaskType.RERANK.equals(taskType)) {
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
}
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,43 +76,43 @@ public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents s
}

@Override
protected HuggingFaceModel createModel(HuggingFaceModelParameters input) {
return switch (input.taskType()) {
protected HuggingFaceModel createModel(HuggingFaceModelParameters params) {
return switch (params.taskType()) {
case RERANK -> new HuggingFaceRerankModel(
input.inferenceEntityId(),
input.taskType(),
params.inferenceEntityId(),
params.taskType(),
NAME,
input.serviceSettings(),
input.taskSettings(),
input.secretSettings(),
input.context()
params.serviceSettings(),
params.taskSettings(),
params.secretSettings(),
params.context()
);
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
input.inferenceEntityId(),
input.taskType(),
params.inferenceEntityId(),
params.taskType(),
NAME,
input.serviceSettings(),
input.chunkingSettings(),
input.secretSettings(),
input.context()
params.serviceSettings(),
params.chunkingSettings(),
params.secretSettings(),
params.context()
);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(
input.inferenceEntityId(),
input.taskType(),
params.inferenceEntityId(),
params.taskType(),
NAME,
input.serviceSettings(),
input.secretSettings(),
input.context()
params.serviceSettings(),
params.secretSettings(),
params.context()
);
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
input.inferenceEntityId(),
input.taskType(),
params.inferenceEntityId(),
params.taskType(),
NAME,
input.serviceSettings(),
input.secretSettings(),
input.context()
params.serviceSettings(),
params.secretSettings(),
params.context()
);
default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST);
default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST);
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
OpenAiChatCompletionResponseEntity::fromResponse
);
private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> {
var errorMessage = format(INVALID_REQUEST_TYPE_MESSAGE, "RERANK", request != null ? request.getClass().getName() : "null");

if ((request instanceof HuggingFaceRerankRequest) == false) {
var errorMessage = format(
INVALID_REQUEST_TYPE_MESSAGE,
"RERANK",
request != null ? request.getClass().getSimpleName() : "null"
);
throw new IllegalArgumentException(errorMessage);
}
return HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,19 @@ public HttpRequest createHttpRequest() {
input,
returnDocuments,
topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(),
model.getTaskSettings(),
model.getServiceSettings().modelId()
model.getTaskSettings()
)
).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());

decorateWithAuth(httpPost);

return new HttpRequest(httpPost, getInferenceEntityId());
}

public void decorateWithAuth(HttpPost httpPost) {
void decorateWithAuth(HttpPost httpPost) {
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.util.Objects;

public record HuggingFaceRerankRequestEntity(
String model,
String query,
List<String> documents,
@Nullable Boolean returnDocuments,
Expand All @@ -34,17 +33,6 @@ public record HuggingFaceRerankRequestEntity(
Objects.requireNonNull(taskSettings);
}

public HuggingFaceRerankRequestEntity(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
HuggingFaceRerankTaskSettings taskSettings,
String model
) {
this(model, query, input, returnDocuments, topN, taskSettings);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public HuggingFaceRerankModel(
}

// Should only be used directly for testing
public HuggingFaceRerankModel(
HuggingFaceRerankModel(
String inferenceEntityId,
TaskType taskType,
String service,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@
import java.util.Comparator;
import java.util.List;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;

public class HuggingFaceRerankResponseEntity {

private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Hugging Face rerank response";
private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Hugging Face rerank ";

/**
* Parses the Hugging Face rerank response.

Expand All @@ -41,10 +36,9 @@ public class HuggingFaceRerankResponseEntity {
* <pre>
* <code>
* {
* "input": ["luke", "like", "leia", "chewy","r2d2", "star", "wars"],
* "texts": ["luke", "leia"],
* "query": "star wars main character",
* "return_documents": true,
* "top_n": 1
* "return_text": true
* }
* </code>
* </pre>
Expand All @@ -53,15 +47,18 @@ public class HuggingFaceRerankResponseEntity {

* <pre>
* <code>
* {
* "rerank": [
* {
* "index": 5,
* "relevance_score": -0.06920313,
* "text": "star"
* }
* ]
* }
* [
* {
* "index": 0,
* "score": -0.07996220886707306,
* "text": "luke"
* },
* {
* "index": 1,
* "score": -0.08393221348524094,
* "text": "leia"
* }
* ]
* </code>
* </pre>
*/
Expand All @@ -71,10 +68,6 @@ public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, H

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_ARRAY, token, jsonParser);

var rankedDocs = doParse(jsonParser);
var rankedDocsByRelevanceStream = rankedDocs.stream()
.sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed());
Expand All @@ -88,24 +81,11 @@ public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, H
private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser) throws IOException {
return parseList(parser, (listParser, index) -> {
var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser);

if (parsedRankedDoc.id == null) {
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDocEntry.ID.getPreferredName()));
}

if (parsedRankedDoc.score == null) {
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDocEntry.SCORE.getPreferredName()));
}

try {
return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text);
} catch (NumberFormatException e) {
throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id));
}
return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text);
});
}

private record RankedDocEntry(@Nullable Integer id, @Nullable Float score, @Nullable String text) {
private record RankedDocEntry(Integer id, Float score, @Nullable String text) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's change id to index, I think that's clearer.


private static final ParseField TEXT = new ParseField("text");
private static final ParseField SCORE = new ParseField("score");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,14 @@ public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOE

private void assertRerankActionCreator(List<String> documents, String query, int topN, boolean returnText) throws IOException {
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().getFirst().getUri().getQuery());
assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
);
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));

var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), is(4));
assertThat(requestMap.get("texts"), is(documents));
assertThat(requestMap.get("query"), is(query));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace;

public class HuggingFaceRerankRequestEntityTests extends ESTestCase {
private static final String INPUT = "texts";
private static final String QUERY = "query";
private static final String INFERENCE_ID = "model";
private static final Integer TOP_N = 8;
private static final Boolean RETURN_DOCUMENTS = false;

Expand All @@ -33,36 +32,28 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException
List.of(INPUT),
Boolean.TRUE,
TOP_N,
new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS),
INFERENCE_ID
new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS)
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
String xContentResult = Strings.toString(builder);

assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
String expected = """
{"texts":["texts"],
"query":"query",
"return_text":true,
"top_n":8}"""));
"top_n":8}""";
assertEquals(stripWhitespace(expected), xContentResult);
}

public void testXContent_WritesMinimalFields() throws IOException {
var entity = new HuggingFaceRerankRequestEntity(
QUERY,
List.of(INPUT),
null,
null,
new HuggingFaceRerankTaskSettings(null, null),
INFERENCE_ID
);
var entity = new HuggingFaceRerankRequestEntity(QUERY, List.of(INPUT), null, null, new HuggingFaceRerankTaskSettings(null, null));

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
String xContentResult = Strings.toString(builder);

assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{"texts":["texts"],"query":"query"}"""));
String expected = """
{"texts":["texts"],"query":"query"}""";
assertEquals(stripWhitespace(expected), xContentResult);
}
}
Loading