Skip to content

Commit 667b9e0

Browse files
updated rerank function arguments (#3988)
1 parent 29c84d7 commit 667b9e0

File tree

3 files changed

+49
-49
lines changed

3 files changed

+49
-49
lines changed

backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
2222
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
2323
from onyx.context.search.models import InferenceSection
24-
from onyx.context.search.models import SearchRequest
25-
from onyx.context.search.pipeline import retrieval_preprocessing
24+
from onyx.context.search.models import RerankingDetails
2625
from onyx.context.search.postprocessing.postprocessing import rerank_sections
26+
from onyx.context.search.postprocessing.postprocessing import should_rerank
2727
from onyx.db.engine import get_session_context_manager
28+
from onyx.db.search_settings import get_current_search_settings
2829

2930

3031
def rerank_documents(
@@ -39,6 +40,8 @@ def rerank_documents(
3940

4041
# Rerank post retrieval and verification. First, create a search query
4142
# then create the list of reranked sections
43+
# If no question defined/question is None in the state, use the original
44+
# question from the search request as query
4245

4346
graph_config = cast(GraphConfig, config["metadata"]["config"])
4447
question = (
@@ -47,39 +50,28 @@ def rerank_documents(
4750
assert (
4851
graph_config.tooling.search_tool
4952
), "search_tool must be provided for agentic search"
50-
with get_session_context_manager() as db_session:
51-
# we ignore some of the user specified fields since this search is
52-
# internal to agentic search, but we still want to pass through
53-
# persona (for stuff like document sets) and rerank settings
54-
# (to not make an unnecessary db call).
55-
search_request = SearchRequest(
56-
query=question,
57-
persona=graph_config.inputs.search_request.persona,
58-
rerank_settings=graph_config.inputs.search_request.rerank_settings,
59-
)
60-
_search_query = retrieval_preprocessing(
61-
search_request=search_request,
62-
user=graph_config.tooling.search_tool.user, # bit of a hack
63-
llm=graph_config.tooling.fast_llm,
64-
db_session=db_session,
65-
)
6653

67-
# skip section filtering
54+
# Note that these are passed in values from the API and are overrides which are typically None
55+
rerank_settings = graph_config.inputs.search_request.rerank_settings
6856

69-
if (
70-
_search_query.rerank_settings
71-
and _search_query.rerank_settings.rerank_model_name
72-
and _search_query.rerank_settings.num_rerank > 0
73-
and len(verified_documents) > 0
74-
):
57+
if rerank_settings is None:
58+
with get_session_context_manager() as db_session:
59+
search_settings = get_current_search_settings(db_session)
60+
if not search_settings.disable_rerank_for_streaming:
61+
rerank_settings = RerankingDetails.from_db_model(search_settings)
62+
63+
if should_rerank(rerank_settings) and len(verified_documents) > 0:
7564
if len(verified_documents) > 1:
7665
reranked_documents = rerank_sections(
77-
_search_query,
78-
verified_documents,
66+
query_str=question,
67+
# if runnable, then rerank_settings is not None
68+
rerank_settings=cast(RerankingDetails, rerank_settings),
69+
sections_to_rerank=verified_documents,
7970
)
8071
else:
81-
num = "No" if len(verified_documents) == 0 else "One"
82-
logger.warning(f"{num} verified document(s) found, skipping reranking")
72+
logger.warning(
73+
f"{len(verified_documents)} verified document(s) found, skipping reranking"
74+
)
8375
reranked_documents = verified_documents
8476
else:
8577
logger.warning("No reranking settings found, using unranked documents")

backend/onyx/context/search/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(
6161
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
6262
prompt_config: PromptConfig | None = None,
6363
):
64+
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
65+
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
6466
self.search_request = search_request
6567
self.user = user
6668
self.llm = llm

backend/onyx/context/search/postprocessing/postprocessing.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from onyx.context.search.models import InferenceChunkUncleaned
1616
from onyx.context.search.models import InferenceSection
1717
from onyx.context.search.models import MAX_METRICS_CONTENT
18+
from onyx.context.search.models import RerankingDetails
1819
from onyx.context.search.models import RerankMetricsContainer
1920
from onyx.context.search.models import SearchQuery
2021
from onyx.document_index.document_index_utils import (
@@ -77,7 +78,8 @@ def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str:
7778

7879
@log_function_time(print_only=True)
7980
def semantic_reranking(
80-
query: SearchQuery,
81+
query_str: str,
82+
rerank_settings: RerankingDetails,
8183
chunks: list[InferenceChunk],
8284
model_min: int = CROSS_ENCODER_RANGE_MIN,
8385
model_max: int = CROSS_ENCODER_RANGE_MAX,
@@ -88,11 +90,9 @@ def semantic_reranking(
8890
8991
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
9092
"""
91-
rerank_settings = query.rerank_settings
92-
93-
if not rerank_settings or not rerank_settings.rerank_model_name:
94-
# Should never reach this part of the flow without reranking settings
95-
raise RuntimeError("Reranking flow should not be running")
93+
assert (
94+
rerank_settings.rerank_model_name
95+
), "Reranking flow cannot run without a specific model"
9696

9797
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
9898

@@ -107,7 +107,7 @@ def semantic_reranking(
107107
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
108108
for chunk in chunks_to_rerank
109109
]
110-
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
110+
sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
111111

112112
# Old logic to handle multiple cross-encoders preserved but not used
113113
sim_scores = [numpy.array(sim_scores_floats)]
@@ -165,8 +165,20 @@ def semantic_reranking(
165165
return list(ranked_chunks), list(ranked_indices)
166166

167167

168+
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
169+
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
170+
- rerank_model_name is not None
171+
- num_rerank is greater than 0
172+
"""
173+
if not rerank_settings:
174+
return False
175+
176+
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
177+
178+
168179
def rerank_sections(
169-
query: SearchQuery,
180+
query_str: str,
181+
rerank_settings: RerankingDetails,
170182
sections_to_rerank: list[InferenceSection],
171183
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
172184
) -> list[InferenceSection]:
@@ -181,16 +193,13 @@ def rerank_sections(
181193
"""
182194
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
183195

184-
if not query.rerank_settings:
185-
# Should never reach this part of the flow without reranking settings
186-
raise RuntimeError("Reranking settings not found")
187-
188196
ranked_chunks, _ = semantic_reranking(
189-
query=query,
197+
query_str=query_str,
198+
rerank_settings=rerank_settings,
190199
chunks=chunks_to_rerank,
191200
rerank_metrics_callback=rerank_metrics_callback,
192201
)
193-
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
202+
lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
194203

195204
# Scores from rerank cannot be meaningfully combined with scores without rerank
196205
# However the ordering is still important
@@ -260,16 +269,13 @@ def search_postprocessing(
260269

261270
rerank_task_id = None
262271
sections_yielded = False
263-
if (
264-
search_query.rerank_settings
265-
and search_query.rerank_settings.rerank_model_name
266-
and search_query.rerank_settings.num_rerank > 0
267-
):
272+
if should_rerank(search_query.rerank_settings):
268273
post_processing_tasks.append(
269274
FunctionCall(
270275
rerank_sections,
271276
(
272-
search_query,
277+
search_query.query,
278+
search_query.rerank_settings, # Cannot be None here
273279
retrieved_sections,
274280
rerank_metrics_callback,
275281
),

0 commit comments

Comments
 (0)