Skip to content

Add Hugging Face Rerank support #127966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4d9ec59
Add Hugging Face Rerank support
Evgenii-Kazannik Apr 28, 2025
b58aab4
Address comments
Evgenii-Kazannik May 10, 2025
c567137
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 13, 2025
a4ebb87
Add transport version
Evgenii-Kazannik May 13, 2025
5d316c1
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 13, 2025
f74e9e5
Add transport version
Evgenii-Kazannik May 13, 2025
2ea07f0
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 14, 2025
0054891
Add to inference service and crud IT rerank tests
Evgenii-Kazannik May 14, 2025
82fd86d
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 16, 2025
733818c
Refactor slightly / error message
Evgenii-Kazannik May 16, 2025
cb67ac0
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 19, 2025
f97f818
correct 'testGetConfiguration' test case
Evgenii-Kazannik May 19, 2025
88d6929
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 19, 2025
a52a1d8
apply suggestions
Evgenii-Kazannik May 20, 2025
2eae767
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 21, 2025
c8c74d6
fix tests
Evgenii-Kazannik May 21, 2025
ae1a1d2
apply suggestions
Evgenii-Kazannik May 21, 2025
1764a4d
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
7f30c6a
[CI] Auto commit changes from spotless
elasticsearchmachine May 22, 2025
4ee7f1f
add changelog information
Evgenii-Kazannik May 22, 2025
887389f
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
68755a6
Merge remote-tracking branch 'origin/Add-Hugging-Face-Rerank-support'…
Evgenii-Kazannik May 22, 2025
43398ca
Merge branch 'main' into Add-Hugging-Face-Rerank-support
Evgenii-Kazannik May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
public class SettingsConfigurationTestUtils {

public static SettingsConfiguration getRandomSettingsConfigurationField() {
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
randomAlphaOfLength(10)
)
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
.setDefaultValue(randomAlphaOfLength(10))
.setDescription(randomAlphaOfLength(10))
.setLabel(randomAlphaOfLength(10))
.setRequired(randomBoolean())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
""";
}

static String mockRerankServiceModelConfig() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if methods you've added to this class are actually used somewhere. Methods you've taken for reference are being called. The ones you've added - are not.

Copy link
Contributor Author

@Evgenii-Kazannik Evgenii-Kazannik May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing. It's used now

return """
{
"task_type": "rerank",
"service": "rerank_test_service",
"service_settings": {
"model": "rerank_model",
"api_key": "abc64"
},
"task_settings": {
"return_documents": true
}
}
""";
}

static void deleteModel(String modelId) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId);
var response = client().performRequest(request);
Expand Down Expand Up @@ -484,6 +500,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
@SuppressWarnings("unchecked")
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
switch (taskType) {
case RERANK -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this method is not called with TaskType.RERANK param anywhere. meaning assertion isn't triggered.

var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
assertThat(results, hasSize(expectedNumberOfResults));
}
case SPARSE_EMBEDDING -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
assertThat(results, hasSize(expectedNumberOfResults));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
"googlevertexai",
"jinaai",
"test_reranking_service",
"voyageai"
"voyageai",
"hugging_face"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
Expand Down Expand Up @@ -357,6 +359,16 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceRerankServiceSettings.NAME,
HuggingFaceRerankServiceSettings::new
)
);
}

private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -91,7 +92,10 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
configuredTopN = googleVertexAiTaskSettings.topN();
}
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
}
if (configuredTopN != null && configuredTopN < rankWindowSize) {
l.onFailure(
new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public void parseRequestConfig(
) {
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
Expand All @@ -65,7 +66,7 @@ public void parseRequestConfig(
);
}

var model = createModel(
var modelBuilder = new HuggingFaceModelInput.Builder(
inferenceEntityId,
taskType,
serviceSettingsMap,
Expand All @@ -75,8 +76,13 @@ public void parseRequestConfig(
ConfigurationParseContext.REQUEST
);

var model = createModel(
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build()
);

throwIfNotEmptyMap(config, name());
throwIfNotEmptyMap(serviceSettingsMap, name());
throwIfNotEmptyMap(taskSettingsMap, name());

parsedModelListener.onResponse(model);
} catch (Exception e) {
Expand All @@ -92,14 +98,15 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
Map<String, Object> secrets
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong. but won't that throw an exception if there are no task settings in config? If so, doesn't that affect other integrations that don't require TASK_SETTINGS to be present?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added Rerank type check to ensure the method isn't used for other tasks

Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
var modelBuilder = new HuggingFaceModelInput.Builder(
inferenceEntityId,
taskType,
serviceSettingsMap,
Expand All @@ -108,18 +115,23 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
);

return createModel(
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build()
);
}

@Override
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added type check before using the methos.Thanks


ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
var modelBuilder = new HuggingFaceModelInput.Builder(
inferenceEntityId,
taskType,
serviceSettingsMap,
Expand All @@ -128,17 +140,13 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
);

return createModel(
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the builder accepts null task settings so how about we just pass in the task settings map, regardless of it being null or not. That way we don't need to check for rerank here.

Copy link
Contributor Author

@Evgenii-Kazannik Evgenii-Kazannik May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models accept the task settings map as is now. Thanks

);
}

protected abstract HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
);
protected abstract HuggingFaceModel createModel(HuggingFaceModelInput input);

@Override
public void doInfer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@

import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.Objects;

public abstract class HuggingFaceModel extends Model {
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
private final SecureString apiKey;

Expand All @@ -34,10 +36,27 @@ public HuggingFaceModel(
apiKey = ServiceUtils.apiKey(apiKeySecrets);
}

protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) {
super(model, taskSettings);

rateLimitServiceSettings = model.rateLimitServiceSettings();
apiKey = model.apiKey();
}

public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}

public SecureString apiKey() {
return apiKey;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.huggingface;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;

import java.util.Map;

public class HuggingFaceModelInput {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we make this a record and maybe rename it to HuggingFaceModelParameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. The record fits better. Thanks. Done

private final String inferenceEntityId;
private final TaskType taskType;
private final Map<String, Object> serviceSettings;
@Nullable
private final Map<String, Object> taskSettings;
private final ChunkingSettings chunkingSettings;
@Nullable
private final Map<String, Object> secretSettings;
private final String failureMessage;
private final ConfigurationParseContext context;

public HuggingFaceModelInput(Builder builder) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this private? We probably want the instantiation done through the builder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The builder was replaced with the record as suggested so not needed anymore.
Though thank you for pointing that out

this.inferenceEntityId = builder.inferenceEntityId;
this.taskType = builder.taskType;
this.serviceSettings = builder.serviceSettings;
this.taskSettings = builder.taskSettings;
this.chunkingSettings = builder.chunkingSettings;
this.secretSettings = builder.secretSettings;
this.failureMessage = builder.failureMessage;
this.context = builder.context;
}

public String getInferenceEntityId() {
return inferenceEntityId;
}

public TaskType getTaskType() {
return taskType;
}

public Map<String, Object> getServiceSettings() {
return serviceSettings;
}

@Nullable
public Map<String, Object> getTaskSettings() {
return taskSettings;
}

public ChunkingSettings getChunkingSettings() {
return chunkingSettings;
}

@Nullable
public Map<String, Object> getSecretSettings() {
return secretSettings;
}

public String getFailureMessage() {
return failureMessage;
}

public ConfigurationParseContext getContext() {
return context;
}

public static class Builder {
private String inferenceEntityId;
private TaskType taskType;
private Map<String, Object> serviceSettings;
@Nullable
private Map<String, Object> taskSettings;
private ChunkingSettings chunkingSettings;
@Nullable
Map<String, Object> secretSettings;
private String failureMessage;
private ConfigurationParseContext context;

public Builder(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
) {
this.inferenceEntityId = inferenceEntityId;
this.taskType = taskType;
this.serviceSettings = serviceSettings;
this.chunkingSettings = chunkingSettings;
this.secretSettings = secretSettings;
this.failureMessage = failureMessage;
this.context = context;
}

public Builder withTaskSettings(Map<String, Object> taskSettings) {
this.taskSettings = taskSettings;
return this;
}

public HuggingFaceModelInput build() {
return new HuggingFaceModelInput(this);
}
}
}
Loading