12
12
from dataclasses import dataclass
13
13
from . import MetricResult
14
14
from ..llm import RagasLLM
15
- from ..prompt .base import Prompt
16
15
17
16
18
17
def create_metric_decorator (metric_class ):
@@ -27,8 +26,6 @@ def create_metric_decorator(metric_class):
27
26
"""
28
27
29
28
def decorator_factory (
30
- llm : RagasLLM ,
31
- prompt : t .Union [str , Prompt ],
32
29
name : t .Optional [str ] = None ,
33
30
** metric_params ,
34
31
):
@@ -50,24 +47,62 @@ def decorator(func):
50
47
# Get metric name and check if function is async
51
48
metric_name = name or func .__name__
52
49
is_async = inspect .iscoroutinefunction (func )
50
+
51
+ # Check function signature to determine if it expects llm/prompt
52
+ sig = inspect .signature (func )
53
+ param_names = list (sig .parameters .keys ())
54
+ expects_llm = 'llm' in param_names
55
+ expects_prompt = 'prompt' in param_names
53
56
54
57
# TODO: Move to dataclass type implementation
55
58
@dataclass
56
59
class CustomMetric (metric_class ):
60
+
61
+ def _validate_result_value (self , result_value ):
62
+ """Validate result value based on metric type constraints."""
63
+ # Discrete metric validation
64
+ if hasattr (self , 'values' ) and result_value not in self .values :
65
+ return f"Metric { self .name } returned '{ result_value } ' but expected one of { self .values } "
66
+
67
+ # Numeric metric validation
68
+ if hasattr (self , 'range' ):
69
+ if not isinstance (result_value , (int , float )):
70
+ return f"Metric { self .name } returned '{ result_value } ' but expected a numeric value"
71
+ min_val , max_val = self .range
72
+ if not (min_val <= result_value <= max_val ):
73
+ return f"Metric { self .name } returned { result_value } but expected value in range { self .range } "
74
+
75
+ # Ranking metric validation
76
+ if hasattr (self , 'num_ranks' ):
77
+ if not isinstance (result_value , list ):
78
+ return f"Metric { self .name } returned '{ result_value } ' but expected a list"
79
+ if len (result_value ) != self .num_ranks :
80
+ return f"Metric { self .name } returned list of length { len (result_value )} but expected { self .num_ranks } items"
81
+
82
+ return None # No validation error
57
83
58
84
def _run_sync_in_async (self , func , * args , ** kwargs ):
59
85
"""Run a synchronous function in an async context."""
60
86
# For sync functions, just run them normally
61
87
return func (* args , ** kwargs )
62
88
63
- def _execute_metric (self , is_async_execution , reasoning , ** kwargs ):
89
+ def _execute_metric (self , llm , is_async_execution , ** kwargs ):
64
90
"""Execute the metric function with proper async handling."""
65
91
try :
92
+ # Prepare function arguments based on what the function expects
93
+ func_kwargs = kwargs .copy ()
94
+ func_args = []
95
+
96
+ if expects_llm :
97
+ func_args .append (llm )
98
+ if expects_prompt :
99
+ func_args .append (self .prompt )
100
+
66
101
if is_async :
67
102
# Async function implementation
68
103
if is_async_execution :
69
104
# In async context, await the function directly
70
- result = func (self . llm , self . prompt , ** kwargs )
105
+ result = func (* func_args , ** func_kwargs )
71
106
else :
72
107
# In sync context, run the async function in an event loop
73
108
try :
@@ -76,40 +111,68 @@ def _execute_metric(self, is_async_execution, reasoning, **kwargs):
76
111
loop = asyncio .new_event_loop ()
77
112
asyncio .set_event_loop (loop )
78
113
result = loop .run_until_complete (
79
- func (self . llm , self . prompt , ** kwargs )
114
+ func (* func_args , ** func_kwargs )
80
115
)
81
116
else :
82
117
# Sync function implementation
83
- result = func (self .llm , self .prompt , ** kwargs )
84
-
118
+ result = func (* func_args , ** func_kwargs )
119
+
120
+ # Ensure result is a MetricResult
121
+ if not isinstance (result , MetricResult ):
122
+ raise ValueError (f"Custom metric function must return MetricResult, got { type (result )} " )
123
+
124
+ # Validate the result based on metric type
125
+ validation_error = self ._validate_result_value (result .result )
126
+ if validation_error :
127
+ return MetricResult (result = None , reason = validation_error )
128
+
85
129
return result
130
+
86
131
except Exception as e :
87
132
# Handle errors gracefully
88
133
error_msg = f"Error executing metric { self .name } : { str (e )} "
89
134
return MetricResult (result = None , reason = error_msg )
90
135
91
- def score (self , reasoning : bool = True , n : int = 1 , ** kwargs ):
136
+ def score (self , llm : t . Optional [ RagasLLM ] = None , ** kwargs ):
92
137
"""Synchronous scoring method."""
93
138
return self ._execute_metric (
94
- is_async_execution = False , reasoning = reasoning , ** kwargs
139
+ llm , is_async_execution = False , ** kwargs
95
140
)
96
141
97
- async def ascore (self , reasoning : bool = True , n : int = 1 , ** kwargs ):
142
+ async def ascore (self , llm : t . Optional [ RagasLLM ] = None , ** kwargs ):
98
143
"""Asynchronous scoring method."""
144
+ # Prepare function arguments based on what the function expects
145
+ func_kwargs = kwargs .copy ()
146
+ func_args = []
147
+
148
+ if expects_llm :
149
+ func_args .append (llm )
150
+ if expects_prompt :
151
+ func_args .append (self .prompt )
152
+
99
153
if is_async :
100
154
# For async functions, await the result
101
- result = await func (self .llm , self .prompt , ** kwargs )
102
- return self ._extract_result (result , reasoning )
155
+ result = await func (* func_args , ** func_kwargs )
103
156
else :
104
157
# For sync functions, run normally
105
158
result = self ._run_sync_in_async (
106
- func , self . llm , self . prompt , ** kwargs
159
+ func , * func_args , ** func_kwargs
107
160
)
108
- return result
161
+
162
+ # Ensure result is a MetricResult
163
+ if not isinstance (result , MetricResult ):
164
+ raise ValueError (f"Custom metric function must return MetricResult, got { type (result )} " )
165
+
166
+ # Validate the result based on metric type
167
+ validation_error = self ._validate_result_value (result .result )
168
+ if validation_error :
169
+ return MetricResult (result = None , reason = validation_error )
170
+
171
+ return result
109
172
110
173
# Create the metric instance with all parameters
111
174
metric_instance = CustomMetric (
112
- name = metric_name , prompt = prompt , llm = llm , ** metric_params
175
+ name = metric_name ,** metric_params
113
176
)
114
177
115
178
# Preserve metadata
0 commit comments