Skip to content

Commit ab2617b

Browse files
fix(metrics): streamline async scoring methods and apply nest_asyncio
- Replaced event loop checks with a unified async wrapper for scoring. - Simplified error handling and callback management. - Ensured compatibility with Jupyter environments by applying nest_asyncio directly.
1 parent 5f042f1 commit ab2617b

File tree

1 file changed

+48
-60
lines changed

1 file changed

+48
-60
lines changed

ragas/src/ragas/metrics/base.py

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tqdm import tqdm
1313

1414
from ragas._analytics import EvaluationEvent, _analytics_batcher
15-
from ragas.async_utils import is_event_loop_running
15+
from ragas.async_utils import apply_nest_asyncio, run
1616
from ragas.callbacks import ChainType, new_group
1717
from ragas.dataset_schema import MetricAnnotation, MultiTurnSample, SingleTurnSample
1818
from ragas.losses import BinaryMetricLoss, MSELoss
@@ -143,26 +143,22 @@ def score(self, row: t.Dict, callbacks: Callbacks = None) -> float:
143143
callbacks=callbacks,
144144
metadata={"type": ChainType.METRIC},
145145
)
146-
try:
147-
if is_event_loop_running():
148-
try:
149-
import nest_asyncio
150146

151-
nest_asyncio.apply()
152-
except ImportError:
153-
raise ImportError(
154-
"It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
155-
)
156-
loop = asyncio.get_event_loop()
157-
score = loop.run_until_complete(self._ascore(row=row, callbacks=group_cm))
158-
except Exception as e:
159-
if not group_cm.ended:
160-
rm.on_chain_error(e)
161-
raise e
162-
else:
163-
if not group_cm.ended:
164-
rm.on_chain_end({"output": score})
165-
return score
147+
async def _async_wrapper():
148+
try:
149+
result = await self._ascore(row=row, callbacks=group_cm)
150+
except Exception as e:
151+
if not group_cm.ended:
152+
rm.on_chain_error(e)
153+
raise e
154+
else:
155+
if not group_cm.ended:
156+
rm.on_chain_end({"output": result})
157+
return result
158+
159+
# Apply nest_asyncio logic to ensure compatibility in notebook/Jupyter environments.
160+
apply_nest_asyncio()
161+
return run(_async_wrapper)
166162

167163
@deprecated("0.2", removal="0.3", alternative="single_turn_ascore")
168164
async def ascore(
@@ -477,27 +473,23 @@ def single_turn_score(
477473
callbacks=callbacks,
478474
metadata={"type": ChainType.METRIC},
479475
)
480-
try:
481-
if is_event_loop_running():
482-
try:
483-
import nest_asyncio
484476

485-
nest_asyncio.apply()
486-
except ImportError:
487-
raise ImportError(
488-
"It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
489-
)
490-
loop = asyncio.get_event_loop()
491-
score = loop.run_until_complete(
492-
self._single_turn_ascore(sample=sample, callbacks=group_cm)
493-
)
494-
except Exception as e:
495-
if not group_cm.ended:
496-
rm.on_chain_error(e)
497-
raise e
498-
else:
499-
if not group_cm.ended:
500-
rm.on_chain_end({"output": score})
477+
async def _async_wrapper():
478+
try:
479+
result = await self._single_turn_ascore(
480+
sample=sample, callbacks=group_cm
481+
)
482+
except Exception as e:
483+
if not group_cm.ended:
484+
rm.on_chain_error(e)
485+
raise e
486+
else:
487+
if not group_cm.ended:
488+
rm.on_chain_end({"output": result})
489+
return result
490+
491+
apply_nest_asyncio()
492+
score = run(_async_wrapper)
501493

502494
# track the evaluation event
503495
_analytics_batcher.add_evaluation(
@@ -605,27 +597,23 @@ def multi_turn_score(
605597
callbacks=callbacks,
606598
metadata={"type": ChainType.METRIC},
607599
)
608-
try:
609-
if is_event_loop_running():
610-
try:
611-
import nest_asyncio
612600

613-
nest_asyncio.apply()
614-
except ImportError:
615-
raise ImportError(
616-
"It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
617-
)
618-
loop = asyncio.get_event_loop()
619-
score = loop.run_until_complete(
620-
self._multi_turn_ascore(sample=sample, callbacks=group_cm)
621-
)
622-
except Exception as e:
623-
if not group_cm.ended:
624-
rm.on_chain_error(e)
625-
raise e
626-
else:
627-
if not group_cm.ended:
628-
rm.on_chain_end({"output": score})
601+
async def _async_wrapper():
602+
try:
603+
result = await self._multi_turn_ascore(
604+
sample=sample, callbacks=group_cm
605+
)
606+
except Exception as e:
607+
if not group_cm.ended:
608+
rm.on_chain_error(e)
609+
raise e
610+
else:
611+
if not group_cm.ended:
612+
rm.on_chain_end({"output": result})
613+
return result
614+
615+
apply_nest_asyncio()
616+
score = run(_async_wrapper)
629617

630618
# track the evaluation event
631619
_analytics_batcher.add_evaluation(

0 commit comments

Comments
 (0)