-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Add Hugging Face Rerank support #127966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Hugging Face Rerank support #127966
Conversation
@@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() { | |||
"""; | |||
} | |||
|
|||
static String mockRerankServiceModelConfig() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if methods you've added to this class are actually used somewhere. Methods you've taken for reference are being called. The ones you've added - are not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for noticing. It's used now
@@ -484,6 +500,10 @@ private String jsonBody(List<String> input, @Nullable String query) { | |||
@SuppressWarnings("unchecked") | |||
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) { | |||
switch (taskType) { | |||
case RERANK -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this method is not called with TaskType.RERANK param anywhere. meaning assertion isn't triggered.
@@ -92,14 +98,15 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( | |||
Map<String, Object> secrets | |||
) { | |||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | |||
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong. but won't that throw an exception if there are no task settings in config? If so, doesn't that affect other integrations that don't require TASK_SETTINGS to be present?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added Rerank type check to ensure the method isn't used for other tasks
} | ||
|
||
@Override | ||
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) { | ||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | ||
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added type check before using the methos.Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments.
|
||
@Override | ||
public boolean[] getTruncationInfo() { | ||
return null; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a comment here, explaining why null is returned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah truncation is only used in some services that support text embedding. Just say something like "Not applicable for rerank, only used in text embedding requests".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added as suggested: Not applicable for rerank, only used in text embedding requests
Thanks all
|
||
@Override | ||
public TransportVersion getMinimalSupportedVersion() { | ||
return TransportVersions.V_8_12_0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check comments related to TransportVersions left by @jonathan-buttner to this PR: #127254
They would apply here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did it, read it through, updated. I'm going to update the versions once more before the merge
|
||
@Override | ||
public TransportVersion getMinimalSupportedVersion() { | ||
return TransportVersions.V_8_14_0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing here related to comments for TranportVersion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, applied the change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! I left a few suggestions.
this.returnDocuments = returnDocuments; | ||
this.topN = topN; | ||
taskSettings = model.getTaskSettings(); | ||
this.model = model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're saving a reference to the model
how about we remove the taskSettings
and inferenceEntityId
references and just use the model
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Used them from the model
|
||
import java.util.Map; | ||
|
||
public class HuggingFaceModelInput { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we make this a record and maybe rename it to HuggingFaceModelParameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. The record fits better. Thanks. Done
private final String failureMessage; | ||
private final ConfigurationParseContext context; | ||
|
||
public HuggingFaceModelInput(Builder builder) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this private? We probably want the instantiation done through the builder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The builder was replaced with the record as suggested so not needed anymore.
Though thank you for pointing that out
@@ -128,17 +140,13 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType | |||
parsePersistedConfigErrorMsg(inferenceEntityId, name()), | |||
ConfigurationParseContext.PERSISTENT | |||
); | |||
|
|||
return createModel( | |||
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the builder accepts null task settings so how about we just pass in the task settings map, regardless of it being null or not. That way we don't need to check for rerank here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The models accept the task settings map as is now. Thanks
|
||
@Override | ||
public boolean[] getTruncationInfo() { | ||
return null; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah truncation is only used in some services that support text embedding. Just say something like "Not applicable for rerank, only used in text embedding requests".
return RERANK_TOKEN_LIMIT; | ||
} | ||
|
||
// model is not defined in the service settings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we encountered situations where the model id was required for chat completion, have we done any testing to see if the serverless style endpoint requires the model id?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that HF currently does not provide serverless for Rerank models. We cannot test it now
@Override | ||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { | ||
builder.field(URL, uri.toString()); | ||
builder.field(MAX_INPUT_TOKENS, RERANK_TOKEN_LIMIT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this, since we don't use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. Thank you
@@ -0,0 +1,123 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're trying to move away from this style of parsing and instead use an ObjectParser
or ConstructingObjectParser
. How about we switch this implementation to use ConstructingObjectParser
? Here's an example: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; | ||
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; | ||
|
||
public class HuggingFaceRerankResponseEntity extends ErrorResponse { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, typically we separate the valid response from the error response. Does the HuggingFaceErrorResponseEntity
suffice?
"Failed to send Hugging Face %s request from inference entity id [%s]"; | ||
static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler( | ||
"hugging face rerank", | ||
(request, response) -> HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be unlikely but can we do an instanceof
check for request
being a HuggingFaceRerankRequest
? And throw an IllegalArgumentException
if it's invalid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Thank you Jonathan. Explicit check was added
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
Pinging @elastic/ml-core (Team:ML) |
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = | ||
"Failed to send Hugging Face %s request from inference entity id [%s]"; | ||
private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s"; | ||
static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it package private and not just private?
"ELSER", | ||
model.getInferenceEntityId() | ||
); | ||
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest adapting approach that is implemented here:
https://github.com/elastic/elasticsearch/pull/127254/files#diff-e0d9eac4ad74ebb018731efce7d8418eb03989288ed59f752ba5dbe71eac7481R97-R99
Since it is likely to be merged before your changes.
} | ||
|
||
@Override | ||
public DefaultSecretSettings getSecretSettings() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this method? It just calls super.method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we need it but it's a pattern we have throughout the plugin. I think it's fine to leave it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, I left some suggestions. Could we add some unit tests for the response parsing logic?
Map<String, Object> taskSettingsMap = Collections.emptyMap(); | ||
|
||
if (TaskType.RERANK.equals(taskType)) { | ||
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The task settings should be optional. I don't think we want to throw if the user does not specify any. In other services like cohere we default to an empty map like this:
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Let's remove the if-block and use the removeFromMapOrDefaultEmpty
instead.
@@ -93,52 +103,60 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( | |||
) { | |||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | |||
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); | |||
Map<String, Object> taskSettingsMap = Collections.emptyMap(); | |||
|
|||
if (TaskType.RERANK.equals(taskType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above, let's use Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
); | ||
} | ||
|
||
@Override | ||
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) { | ||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | ||
Map<String, Object> taskSettingsMap = Collections.emptyMap(); | ||
|
||
if (TaskType.RERANK.equals(taskType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above let's use:
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
ConfigurationParseContext context | ||
) { | ||
return switch (taskType) { | ||
protected HuggingFaceModel createModel(HuggingFaceModelParameters input) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we rename input
to parameters
or params
?
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions"; | ||
static final String USER_ROLE = "user"; | ||
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( | ||
"hugging face completion", | ||
OpenAiChatCompletionResponseEntity::fromResponse | ||
); | ||
private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> { | ||
var errorMessage = format(INVALID_REQUEST_TYPE_MESSAGE, "RERANK", request != null ? request.getClass().getName() : "null"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we use .getSimpleName()
here, that version tends to be a little more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this block inside the if-block so we aren't calculating the error util it's needed.
return parseList(parser, (listParser, index) -> { | ||
var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser); | ||
|
||
if (parsedRankedDoc.id == null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe declare*()
requires that the result be non-null. So I think we can remove these if-blocks to check for non-null. Can we create a test to ensure that null is not valid?
|
||
try { | ||
return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text); | ||
} catch (NumberFormatException e) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be missing it but what logic could throw the NumberFormatException
? Do we need the try/catch?
* <pre> | ||
* <code> | ||
* { | ||
* "rerank": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does HF respond with the rerank
field? Or is it just an array without the outer object?
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { | ||
moveToFirstToken(jsonParser); | ||
|
||
XContentParser.Token token = jsonParser.currentToken(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can omit this line and the ensureExpectedToken
because the parseList
will do the same check.
entity.toXContent(builder, ToXContent.EMPTY_PARAMS); | ||
String xContentResult = Strings.toString(builder); | ||
|
||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we use XContentHelper.stripWhitespace()
instead
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
additionally replaced getFirst() with get(0) in HuggingFaceActionCreatorTests. For the sake of backward compatibility |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, a few more test suggestions
}); | ||
} | ||
|
||
private record RankedDocEntry(@Nullable Integer id, @Nullable Float score, @Nullable String text) { | ||
private record RankedDocEntry(Integer id, Float score, @Nullable String text) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's change id
to index
, I think that's clearer.
public class HuggingFaceRerankResponseEntityTests extends ESTestCase { | ||
private static final String MISSED_FIELD_INDEX = "index"; | ||
private static final String MISSED_FIELD_SCORE = "score"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add the following tests
responseJson
with more than 1 item in the array, ensure that the score sorting works as expectedtopN
of null does not do any limitingtopN
of 5 does not do anything for a result set of 2topN
of 2 reduces the results set of 5 to 2
public class HuggingFaceRerankTaskSettings implements TaskSettings { | ||
|
||
public static final String NAME = "hugging_face_rerank_task_settings"; | ||
public static final String RETURN_TEXT = "return_text"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we use return_documents
here instead? That aligns with what we allow in the root level of the request and cohere also uses that as the name of the field.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been doing some testing with bge-reranker-base-mkn
. It's looking good. I did notice that the top_n
from the task settings doesn't seem to be applying:
PUT _inference/rerank/test
{
"service": "hugging_face",
"service_settings": {
"api_key": "<api key>",
"url": "https://vvx7nmi2feeokepr....."
},
"task_settings": {
"top_n": 1,
"return_text": true
}
}
POST _inference/rerank/test
{
"query": "Main characters in Star Wars",
"input": [
"money",
"luke skywalker",
"yoga",
"darth vader",
"han solo",
"fruit"
]
}
This produces:
{
"rerank": [
{
"index": 3,
"relevance_score": 0.7399865,
"text": "darth vader"
},
{
"index": 1,
"relevance_score": 0.099996306,
"text": "luke skywalker"
},
{
"index": 4,
"relevance_score": 0.00040448149,
"text": "han solo"
},
{
"index": 0,
"relevance_score": 0.00004720623,
"text": "money"
},
{
"index": 5,
"relevance_score": 0.00003734357,
"text": "fruit"
},
{
"index": 2,
"relevance_score": 0.00003734357,
"text": "yoga"
}
]
}
I expected there to only be a single entry in the array. If I include top_n
in the request it does limit it to 1 entry.
|
||
import static org.hamcrest.Matchers.containsString; | ||
|
||
public class HuggingFaceRerankTaskSettingsTests extends AbstractWireSerializingTestCase<HuggingFaceRerankTaskSettings> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this extend AbstractBWCWireSerializationTestCase
.
private final Boolean returnDocuments; | ||
|
||
public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException { | ||
this(in.readOptionalInt(), in.readOptionalBoolean()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of readOptionalInt
let's use readOptionalVInt
and writeOptionalVInt
.
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeOptionalInt(topNDocumentsOnly); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use writeOptionalVInt
.
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
… into Add-Hugging-Face-Rerank-support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes!
💔 Backport failed
You can use sqren/backport to manually backport by running |
gradle check
?CA have been signed.
Used the following with success:
gradlew :x-pack:plugin:inference:check
gradlew.bat :x-pack:plugin:inference:spotlessApply
Tested via api:
Also there were the following HF task settings integrated additionally:
raw_scores, truncate, truncation_direction
For now removed from the PR saving into a distinct branch,
Just for a case if we decide to make those a part of the inference api
@jonathan-buttner @Jan-Kazlouski-elastic
Apologies for the delay. I meant to create this much sooner.
Thanks for your patience
tested on:
bge-reranker-base
-> bge-reranker-base-mknjina-reranker-v1-turbo-en-GGUF
-> jina-reranker-v1-turbo-en-gg-iuuelasticsearch-specification PR