1
1
import csv
2
2
import json
3
- import os
4
3
from bisect import bisect_left
5
4
from datetime import datetime
6
5
from pathlib import Path
32
31
from onyx .document_index .factory import get_default_document_index
33
32
from onyx .document_index .interfaces import DocumentIndex
34
33
from onyx .utils .logger import setup_logger
34
+ from shared_configs .configs import MULTI_TENANT
35
35
36
36
logger = setup_logger (__name__ )
37
37
@@ -51,9 +51,11 @@ class SearchEvalParameters(BaseModel):
51
51
52
52
53
53
def _load_search_parameters () -> SearchEvalParameters :
54
- current_dir = os .path .dirname (os .path .abspath (__file__ ))
55
- config_path = os .path .join (current_dir , "search_eval_config.yaml" )
56
- with open (config_path , "r" ) as file :
54
+ current_dir = Path (__file__ ).parent
55
+ config_path = current_dir / "search_eval_config.yaml"
56
+ if not config_path .exists ():
57
+ raise FileNotFoundError (f"Search eval config file not found at { config_path } " )
58
+ with config_path .open ("r" ) as file :
57
59
config = yaml .safe_load (file )
58
60
59
61
export_folder = config .get ("EXPORT_FOLDER" , "eval-%Y-%m-%d-%H-%M-%S" )
@@ -90,13 +92,29 @@ def _load_search_parameters() -> SearchEvalParameters:
90
92
91
93
def _load_query_pairs () -> list [tuple [str , str ]]:
92
94
current_dir = Path (__file__ ).parent
93
-
94
- with open (current_dir / "search_queries.json" , "r" ) as file :
95
+ search_queries_path = current_dir / "search_queries.json"
96
+ if not search_queries_path .exists ():
97
+ raise FileNotFoundError (
98
+ f"Search queries file not found at { search_queries_path } "
99
+ )
100
+ with search_queries_path .open ("r" ) as file :
95
101
orig_queries = json .load (file )
96
102
97
- with open (current_dir / "search_queries_modified.json" , "r" ) as file :
103
+ alt_queries_path = current_dir / "search_queries_modified.json"
104
+ if not alt_queries_path .exists ():
105
+ raise FileNotFoundError (
106
+ f"Modified search queries file not found at { alt_queries_path } . "
107
+ "Try running generate_search_queries.py."
108
+ )
109
+ with alt_queries_path .open ("r" ) as file :
98
110
alt_queries = json .load (file )
99
111
112
+ if len (orig_queries ) != len (alt_queries ):
113
+ raise ValueError (
114
+ "Number of original and modified queries must be the same. "
115
+ "Try running generate_search_queries.py again."
116
+ )
117
+
100
118
return list (zip (orig_queries , alt_queries ))
101
119
102
120
@@ -188,24 +206,24 @@ def _evaluate_one_query(
188
206
# compute metrics
189
207
search_ranks = {chunk .unique_id : rank for rank , chunk in enumerate (search_results )}
190
208
return [
191
- _compute_jaccard_similarity (search_topk , rerank_topk ),
209
+ * _compute_jaccard_and_missing_chunks_ratio (search_topk , rerank_topk ),
192
210
_compute_average_rank_change (search_ranks , rerank_topk ),
193
- _compute_average_missing_chunk_ratio (search_topk , rerank_topk ),
194
211
# score adjusted metrics
195
- _compute_jaccard_similarity (search_adj_topk , rerank_adj_topk ),
212
+ * _compute_jaccard_and_missing_chunks_ratio (search_adj_topk , rerank_adj_topk ),
196
213
_compute_average_rank_change (search_ranks , rerank_adj_topk ),
197
- _compute_average_missing_chunk_ratio (search_adj_topk , rerank_adj_topk ),
198
214
]
199
215
200
216
201
- def _compute_jaccard_similarity (
217
+ def _compute_jaccard_and_missing_chunks_ratio (
202
218
search_topk : list [InferenceChunk ], rerank_topk : list [InferenceChunk ]
203
- ) -> float :
219
+ ) -> tuple [ float , float ] :
204
220
search_chunkids = {chunk .unique_id for chunk in search_topk }
205
221
rerank_chunkids = {chunk .unique_id for chunk in rerank_topk }
206
- return len (search_chunkids . intersection ( rerank_chunkids ) ) / len (
207
- search_chunkids . union ( rerank_chunkids )
222
+ jaccard_similarity = len (search_chunkids & rerank_chunkids ) / len (
223
+ search_chunkids | rerank_chunkids
208
224
)
225
+ missing_chunks_ratio = len (rerank_chunkids - search_chunkids ) / len (rerank_chunkids )
226
+ return jaccard_similarity , missing_chunks_ratio
209
227
210
228
211
229
def _compute_average_rank_change (
@@ -218,22 +236,17 @@ def _compute_average_rank_change(
218
236
return sum (rank_changes ) / len (rank_changes )
219
237
220
238
221
- def _compute_average_missing_chunk_ratio (
222
- search_topk : list [InferenceChunk ], rerank_topk : list [InferenceChunk ]
223
- ) -> float :
224
- search_chunkids = {chunk .unique_id for chunk in search_topk }
225
- rerank_chunkids = {chunk .unique_id for chunk in rerank_topk }
226
- return len (rerank_chunkids .difference (search_chunkids )) / len (rerank_chunkids )
227
-
228
-
229
239
def run_search_eval () -> None :
240
+ if MULTI_TENANT :
241
+ raise ValueError ("Multi-tenant is not supported currently" )
242
+
230
243
SqlEngine .init_engine (
231
244
pool_size = POSTGRES_API_SERVER_POOL_SIZE ,
232
245
max_overflow = POSTGRES_API_SERVER_POOL_OVERFLOW ,
233
246
)
234
247
235
- search_parameters = _load_search_parameters ()
236
248
query_pairs = _load_query_pairs ()
249
+ search_parameters = _load_search_parameters ()
237
250
238
251
with get_session_with_current_tenant () as db_session :
239
252
multilingual_expansion = get_multilingual_expansion (db_session )
@@ -265,11 +278,11 @@ def run_search_eval() -> None:
265
278
[
266
279
"query" ,
267
280
"jaccard_similarity" ,
268
- "average_rank_change" ,
269
281
"missing_chunks_ratio" ,
282
+ "average_rank_change" ,
270
283
"jaccard_similarity_adj" ,
271
- "average_rank_change_adj" ,
272
284
"missing_chunks_ratio_adj" ,
285
+ "average_rank_change_adj" ,
273
286
]
274
287
)
275
288
@@ -326,23 +339,23 @@ def run_search_eval() -> None:
326
339
if not search_parameters .skip_rerank :
327
340
average_metrics = [metric / len (query_pairs ) for metric in sum_metrics ]
328
341
logger .info (f"Jaccard similarity: { average_metrics [0 ]} " )
329
- logger .info (f"Average rank change : { average_metrics [1 ]} " )
330
- logger .info (f"Average missing chunks ratio : { average_metrics [2 ]} " )
342
+ logger .info (f"Average missing chunks ratio : { average_metrics [1 ]} " )
343
+ logger .info (f"Average rank change : { average_metrics [2 ]} " )
331
344
logger .info (f"Jaccard similarity (adjusted): { average_metrics [3 ]} " )
332
- logger .info (f"Average rank change (adjusted): { average_metrics [4 ]} " )
333
- logger .info (f"Average missing chunks ratio (adjusted): { average_metrics [5 ]} " )
345
+ logger .info (f"Average missing chunks ratio (adjusted): { average_metrics [4 ]} " )
346
+ logger .info (f"Average rank change (adjusted): { average_metrics [5 ]} " )
334
347
335
348
aggregate_file = export_path / "aggregate_results.csv"
336
349
with aggregate_file .open ("w" ) as file :
337
350
aggregate_csv_writer = csv .writer (file )
338
351
aggregate_csv_writer .writerow (
339
352
[
340
353
"jaccard_similarity" ,
341
- "average_rank_change" ,
342
354
"missing_chunks_ratio" ,
355
+ "average_rank_change" ,
343
356
"jaccard_similarity_adj" ,
344
- "average_rank_change_adj" ,
345
357
"missing_chunks_ratio_adj" ,
358
+ "average_rank_change_adj" ,
346
359
]
347
360
)
348
361
aggregate_csv_writer .writerow (average_metrics )
0 commit comments