1- import torch
2-
1+ import asyncio
2+ import concurrent . futures
33from typing import List
44
5+ import torch
56from sentence_transformers import CrossEncoder
67
78from src .config .arg import RerankerConfig
1314
1415class SentenceTransformerRerankerEngine (BaseEngine , RerankerConfig ):
1516 model : CrossEncoder
17+ _executor : concurrent .futures .ThreadPoolExecutor = None
1618
1719 @classmethod
1820 def from_config (cls , config : RerankerConfig ) -> "SentenceTransformerRerankerEngine" :
@@ -21,6 +23,16 @@ def from_config(cls, config: RerankerConfig) -> "SentenceTransformerRerankerEngi
2123 logger .success (f"[RerankerEngine] load model from { config .path } " )
2224 return cls (model = model , ** config .model_dump ())
2325
26+ @property
27+ def executor (self ) -> concurrent .futures .ThreadPoolExecutor :
28+ """延迟初始化线程池执行器"""
29+ if self ._executor is None :
30+ self ._executor = concurrent .futures .ThreadPoolExecutor (
31+ max_workers = 4 , # 单个模型使用一个worker避免GPU内存冲突
32+ thread_name_prefix = f"gpu_reranker_{ self .alias } "
33+ )
34+ return self ._executor
35+
2436 def invoke (self , query : str , documents : List [str ]) -> List [RerankerResult ]:
2537 with torch .inference_mode (): # 使
2638 scores = self .model .predict (
@@ -36,5 +48,19 @@ def invoke(self, query: str, documents: List[str]) -> List[RerankerResult]:
3648
3749 return [RerankerResult (index = index , relevent_score = score ) for score , index in zip (scores , indexes )]
3850
51+ async def async_invoke (self , query : str , documents : List [str ]) -> List [RerankerResult ]:
52+ loop = asyncio .get_event_loop ()
53+ return await loop .run_in_executor (
54+ self .executor ,
55+ self .invoke ,
56+ query ,
57+ documents
58+ )
59+
3960 def stream (self , query : str , documents : List [str ]) -> List [RerankerResult ]:
4061 raise NotImplementedError (f"{ self .__class__ .__name__ } does not implement `stream` method" )
62+
63+ def __del__ (self ):
64+ """清理线程池资源"""
65+ if hasattr (self , '_executor' ) and self ._executor is not None :
66+ self ._executor .shutdown (wait = False )
0 commit comments