Skip to content

Commit 47c60b5

Browse files
authored
feat: llm as prompt as optional (#2084)
- llm based metric ```py test_metric = DiscreteMetric( name="test_metric", prompt = "Is the {response} a good response to the query {query}?", values=["pass", "fail"], ) ``` - Writing custom metric logic ```py @numeric_metric( name="test_metric", range=(0, 1), ) def test_metric( query: str, response: str, ) -> MetricResult: """ Is the response a good response to the query? """ result = 0 return MetricResult(result=result, reason="") ```
1 parent 10061a8 commit 47c60b5

File tree

6 files changed

+135
-191
lines changed

6 files changed

+135
-191
lines changed

experimental/ragas_experimental/metric/base.py

Lines changed: 29 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
__all__ = ['Metric']
77

88
# %% ../../nbs/api/metric/base.ipynb 2
9-
from abc import ABC, abstractmethod
9+
from abc import ABC
1010
import asyncio
1111
from dataclasses import dataclass, field
1212
from pydantic import BaseModel
@@ -31,24 +31,13 @@ class Metric(ABC):
3131
"""Base class for all metrics in the LLM evaluation library."""
3232

3333
name: str
34-
prompt: str | Prompt
35-
llm: RagasLLM
36-
_response_models: t.Dict[bool, t.Type[BaseModel]] = field(
37-
default_factory=dict, init=False, repr=False
38-
)
34+
prompt: t.Optional[t.Union[str, Prompt]] = None
35+
_response_model: t.Type[BaseModel] = field(init=False)
3936

4037
def __post_init__(self):
4138
if isinstance(self.prompt, str):
4239
self.prompt = Prompt(self.prompt)
4340

44-
@abstractmethod
45-
def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:
46-
"""Get the appropriate response model."""
47-
pass
48-
49-
@abstractmethod
50-
def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:
51-
pass
5241

5342
def get_variables(self) -> t.List[str]:
5443
if isinstance(self.prompt, Prompt):
@@ -62,54 +51,49 @@ def get_variables(self) -> t.List[str]:
6251
]
6352
return vars
6453

65-
def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:
66-
responses = []
54+
def score(self, llm: RagasLLM, **kwargs) -> MetricResult:
55+
6756
traces = {}
6857
traces["input"] = kwargs
6958
prompt_input = self.prompt.format(**kwargs)
70-
for _ in range(n):
71-
response = self.llm.generate(
72-
prompt_input, response_model=self._get_response_model(reasoning)
73-
)
74-
traces["output"] = response.model_dump()
75-
response = MetricResult(**response.model_dump())
76-
responses.append(response)
77-
results = self._ensemble(responses)
78-
results.traces = traces
79-
return results
59+
response = llm.generate(
60+
prompt_input, response_model=self._response_model
61+
)
62+
traces["output"] = response.model_dump()
63+
result = MetricResult(**response.model_dump())
64+
result.traces = traces
65+
return result
8066

8167
async def ascore(
82-
self, reasoning: bool = True, n: int = 1, **kwargs
68+
self, llm: RagasLLM, **kwargs
8369
) -> MetricResult:
84-
responses = [] # Added missing initialization
70+
8571
traces = {}
86-
traces["input"] = kwargs
72+
8773
prompt_input = self.prompt.format(**kwargs)
88-
for _ in range(n):
89-
response = await self.llm.agenerate(
90-
prompt_input, response_model=self._get_response_model(reasoning)
91-
)
92-
traces["output"] = response.model_dump()
93-
response = MetricResult(
94-
**response.model_dump()
95-
) # Fixed missing parentheses
96-
responses.append(response)
97-
results = self._ensemble(responses)
98-
results.traces = traces
99-
return results
74+
traces["input"] = prompt_input
75+
response = await llm.agenerate(
76+
prompt_input, response_model=self._response_model,
77+
)
78+
traces["output"] = response.model_dump()
79+
result = MetricResult(
80+
**response.model_dump()
81+
) # Fixed missing parentheses
82+
result.traces = traces
83+
return result
10084

10185
def batch_score(
102-
self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1
86+
self, llm: RagasLLM, inputs: t.List[t.Dict[str, t.Any]],
10387
) -> t.List[t.Any]:
104-
return [self.score(reasoning, n, **input_dict) for input_dict in inputs]
88+
return [self.score(llm, **input_dict) for input_dict in inputs]
10589

10690
async def abatch_score(
107-
self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1
91+
self, llm: RagasLLM, inputs: t.List[t.Dict[str, t.Any]],
10892
) -> t.List[MetricResult]:
10993
async_tasks = []
11094
for input_dict in inputs:
11195
# Add reasoning and n to the input parameters
112-
async_tasks.append(self.ascore(reasoning=reasoning, n=n, **input_dict))
96+
async_tasks.append(self.ascore(llm, **input_dict))
11397

11498
# Run all tasks concurrently and return results
11599
return await asyncio.gather(*async_tasks)

experimental/ragas_experimental/metric/decorator.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from dataclasses import dataclass
1313
from . import MetricResult
1414
from ..llm import RagasLLM
15-
from ..prompt.base import Prompt
1615

1716

1817
def create_metric_decorator(metric_class):
@@ -27,8 +26,6 @@ def create_metric_decorator(metric_class):
2726
"""
2827

2928
def decorator_factory(
30-
llm: RagasLLM,
31-
prompt: t.Union[str, Prompt],
3229
name: t.Optional[str] = None,
3330
**metric_params,
3431
):
@@ -50,24 +47,62 @@ def decorator(func):
5047
# Get metric name and check if function is async
5148
metric_name = name or func.__name__
5249
is_async = inspect.iscoroutinefunction(func)
50+
51+
# Check function signature to determine if it expects llm/prompt
52+
sig = inspect.signature(func)
53+
param_names = list(sig.parameters.keys())
54+
expects_llm = 'llm' in param_names
55+
expects_prompt = 'prompt' in param_names
5356

5457
# TODO: Move to dataclass type implementation
5558
@dataclass
5659
class CustomMetric(metric_class):
60+
61+
def _validate_result_value(self, result_value):
62+
"""Validate result value based on metric type constraints."""
63+
# Discrete metric validation
64+
if hasattr(self, 'values') and result_value not in self.values:
65+
return f"Metric {self.name} returned '{result_value}' but expected one of {self.values}"
66+
67+
# Numeric metric validation
68+
if hasattr(self, 'range'):
69+
if not isinstance(result_value, (int, float)):
70+
return f"Metric {self.name} returned '{result_value}' but expected a numeric value"
71+
min_val, max_val = self.range
72+
if not (min_val <= result_value <= max_val):
73+
return f"Metric {self.name} returned {result_value} but expected value in range {self.range}"
74+
75+
# Ranking metric validation
76+
if hasattr(self, 'num_ranks'):
77+
if not isinstance(result_value, list):
78+
return f"Metric {self.name} returned '{result_value}' but expected a list"
79+
if len(result_value) != self.num_ranks:
80+
return f"Metric {self.name} returned list of length {len(result_value)} but expected {self.num_ranks} items"
81+
82+
return None # No validation error
5783

5884
def _run_sync_in_async(self, func, *args, **kwargs):
5985
"""Run a synchronous function in an async context."""
6086
# For sync functions, just run them normally
6187
return func(*args, **kwargs)
6288

63-
def _execute_metric(self, is_async_execution, reasoning, **kwargs):
89+
def _execute_metric(self, llm, is_async_execution, **kwargs):
6490
"""Execute the metric function with proper async handling."""
6591
try:
92+
# Prepare function arguments based on what the function expects
93+
func_kwargs = kwargs.copy()
94+
func_args = []
95+
96+
if expects_llm:
97+
func_args.append(llm)
98+
if expects_prompt:
99+
func_args.append(self.prompt)
100+
66101
if is_async:
67102
# Async function implementation
68103
if is_async_execution:
69104
# In async context, await the function directly
70-
result = func(self.llm, self.prompt, **kwargs)
105+
result = func(*func_args, **func_kwargs)
71106
else:
72107
# In sync context, run the async function in an event loop
73108
try:
@@ -76,40 +111,68 @@ def _execute_metric(self, is_async_execution, reasoning, **kwargs):
76111
loop = asyncio.new_event_loop()
77112
asyncio.set_event_loop(loop)
78113
result = loop.run_until_complete(
79-
func(self.llm, self.prompt, **kwargs)
114+
func(*func_args, **func_kwargs)
80115
)
81116
else:
82117
# Sync function implementation
83-
result = func(self.llm, self.prompt, **kwargs)
84-
118+
result = func(*func_args, **func_kwargs)
119+
120+
# Ensure result is a MetricResult
121+
if not isinstance(result, MetricResult):
122+
raise ValueError(f"Custom metric function must return MetricResult, got {type(result)}")
123+
124+
# Validate the result based on metric type
125+
validation_error = self._validate_result_value(result.result)
126+
if validation_error:
127+
return MetricResult(result=None, reason=validation_error)
128+
85129
return result
130+
86131
except Exception as e:
87132
# Handle errors gracefully
88133
error_msg = f"Error executing metric {self.name}: {str(e)}"
89134
return MetricResult(result=None, reason=error_msg)
90135

91-
def score(self, reasoning: bool = True, n: int = 1, **kwargs):
136+
def score(self, llm: t.Optional[RagasLLM] = None, **kwargs):
92137
"""Synchronous scoring method."""
93138
return self._execute_metric(
94-
is_async_execution=False, reasoning=reasoning, **kwargs
139+
llm, is_async_execution=False, **kwargs
95140
)
96141

97-
async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs):
142+
async def ascore(self, llm: t.Optional[RagasLLM] = None, **kwargs):
98143
"""Asynchronous scoring method."""
144+
# Prepare function arguments based on what the function expects
145+
func_kwargs = kwargs.copy()
146+
func_args = []
147+
148+
if expects_llm:
149+
func_args.append(llm)
150+
if expects_prompt:
151+
func_args.append(self.prompt)
152+
99153
if is_async:
100154
# For async functions, await the result
101-
result = await func(self.llm, self.prompt, **kwargs)
102-
return self._extract_result(result, reasoning)
155+
result = await func(*func_args, **func_kwargs)
103156
else:
104157
# For sync functions, run normally
105158
result = self._run_sync_in_async(
106-
func, self.llm, self.prompt, **kwargs
159+
func, *func_args, **func_kwargs
107160
)
108-
return result
161+
162+
# Ensure result is a MetricResult
163+
if not isinstance(result, MetricResult):
164+
raise ValueError(f"Custom metric function must return MetricResult, got {type(result)}")
165+
166+
# Validate the result based on metric type
167+
validation_error = self._validate_result_value(result.result)
168+
if validation_error:
169+
return MetricResult(result=None, reason=validation_error)
170+
171+
return result
109172

110173
# Create the metric instance with all parameters
111174
metric_instance = CustomMetric(
112-
name=metric_name, prompt=prompt, llm=llm, **metric_params
175+
name=metric_name,**metric_params
113176
)
114177

115178
# Preserve metadata

experimental/ragas_experimental/metric/discrete.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,22 @@
88
# %% ../../nbs/api/metric/discrete.ipynb 2
99
import typing as t
1010
from dataclasses import dataclass, field
11-
from pydantic import BaseModel, create_model
12-
from collections import Counter
13-
from . import Metric, MetricResult
11+
from pydantic import create_model
12+
from . import Metric
1413
from .decorator import create_metric_decorator
1514

1615

1716
@dataclass
1817
class DiscreteMetric(Metric):
1918
values: t.List[str] = field(default_factory=lambda: ["pass", "fail"])
2019

21-
def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:
22-
"""Get or create a response model based on reasoning parameter."""
23-
24-
if with_reasoning in self._response_models:
25-
return self._response_models[with_reasoning]
26-
27-
model_name = "response_model"
20+
def __post_init__(self):
21+
super().__post_init__()
2822
values = tuple(self.values)
29-
fields = {"result": (t.Literal[values], ...)}
30-
31-
if with_reasoning:
32-
fields["reason"] = (str, ...) # type: ignore
33-
34-
model = create_model(model_name, **fields) # type: ignore
35-
self._response_models[with_reasoning] = model
36-
return model
37-
38-
def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:
39-
40-
if len(results) == 1:
41-
return results[0]
42-
43-
candidates = [candidate.result for candidate in results]
44-
counter = Counter(candidates)
45-
max_count = max(counter.values())
46-
for candidate in results:
47-
if counter[candidate.result] == max_count:
48-
result = candidate.result
49-
reason = candidate.reason
50-
return MetricResult(result=result, reason=reason)
23+
self._response_model = create_model("response_model",
24+
result=(t.Literal[values], ...),
25+
reason=(str, ...))
5126

52-
return results[0]
5327

5428

5529
discrete_metric = create_metric_decorator(DiscreteMetric)

experimental/ragas_experimental/metric/numeric.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,41 +8,19 @@
88
# %% ../../nbs/api/metric/numeric.ipynb 2
99
import typing as t
1010
from dataclasses import dataclass, field
11-
from pydantic import BaseModel, create_model
12-
from . import Metric, MetricResult
11+
from pydantic import create_model
12+
from . import Metric
1313
from .decorator import create_metric_decorator
1414

1515

1616
@dataclass
1717
class NumericMetric(Metric):
18-
range: t.Tuple[float, float]
18+
range: t.Tuple[float, float] = (0.0, 1.0)
1919

20-
def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:
21-
"""Get or create a response model based on reasoning parameter."""
20+
def __post_init__(self):
21+
super().__post_init__()
22+
self._response_model = create_model("response_model", result=(float, ...))
2223

23-
if with_reasoning in self._response_models:
24-
return self._response_models[with_reasoning]
25-
26-
model_name = "response_model"
27-
fields = {"result": (float, ...)}
28-
29-
if with_reasoning:
30-
fields["reason"] = (str, ...) # type: ignore
31-
32-
model = create_model(model_name, **fields)
33-
self._response_models[with_reasoning] = model
34-
return model
35-
36-
def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:
37-
38-
if len(results) == 1:
39-
return results[0]
40-
41-
candidates = [candidate.result for candidate in results]
42-
result = sum(candidates) / len(candidates)
43-
reason = results[0].reason
44-
45-
return MetricResult(result=result, reason=reason)
4624

4725

4826
numeric_metric = create_metric_decorator(NumericMetric)

0 commit comments

Comments
 (0)