Skip to content

Commit 2e012f5

Browse files
committed
🎈 perf: improve reranker performance
1 parent cc5e8f4 commit 2e012f5

File tree

4 files changed

+39
-4
lines changed

4 files changed

+39
-4
lines changed

config/models.yaml.template

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ embedding:
88
reranker:
99
bge-reranker-v2-m3:
1010
path: BAAI/bge-reranker-v2-m3
11-
batch_size: 4
11+
batch_size: 4 # 增加batch_size可以提高GPU利用率,但需要更多内存
1212
max_seq_len: 512
13+
# 性能优化建议:
14+
# 1. 根据GPU内存调整batch_size:RTX 4090建议4-8,A100建议8-16
15+
# 2. 使用GPU并发:多个GPU可以增加更多worker
16+
# 3. 启用缓存以避免重复计算(如果applicable)
1317

1418
llm:
1519
qwen2-7b-instruct-mlx:

src/api/router/v1/reranker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ async def rerank(request: RerankRequest) -> RerankResponse:
4141
detail=f"Invalid model name: {request.model}",
4242
)
4343

44-
results = engine.invoke(request.query, request.documents)
44+
# 使用引擎的异步方法进行GPU推理
45+
results = await engine.async_invoke(request.query, request.documents)
4546
logger.debug(f"[Cohere Rerank] result: {results}")
4647

4748
sorted_results = sorted(enumerate(results), key=lambda x: x[1]["relevent_score"], reverse=True)

src/engine/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ def invoke(self, *args, **kwargs) -> Any:
1616
@abstractmethod
1717
def stream(self, *args, **kwargs) -> Any:
1818
raise NotImplementedError(f"{self.__class__.__name__} does not implement `stream` method")
19+
20+
@abstractmethod
21+
def async_invoke(self, *args, **kwargs) -> Any:
22+
raise NotImplementedError(f"{self.__class__.__name__} does not implement `async_invoke` method")
1923

2024

2125
class RerankerResult(TypedDict):

src/engine/reranker/sentence_transformer.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import torch
2-
1+
import asyncio
2+
import concurrent.futures
33
from typing import List
44

5+
import torch
56
from sentence_transformers import CrossEncoder
67

78
from src.config.arg import RerankerConfig
@@ -13,6 +14,7 @@
1314

1415
class SentenceTransformerRerankerEngine(BaseEngine, RerankerConfig):
1516
model: CrossEncoder
17+
_executor: concurrent.futures.ThreadPoolExecutor = None
1618

1719
@classmethod
1820
def from_config(cls, config: RerankerConfig) -> "SentenceTransformerRerankerEngine":
@@ -21,6 +23,16 @@ def from_config(cls, config: RerankerConfig) -> "SentenceTransformerRerankerEngi
2123
logger.success(f"[RerankerEngine] load model from {config.path}")
2224
return cls(model=model, **config.model_dump())
2325

26+
@property
27+
def executor(self) -> concurrent.futures.ThreadPoolExecutor:
28+
"""延迟初始化线程池执行器"""
29+
if self._executor is None:
30+
self._executor = concurrent.futures.ThreadPoolExecutor(
31+
max_workers=4, # 单个模型使用一个worker避免GPU内存冲突
32+
thread_name_prefix=f"gpu_reranker_{self.alias}"
33+
)
34+
return self._executor
35+
2436
def invoke(self, query: str, documents: List[str]) -> List[RerankerResult]:
2537
with torch.inference_mode(): # 使
2638
scores = self.model.predict(
@@ -36,5 +48,19 @@ def invoke(self, query: str, documents: List[str]) -> List[RerankerResult]:
3648

3749
return [RerankerResult(index=index, relevent_score=score) for score, index in zip(scores, indexes)]
3850

51+
async def async_invoke(self, query: str, documents: List[str]) -> List[RerankerResult]:
52+
loop = asyncio.get_event_loop()
53+
return await loop.run_in_executor(
54+
self.executor,
55+
self.invoke,
56+
query,
57+
documents
58+
)
59+
3960
def stream(self, query: str, documents: List[str]) -> List[RerankerResult]:
4061
raise NotImplementedError(f"{self.__class__.__name__} does not implement `stream` method")
62+
63+
def __del__(self):
64+
"""清理线程池资源"""
65+
if hasattr(self, '_executor') and self._executor is not None:
66+
self._executor.shutdown(wait=False)

0 commit comments

Comments
 (0)