6
6
from UnionChatBot .utils .EmbeddingAPI import MyEmbeddingFunction
7
7
from UnionChatBot .utils .RerankerAPI import BM25Reranker
8
8
9
+
9
10
class ChromaAdapter :
10
11
def __init__ (
11
- self ,
12
- host : str = "localhost" ,
13
- port : int = 32000 ,
14
- max_rag_documents : int = 20 ,
15
- topk_documents : int = 3 ,
16
- similarity_filter : float = 1.5 ,
17
- embedding_model : str = "all-MiniLM-L6-v2" ,
18
- reranker_type : str = "bm25" ,
19
- api_key : Optional [str ] = None ,
20
- folder_id : Optional [str ] = None ,
21
- text_type : str = "doc" ,
22
- api_url : Optional [str ] = None
12
+ self ,
13
+ host : str = "localhost" ,
14
+ port : int = 32000 ,
15
+ max_rag_documents : int = 20 ,
16
+ topk_documents : int = 3 ,
17
+ similarity_filter : float = 1.5 ,
18
+ embedding_model : str = "all-MiniLM-L6-v2" ,
19
+ reranker_type : str = "bm25" ,
20
+ api_key : Optional [str ] = None ,
21
+ folder_id : Optional [str ] = None ,
22
+ text_type : str = "doc" ,
23
+ api_url : Optional [str ] = None ,
23
24
):
24
25
self .reranker_type = reranker_type
25
26
if reranker_type == "bm25" :
@@ -48,29 +49,26 @@ def embedding_function(self):
48
49
api_url = self .api_url ,
49
50
folder_id = self .folder_id ,
50
51
iam_token = self .api_key ,
51
- text_type = self .text_type
52
+ text_type = self .text_type ,
52
53
)
53
54
return self ._embedding_function
54
55
55
56
def get_info_from_db (
56
- self ,
57
- query : str ,
58
- collection_name : str ,
59
- n_results : int = 30 ,
60
- ** kwargs
57
+ self , query : str , collection_name : str , n_results : int = 30 , ** kwargs
61
58
) -> Dict [str , Any ]:
62
59
collection = self .client .get_collection (
63
- name = collection_name ,
64
- embedding_function = self .embedding_function
60
+ name = collection_name , embedding_function = self .embedding_function
65
61
)
66
62
return collection .query (
67
63
query_texts = [query ],
68
64
n_results = n_results ,
69
- include = ["documents" , "metadatas" , "distances" ]
65
+ include = ["documents" , "metadatas" , "distances" ],
70
66
)
71
67
72
68
def get_filtered_documents (self , data_raw : Dict [str , Any ]) -> dict :
73
- distances = data_raw ["distances" ][0 ] # Берем первый элемент, так как query_texts=[query]
69
+ distances = data_raw ["distances" ][
70
+ 0
71
+ ] # Берем первый элемент, так как query_texts=[query]
74
72
documents = data_raw ["documents" ][0 ]
75
73
metadatas = data_raw ["metadatas" ][0 ]
76
74
@@ -84,7 +82,8 @@ def get_filtered_documents(self, data_raw: Dict[str, Any]) -> dict:
84
82
metadatas [idx ]
85
83
for idx , dist in enumerate (distances )
86
84
if dist < self .similarity_filter
87
- ]}
85
+ ],
86
+ }
88
87
89
88
def get_pairs (self , query : str , documents : List [str ]) -> List [List [str ]]:
90
89
return [[query , doc ] for doc in documents ]
@@ -100,7 +99,7 @@ def get_info(self, query: str, collection_name: str) -> dict[str, list[Any] | st
100
99
data_raw = self .get_info_from_db (
101
100
query = query ,
102
101
collection_name = collection_name ,
103
- n_results = self .max_rag_documents
102
+ n_results = self .max_rag_documents ,
104
103
)
105
104
filtered_documents = self .get_filtered_documents (data_raw )
106
105
@@ -109,13 +108,19 @@ def get_info(self, query: str, collection_name: str) -> dict[str, list[Any] | st
109
108
"documents" : [],
110
109
"metadatas" : [],
111
110
"query" : query ,
112
- "collection_name" : collection_name
111
+ "collection_name" : collection_name ,
113
112
}
114
113
115
- idx_relevant_documents = self .apply_reranker (query = query , documents = filtered_documents ["documents" ])
114
+ idx_relevant_documents = self .apply_reranker (
115
+ query = query , documents = filtered_documents ["documents" ]
116
+ )
116
117
return {
117
- "documents" : [filtered_documents ["documents" ][idx ] for idx in idx_relevant_documents ],
118
- "metadatas" : [filtered_documents ["metadatas" ][idx ] for idx in idx_relevant_documents ],
118
+ "documents" : [
119
+ filtered_documents ["documents" ][idx ] for idx in idx_relevant_documents
120
+ ],
121
+ "metadatas" : [
122
+ filtered_documents ["metadatas" ][idx ] for idx in idx_relevant_documents
123
+ ],
119
124
"query" : query ,
120
- "collection_name" : collection_name
121
- }
125
+ "collection_name" : collection_name ,
126
+ }
0 commit comments