Skip to content

Add RerankRequestChunker #130485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ protected void doInference(
InferenceService service,
ActionListener<InferenceServiceResults> listener
) {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed.

service.infer(
model,
request.getQuery(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ public class ChunkingSettingsBuilder {
public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1);
// Old settings used for backward compatibility for endpoints created before 8.16 when default was changed
public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
public static final int ELASTIC_RERANKER_TOKEN_LIMIT = 512;
public static final int ELASTIC_RERANKER_EXTRA_TOKEN_COUNT = 3;
public static final float TOKENS_PER_WORD = 0.75f;

public static ChunkingSettings fromMap(Map<String, Object> settings) {
return fromMap(settings, true);
Expand Down Expand Up @@ -51,4 +54,17 @@ public static ChunkingSettings fromMap(Map<String, Object> settings, boolean ret
case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings));
};
}

public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWordCount) {
var queryTokenCount = Math.ceil(queryWordCount * TOKENS_PER_WORD);
var chunkSizeTokenCountWithFullQuery = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount);

var maxChunkSizeTokenCount = Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2);
if (chunkSizeTokenCountWithFullQuery > maxChunkSizeTokenCount) {
maxChunkSizeTokenCount = chunkSizeTokenCountWithFullQuery;
}

var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount / TOKENS_PER_WORD);
return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class RerankRequestChunker {
private final List<String> inputs;
private final List<RerankChunks> rerankChunks;

public RerankRequestChunker(String query, List<String> inputs) {
this.inputs = inputs;
this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query));
}

private List<RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings) {
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
var chunks = new ArrayList<RerankChunks>();
for (int i = 0; i < inputs.size(); i++) {
var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings);
for (var chunk : chunksForInput) {
chunks.add(new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end())));
}
}
return chunks;
}

public List<String> getChunkedInputs() {
List<String> chunkedInputs = new ArrayList<>();
for (RerankChunks chunk : rerankChunks) {
chunkedInputs.add(chunk.chunkString());
}

// TODO: Score the inputs here and only return the top N chunks for each document
return chunkedInputs;
}

public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(ActionListener<InferenceServiceResults> listener) {
return ActionListener.wrap(results -> {
if (results instanceof RankedDocsResults rankedDocsResults) {
listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults));

} else {
listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass()));
}

}, listener::onFailure);
}

// TODO: Can we assume the rankeddocsresults are always sorted by relevance score?
// TODO: Should we short circuit if no chunking was done?
private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) {
List<RankedDocsResults.RankedDoc> updatedRankedDocs = new ArrayList<>();
Set<Integer> docIndicesSeen = new HashSet<>();
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) {
int chunkIndex = rankedDoc.index();
int docIndex = rerankChunks.get(chunkIndex).docIndex();

if (docIndicesSeen.contains(docIndex) == false) {
// Create a ranked doc with the full input string and the index for the document instead of the chunk
RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc(
docIndex,
rankedDoc.relevanceScore(),
inputs.get(docIndex)
);
updatedRankedDocs.add(updatedRankedDoc);
docIndicesSeen.add(docIndex);
}
}

return new RankedDocsResults(updatedRankedDocs);
}

public record RerankChunks(int docIndex, String chunkString) {};

private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) {
var wordIterator = BreakIterator.getWordInstance();
wordIterator.setText(query);
var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator);
return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rer
}

protected InferenceAction.Request generateRequest(List<String> docFeatures) {
// TODO: Try running the RerankRequestChunker here.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed as we're calling this in the service now.

return new InferenceAction.Request(
TaskType.RERANK,
inferenceId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

Expand Down Expand Up @@ -686,7 +687,15 @@ public void inferRerank(
Map<String, Object> requestTaskSettings,
ActionListener<InferenceServiceResults> listener
) {
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
var rerankChunker = new RerankRequestChunker(query, inputs);
var chunkedInputs = rerankChunker.getChunkedInputs();
var request = buildInferenceRequest(
model.mlNodeDeploymentId(),
new TextSimilarityConfigUpdate(query),
chunkedInputs,
inputType,
timeout
);

var returnDocs = Boolean.TRUE;
if (returnDocuments != null) {
Expand All @@ -696,13 +705,14 @@ public void inferRerank(
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
}

Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? chunkedInputs::get : i -> null;

ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
(l, inferenceResult) -> l.onResponse(
textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN)
)
);
ActionListener<InferModelAction.Response> mlResultsListener = rerankChunker.parseChunkedRerankResultsListener(listener)
.delegateFailureAndWrap(
(l, inferenceResult) -> l.onResponse(
textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN)
)
);

var maybeDeployListener = mlResultsListener.delegateResponse(
(l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import java.util.HashMap;
import java.util.Map;

import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_EXTRA_TOKEN_COUNT;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_TOKEN_LIMIT;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.TOKENS_PER_WORD;

public class ChunkingSettingsBuilderTests extends ESTestCase {

public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1);
Expand Down Expand Up @@ -47,6 +51,32 @@ public void testValidChunkingSettingsMap() {
});
}

public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountLessThanHalfOfTokenLimit() {
// Generate a word count for a non-empty query that takes up less than half the token limit
int maxQueryTokenCount = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT) / 2;
int queryWordCount = randomIntBetween(1, (int) (maxQueryTokenCount / TOKENS_PER_WORD));
var queryTokenCount = Math.ceil(queryWordCount * TOKENS_PER_WORD);
ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount);
assertTrue(chunkingSettings instanceof SentenceBoundaryChunkingSettings);
SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings;
int expectedMaxChunkSize = (int) ((ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount)
/ TOKENS_PER_WORD);
assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize);
assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap);
}

public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanHalfOfTokenLimit() {
// Generate a word count for a non-empty query that takes up more than half the token limit
int maxQueryTokenCount = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT) / 2;
int queryWordCount = randomIntBetween((int) (maxQueryTokenCount / TOKENS_PER_WORD), Integer.MAX_VALUE);
ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount);
assertTrue(chunkingSettings instanceof SentenceBoundaryChunkingSettings);
SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings;
int expectedMaxChunkSize = (int) (Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2) / TOKENS_PER_WORD);
assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize);
assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap);
}

private Map<Map<String, Object>, ChunkingSettings> chunkingSettingsMapToChunkingSettings() {
var maxChunkSizeWordBoundaryChunkingSettings = randomIntBetween(10, 300);
var overlap = randomIntBetween(1, maxChunkSizeWordBoundaryChunkingSettings / 2);
Expand Down