Skip to content

Commit 4fb1d4d

Browse files
committed
🐞 fix: fix reranker error
1 parent 642d842 commit 4fb1d4d

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

src/api/model/reranker.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ class ListRerankerResponse(BaseModel):
1313

1414

1515
# Cohere协议的请求格式
16-
class RerankerRequest(BaseModel):
16+
class RerankRequest(BaseModel):
1717
model: str
1818
query: str
1919
documents: List[str]
2020
top_n: Optional[int] = None
2121

2222

2323
# Cohere协议的响应格式
24-
class RerankerResult(BaseModel):
24+
class RerankResult(BaseModel):
2525
index: int
2626
relevance_score: float
2727

@@ -34,12 +34,12 @@ class BilledUnits(BaseModel):
3434
search_units: int
3535

3636

37-
class Meta(BaseModel):
37+
class RerankMeta(BaseModel):
3838
api_version: ApiVersion
3939
billed_units: BilledUnits
4040

4141

42-
class RerankerResponse(BaseModel):
43-
results: List[RerankerResult]
42+
class RerankResponse(BaseModel):
43+
results: List[RerankResult]
4444
id: str
45-
meta: Meta
45+
meta: RerankMeta

src/api/router/v1/reranker.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import uuid
2+
13
from fastapi import APIRouter, Depends, HTTPException, status
24

35
from src.api.auth.bearer import auth_secret_key
4-
from src.api.model.reranker import (ListRerankerResponse, RerankerRequest,
5-
RerankerResponse, RerankerResult,
6-
RerankModelCard)
6+
from src.api.model.reranker import (ApiVersion, BilledUnits,
7+
ListRerankerResponse, RerankMeta,
8+
RerankModelCard, RerankRequest,
9+
RerankResponse, RerankResult)
710
from src.config.gbl import RERANKER_ENGINE_MAPPING
811
from src.logger import logger
912

@@ -23,11 +26,11 @@ async def list_rerankers() -> ListRerankerResponse:
2326
)
2427

2528

26-
@reranker_router.post("/rerank", response_model=RerankerResponse, dependencies=[Depends(auth_secret_key)])
27-
async def cohere_rerank(request: RerankerRequest) -> RerankerResponse:
29+
@reranker_router.post("/rerank", response_model=RerankResponse, dependencies=[Depends(auth_secret_key)])
30+
async def cohere_rerank(request: RerankRequest) -> RerankResponse:
2831
logger.info(f"[Cohere] use model: {request.model}")
2932
if not request.query or not request.documents:
30-
return RerankerResponse.create_response([])
33+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request, query and documents are required")
3134

3235
try:
3336
engine = RERANKER_ENGINE_MAPPING[request.model]
@@ -47,11 +50,18 @@ async def cohere_rerank(request: RerankerRequest) -> RerankerResponse:
4750
sorted_results = sorted_results[:request.top_n]
4851

4952
cohere_results = [
50-
RerankerResult(
53+
RerankResult(
5154
index=original_index,
5255
relevance_score=round(result["relevent_score"], 6)
5356
)
5457
for original_index, result in sorted_results
5558
]
5659

57-
return RerankerResponse.create_response(cohere_results)
60+
return RerankResponse(
61+
results=cohere_results,
62+
id=uuid.uuid4().hex,
63+
meta=RerankMeta(
64+
api_version=ApiVersion(version="1"),
65+
billed_units=BilledUnits(search_units=1),
66+
),
67+
)

0 commit comments

Comments
 (0)