1+ import uuid
2+
13from fastapi import APIRouter , Depends , HTTPException , status
24
35from src .api .auth .bearer import auth_secret_key
4- from src .api .model .reranker import (ListRerankerResponse , RerankerRequest ,
5- RerankerResponse , RerankerResult ,
6- RerankModelCard )
6+ from src .api .model .reranker import (ApiVersion , BilledUnits ,
7+ ListRerankerResponse , RerankMeta ,
8+ RerankModelCard , RerankRequest ,
9+ RerankResponse , RerankResult )
710from src .config .gbl import RERANKER_ENGINE_MAPPING
811from src .logger import logger
912
@@ -23,11 +26,11 @@ async def list_rerankers() -> ListRerankerResponse:
2326 )
2427
2528
26- @reranker_router .post ("/rerank" , response_model = RerankerResponse , dependencies = [Depends (auth_secret_key )])
27- async def cohere_rerank (request : RerankerRequest ) -> RerankerResponse :
29+ @reranker_router .post ("/rerank" , response_model = RerankResponse , dependencies = [Depends (auth_secret_key )])
30+ async def cohere_rerank (request : RerankRequest ) -> RerankResponse :
2831 logger .info (f"[Cohere] use model: { request .model } " )
2932 if not request .query or not request .documents :
30- return RerankerResponse . create_response ([] )
33+ raise HTTPException ( status_code = status . HTTP_400_BAD_REQUEST , detail = "Invalid request, query and documents are required" )
3134
3235 try :
3336 engine = RERANKER_ENGINE_MAPPING [request .model ]
@@ -47,11 +50,18 @@ async def cohere_rerank(request: RerankerRequest) -> RerankerResponse:
4750 sorted_results = sorted_results [:request .top_n ]
4851
4952 cohere_results = [
50- RerankerResult (
53+ RerankResult (
5154 index = original_index ,
5255 relevance_score = round (result ["relevent_score" ], 6 )
5356 )
5457 for original_index , result in sorted_results
5558 ]
5659
57- return RerankerResponse .create_response (cohere_results )
60+ return RerankResponse (
61+ results = cohere_results ,
62+ id = uuid .uuid4 ().hex ,
63+ meta = RerankMeta (
64+ api_version = ApiVersion (version = "1" ),
65+ billed_units = BilledUnits (search_units = 1 ),
66+ ),
67+ )
0 commit comments