1+ import torch
2+
13from typing import List
24
35from sentence_transformers import CrossEncoder
@@ -20,16 +22,17 @@ def from_config(cls, config: RerankerConfig) -> "SentenceTransformerRerankerEngi
2022 return cls (model = model , ** config .model_dump ())
2123
2224 def invoke (self , query : str , documents : List [str ]) -> List [RerankerResult ]:
23- scores = self .model .predict (
24- [(query , doc ) for doc in documents ],
25- batch_size = self .batch_size ,
26- show_progress_bar = True ,
27- activation_fct = None , # NOTE sentence_transformers CrossEncoder will use sigmoid to normalize the score
28- convert_to_tensor = True ,
29- convert_to_numpy = False ,
30- )
31- scores = scores .to ("cpu" ).numpy ()
32- scores , indexes = scores .tolist (), (- scores ).argsort ().argsort ().tolist ()
25+ with torch .inference_mode (): # 使
26+ scores = self .model .predict (
27+ [(query , doc ) for doc in documents ],
28+ batch_size = self .batch_size ,
29+ show_progress_bar = True ,
30+ activation_fct = None , # NOTE sentence_transformers CrossEncoder will use sigmoid to normalize the score
31+ convert_to_tensor = True ,
32+ convert_to_numpy = False ,
33+ )
34+ scores = scores .to ("cpu" ).numpy ()
35+ scores , indexes = scores .tolist (), (- scores ).argsort ().argsort ().tolist ()
3336
3437 return [RerankerResult (index = index , relevent_score = score ) for score , index in zip (scores , indexes )]
3538
0 commit comments