Skip to content

Commit 8fa1acf

Browse files
committed
Add retries on batch eval (3 by default) and allow for custom evaluators
1 parent 732ac2f commit 8fa1acf

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

evaluators/ragas/langevals_ragas/lib/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ class _GenericEvaluatorEntry(EvaluatorEntry):
9090

9191

9292
class RagasEvaluator(BaseEvaluator[TEntry, TSettings, TResult]):
93-
def _evaluate_entry(self, entry):
93+
def _evaluate_entry(self, *args, **kwargs):
9494
disable_tqdm()
95-
return super()._evaluate_entry(entry)
95+
return super()._evaluate_entry(*args, **kwargs)
9696

9797
def evaluate_batch(self, *args, **kwargs):
9898
restore_tqdm()

langevals/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import importlib
22
import importlib.metadata
33
import pkgutil
4+
import re
45
import textwrap
56
from typing import Optional, Type, get_args
67

@@ -64,8 +65,15 @@ def get_evaluator_definitions(evaluator_cls: BaseEvaluator):
6465
entry_type = get_args(fields["entry"].annotation)[0]
6566
result_type = get_args(fields["result"].annotation)[0]
6667

67-
module_name, evaluator_name = evaluator_cls.__module__.split(".", 1)
68-
module_name = module_name.split("langevals_")[1]
68+
namespaces = evaluator_cls.__module__.split(".", 1)
69+
if len(namespaces) == 2:
70+
module_name, evaluator_name = namespaces
71+
module_name = module_name.split("langevals_")[1]
72+
else:
73+
module_name = ""
74+
evaluator_name = evaluator_cls.__class__.__name__
75+
# CamelCase to snake_case
76+
evaluator_name = re.sub(r"(?<!^)(?=[A-Z])", "_", evaluator_name).lower()
6977

7078
if getattr(evaluator_cls, "name", None) is None:
7179
raise ValueError(f"Missing name attribute in {evaluator_cls}")

langevals_core/langevals_core/base_evaluator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from pydantic import BaseModel, ConfigDict, Field
2121
import pandas as pd
22+
from tenacity import Retrying, retry, stop_after_attempt, wait_exponential
2223
from tqdm.auto import tqdm
2324
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, as_completed, wait
2425
from langevals_core.azure_patch import patch_litellm
@@ -267,9 +268,10 @@ def set_model_envs(self):
267268
def evaluate(self, entry: TEntry) -> SingleEvaluationResult:
268269
raise NotImplementedError("This method should be implemented by subclasses.")
269270

270-
def _evaluate_entry(self, entry):
271+
def _evaluate_entry(self, entry, retries=0):
271272
try:
272-
return self.evaluate(entry)
273+
retryer = Retrying(stop=stop_after_attempt(retries), reraise=True)
274+
return retryer(self.evaluate, entry)
273275
except Exception as exception:
274276
return EvaluationResultError(
275277
error_type=type(exception).__name__,
@@ -284,14 +286,15 @@ def evaluate_batch(
284286
data: List[TEntry],
285287
index=0,
286288
max_evaluations_in_parallel=50,
289+
retries=3,
287290
_executor_ref: Optional[Callable[[ThreadPoolExecutor], None]] = None,
288291
) -> BatchEvaluationResult:
289292
results: list[SingleEvaluationResult] = [
290293
EvaluationResultSkipped(details="not processed")
291294
] * len(data)
292295
with ThreadPoolExecutor(max_workers=max_evaluations_in_parallel) as executor:
293296
future_to_index = {
294-
executor.submit(self._evaluate_entry, entry): idx
297+
executor.submit(self._evaluate_entry, entry, retries): idx
295298
for idx, entry in enumerate(data)
296299
}
297300

@@ -306,7 +309,9 @@ def evaluate_batch(
306309
executor, "interrupted"
307310
) and executor.__getattribute__("interrupted"):
308311
raise KeyboardInterrupt()
309-
done, not_done = wait(not_done, timeout=0.1, return_when=FIRST_COMPLETED)
312+
done, not_done = wait(
313+
not_done, timeout=0.1, return_when=FIRST_COMPLETED
314+
)
310315
for future in done:
311316
idx = future_to_index[future]
312317
results[idx] = future.result()

0 commit comments

Comments
 (0)