Skip to content

Commit bcd37d0

Browse files
committed
feat: error handling & optimization
1 parent f3e2795 commit bcd37d0

File tree

3 files changed

+58
-33
lines changed

3 files changed

+58
-33
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
/web/test-results/
1515
backend/onyx/agent_search/main/test_data.json
1616
backend/tests/regression/answer_quality/test_data.json
17+
backend/tests/regression/search_quality/eval-*
18+
backend/tests/regression/search_quality/search_eval_config.yaml
19+
backend/tests/regression/search_quality/*.json
1720

1821
# secret files
1922
.env

backend/tests/regression/search_quality/generate_search_queries.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@
1919
from onyx.tools.tool_implementations.search.search_tool import SearchTool
2020
from onyx.tools.utils import explicit_tool_calling_supported
2121
from onyx.utils.logger import setup_logger
22+
from shared_configs.configs import MULTI_TENANT
2223

2324
logger = setup_logger()
2425

2526

2627
def _load_queries() -> list[str]:
2728
current_dir = Path(__file__).parent
28-
with open(current_dir / "search_queries.json", "r") as file:
29+
search_queries_path = current_dir / "search_queries.json"
30+
if not search_queries_path.exists():
31+
raise FileNotFoundError(
32+
f"Search queries file not found at {search_queries_path}"
33+
)
34+
with search_queries_path.open("r") as file:
2935
return json.load(file)
3036

3137

@@ -77,6 +83,9 @@ def __init__(self) -> None:
7783

7884

7985
def generate_search_queries() -> None:
86+
if MULTI_TENANT:
87+
raise ValueError("Multi-tenant is not supported currently")
88+
8089
SqlEngine.init_engine(
8190
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
8291
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,

backend/tests/regression/search_quality/run_search_eval.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import csv
22
import json
3-
import os
43
from bisect import bisect_left
54
from datetime import datetime
65
from pathlib import Path
@@ -32,6 +31,7 @@
3231
from onyx.document_index.factory import get_default_document_index
3332
from onyx.document_index.interfaces import DocumentIndex
3433
from onyx.utils.logger import setup_logger
34+
from shared_configs.configs import MULTI_TENANT
3535

3636
logger = setup_logger(__name__)
3737

@@ -51,9 +51,11 @@ class SearchEvalParameters(BaseModel):
5151

5252

5353
def _load_search_parameters() -> SearchEvalParameters:
54-
current_dir = os.path.dirname(os.path.abspath(__file__))
55-
config_path = os.path.join(current_dir, "search_eval_config.yaml")
56-
with open(config_path, "r") as file:
54+
current_dir = Path(__file__).parent
55+
config_path = current_dir / "search_eval_config.yaml"
56+
if not config_path.exists():
57+
raise FileNotFoundError(f"Search eval config file not found at {config_path}")
58+
with config_path.open("r") as file:
5759
config = yaml.safe_load(file)
5860

5961
export_folder = config.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S")
@@ -90,13 +92,29 @@ def _load_search_parameters() -> SearchEvalParameters:
9092

9193
def _load_query_pairs() -> list[tuple[str, str]]:
9294
current_dir = Path(__file__).parent
93-
94-
with open(current_dir / "search_queries.json", "r") as file:
95+
search_queries_path = current_dir / "search_queries.json"
96+
if not search_queries_path.exists():
97+
raise FileNotFoundError(
98+
f"Search queries file not found at {search_queries_path}"
99+
)
100+
with search_queries_path.open("r") as file:
95101
orig_queries = json.load(file)
96102

97-
with open(current_dir / "search_queries_modified.json", "r") as file:
103+
alt_queries_path = current_dir / "search_queries_modified.json"
104+
if not alt_queries_path.exists():
105+
raise FileNotFoundError(
106+
f"Modified search queries file not found at {alt_queries_path}. "
107+
"Try running generate_search_queries.py."
108+
)
109+
with alt_queries_path.open("r") as file:
98110
alt_queries = json.load(file)
99111

112+
if len(orig_queries) != len(alt_queries):
113+
raise ValueError(
114+
"Number of original and modified queries must be the same. "
115+
"Try running generate_search_queries.py again."
116+
)
117+
100118
return list(zip(orig_queries, alt_queries))
101119

102120

@@ -188,24 +206,24 @@ def _evaluate_one_query(
188206
# compute metrics
189207
search_ranks = {chunk.unique_id: rank for rank, chunk in enumerate(search_results)}
190208
return [
191-
_compute_jaccard_similarity(search_topk, rerank_topk),
209+
*_compute_jaccard_and_missing_chunks_ratio(search_topk, rerank_topk),
192210
_compute_average_rank_change(search_ranks, rerank_topk),
193-
_compute_average_missing_chunk_ratio(search_topk, rerank_topk),
194211
# score adjusted metrics
195-
_compute_jaccard_similarity(search_adj_topk, rerank_adj_topk),
212+
*_compute_jaccard_and_missing_chunks_ratio(search_adj_topk, rerank_adj_topk),
196213
_compute_average_rank_change(search_ranks, rerank_adj_topk),
197-
_compute_average_missing_chunk_ratio(search_adj_topk, rerank_adj_topk),
198214
]
199215

200216

201-
def _compute_jaccard_similarity(
217+
def _compute_jaccard_and_missing_chunks_ratio(
202218
search_topk: list[InferenceChunk], rerank_topk: list[InferenceChunk]
203-
) -> float:
219+
) -> tuple[float, float]:
204220
search_chunkids = {chunk.unique_id for chunk in search_topk}
205221
rerank_chunkids = {chunk.unique_id for chunk in rerank_topk}
206-
return len(search_chunkids.intersection(rerank_chunkids)) / len(
207-
search_chunkids.union(rerank_chunkids)
222+
jaccard_similarity = len(search_chunkids & rerank_chunkids) / len(
223+
search_chunkids | rerank_chunkids
208224
)
225+
missing_chunks_ratio = len(rerank_chunkids - search_chunkids) / len(rerank_chunkids)
226+
return jaccard_similarity, missing_chunks_ratio
209227

210228

211229
def _compute_average_rank_change(
@@ -218,22 +236,17 @@ def _compute_average_rank_change(
218236
return sum(rank_changes) / len(rank_changes)
219237

220238

221-
def _compute_average_missing_chunk_ratio(
222-
search_topk: list[InferenceChunk], rerank_topk: list[InferenceChunk]
223-
) -> float:
224-
search_chunkids = {chunk.unique_id for chunk in search_topk}
225-
rerank_chunkids = {chunk.unique_id for chunk in rerank_topk}
226-
return len(rerank_chunkids.difference(search_chunkids)) / len(rerank_chunkids)
227-
228-
229239
def run_search_eval() -> None:
240+
if MULTI_TENANT:
241+
raise ValueError("Multi-tenant is not supported currently")
242+
230243
SqlEngine.init_engine(
231244
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
232245
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
233246
)
234247

235-
search_parameters = _load_search_parameters()
236248
query_pairs = _load_query_pairs()
249+
search_parameters = _load_search_parameters()
237250

238251
with get_session_with_current_tenant() as db_session:
239252
multilingual_expansion = get_multilingual_expansion(db_session)
@@ -265,11 +278,11 @@ def run_search_eval() -> None:
265278
[
266279
"query",
267280
"jaccard_similarity",
268-
"average_rank_change",
269281
"missing_chunks_ratio",
282+
"average_rank_change",
270283
"jaccard_similarity_adj",
271-
"average_rank_change_adj",
272284
"missing_chunks_ratio_adj",
285+
"average_rank_change_adj",
273286
]
274287
)
275288

@@ -326,23 +339,23 @@ def run_search_eval() -> None:
326339
if not search_parameters.skip_rerank:
327340
average_metrics = [metric / len(query_pairs) for metric in sum_metrics]
328341
logger.info(f"Jaccard similarity: {average_metrics[0]}")
329-
logger.info(f"Average rank change: {average_metrics[1]}")
330-
logger.info(f"Average missing chunks ratio: {average_metrics[2]}")
342+
logger.info(f"Average missing chunks ratio: {average_metrics[1]}")
343+
logger.info(f"Average rank change: {average_metrics[2]}")
331344
logger.info(f"Jaccard similarity (adjusted): {average_metrics[3]}")
332-
logger.info(f"Average rank change (adjusted): {average_metrics[4]}")
333-
logger.info(f"Average missing chunks ratio (adjusted): {average_metrics[5]}")
345+
logger.info(f"Average missing chunks ratio (adjusted): {average_metrics[4]}")
346+
logger.info(f"Average rank change (adjusted): {average_metrics[5]}")
334347

335348
aggregate_file = export_path / "aggregate_results.csv"
336349
with aggregate_file.open("w") as file:
337350
aggregate_csv_writer = csv.writer(file)
338351
aggregate_csv_writer.writerow(
339352
[
340353
"jaccard_similarity",
341-
"average_rank_change",
342354
"missing_chunks_ratio",
355+
"average_rank_change",
343356
"jaccard_similarity_adj",
344-
"average_rank_change_adj",
345357
"missing_chunks_ratio_adj",
358+
"average_rank_change_adj",
346359
]
347360
)
348361
aggregate_csv_writer.writerow(average_metrics)

0 commit comments

Comments
 (0)