Skip to content

Commit 03893c4

Browse files
committed
🦄 refactor: use torch inference mode
1 parent 4fb1d4d commit 03893c4

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

src/engine/reranker/sentence_transformer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
13
from typing import List
24

35
from sentence_transformers import CrossEncoder
@@ -20,16 +22,17 @@ def from_config(cls, config: RerankerConfig) -> "SentenceTransformerRerankerEngi
2022
return cls(model=model, **config.model_dump())
2123

2224
def invoke(self, query: str, documents: List[str]) -> List[RerankerResult]:
23-
scores = self.model.predict(
24-
[(query, doc) for doc in documents],
25-
batch_size=self.batch_size,
26-
show_progress_bar=True,
27-
activation_fct=None, # NOTE sentence_transformers CrossEncoder will use sigmoid to normalize the score
28-
convert_to_tensor=True,
29-
convert_to_numpy=False,
30-
)
31-
scores = scores.to("cpu").numpy()
32-
scores, indexes = scores.tolist(), (-scores).argsort().argsort().tolist()
25+
with torch.inference_mode(): # 使
26+
scores = self.model.predict(
27+
[(query, doc) for doc in documents],
28+
batch_size=self.batch_size,
29+
show_progress_bar=True,
30+
activation_fct=None, # NOTE sentence_transformers CrossEncoder will use sigmoid to normalize the score
31+
convert_to_tensor=True,
32+
convert_to_numpy=False,
33+
)
34+
scores = scores.to("cpu").numpy()
35+
scores, indexes = scores.tolist(), (-scores).argsort().argsort().tolist()
3336

3437
return [RerankerResult(index=index, relevent_score=score) for score, index in zip(scores, indexes)]
3538

0 commit comments

Comments
 (0)