8
8
from pathlib import Path
9
9
from typing import Union , List , Optional , Dict , Generator
10
10
from tqdm .auto import tqdm
11
+ import warnings
11
12
12
13
try :
13
14
import faiss
@@ -37,7 +38,8 @@ class FAISSDocumentStore(SQLDocumentStore):
37
38
def __init__ (
38
39
self ,
39
40
sql_url : str = "sqlite:///faiss_document_store.db" ,
40
- vector_dim : int = 768 ,
41
+ vector_dim : int = None ,
42
+ embedding_dim : int = 768 ,
41
43
faiss_index_factory_str : str = "Flat" ,
42
44
faiss_index : Optional ["faiss.swigfaiss.Index" ] = None ,
43
45
return_embedding : bool = False ,
@@ -53,7 +55,8 @@ def __init__(
53
55
"""
54
56
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
55
57
deployment, Postgres is recommended.
56
- :param vector_dim: the embedding vector size.
58
+ :param vector_dim: Deprecated. Use embedding_dim instead.
59
+ :param embedding_dim: The embedding vector size. Default: 768.
57
60
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
58
61
The type is determined from the given string following the conventions
59
62
of the original FAISS index factory.
@@ -75,7 +78,7 @@ def __init__(
75
78
:param index: Name of index in document store to use.
76
79
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
77
80
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
78
- In both cases, the returned values in Document.score are normalized to be in range [0,1]:
81
+ In both cases, the returned values in Document.score are normalized to be in range [0,1]:
79
82
For `dot_product`: expit(np.asarray(raw_score / 100))
80
83
FOr `cosine`: (raw_score + 1) / 2
81
84
:param embedding_field: Name of field containing an embedding vector.
@@ -89,7 +92,7 @@ def __init__(
89
92
exists.
90
93
:param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
91
94
If specified no other params besides faiss_config_path must be specified.
92
- :param faiss_config_path: Stored FAISS initial configuration parameters.
95
+ :param faiss_config_path: Stored FAISS initial configuration parameters.
93
96
Can be created via calling `save()`
94
97
"""
95
98
# special case if we want to load an existing index from disk
@@ -103,14 +106,15 @@ def __init__(
103
106
104
107
# save init parameters to enable export of component config as YAML
105
108
self .set_config (
106
- sql_url = sql_url ,
107
- vector_dim = vector_dim ,
109
+ sql_url = sql_url ,
110
+ vector_dim = vector_dim ,
111
+ embedding_dim = embedding_dim ,
108
112
faiss_index_factory_str = faiss_index_factory_str ,
109
113
return_embedding = return_embedding ,
110
- duplicate_documents = duplicate_documents ,
111
- index = index ,
114
+ duplicate_documents = duplicate_documents ,
115
+ index = index ,
112
116
similarity = similarity ,
113
- embedding_field = embedding_field ,
117
+ embedding_field = embedding_field ,
114
118
progress_bar = progress_bar
115
119
)
116
120
@@ -124,14 +128,20 @@ def __init__(
124
128
raise ValueError ("The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
125
129
"Please set similarity to one of the above." )
126
130
127
- self .vector_dim = vector_dim
131
+ if vector_dim is not None :
132
+ warnings .warn ("The 'vector_dim' parameter is deprecated, "
133
+ "use 'embedding_dim' instead." , DeprecationWarning , 2 )
134
+ self .embedding_dim = vector_dim
135
+ else :
136
+ self .embedding_dim = embedding_dim
137
+
128
138
self .faiss_index_factory_str = faiss_index_factory_str
129
139
self .faiss_indexes : Dict [str , faiss .swigfaiss .Index ] = {}
130
140
if faiss_index :
131
141
self .faiss_indexes [index ] = faiss_index
132
142
else :
133
143
self .faiss_indexes [index ] = self ._create_new_index (
134
- vector_dim = self .vector_dim ,
144
+ embedding_dim = self .embedding_dim ,
135
145
index_factory = faiss_index_factory_str ,
136
146
metric_type = self .metric_type ,
137
147
** kwargs
@@ -158,7 +168,7 @@ def _validate_params_load_from_disk(self, sig: Signature, locals: dict, kwargs:
158
168
if param .name not in allowed_params and param .default != locals [param .name ]:
159
169
invalid_param_set = True
160
170
break
161
-
171
+
162
172
if invalid_param_set or len (kwargs ) > 0 :
163
173
raise ValueError ("if faiss_index_path is passed no other params besides faiss_config_path are allowed." )
164
174
@@ -172,20 +182,20 @@ def _validate_index_sync(self):
172
182
"configuration file correctly points to the same database that "
173
183
"was used when creating the original index." )
174
184
175
- def _create_new_index (self , vector_dim : int , metric_type , index_factory : str = "Flat" , ** kwargs ):
185
+ def _create_new_index (self , embedding_dim : int , metric_type , index_factory : str = "Flat" , ** kwargs ):
176
186
if index_factory == "HNSW" :
177
187
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
178
188
# defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
179
189
n_links = kwargs .get ("n_links" , 64 )
180
- index = faiss .IndexHNSWFlat (vector_dim , n_links , metric_type )
190
+ index = faiss .IndexHNSWFlat (embedding_dim , n_links , metric_type )
181
191
index .hnsw .efSearch = kwargs .get ("efSearch" , 20 )#20
182
192
index .hnsw .efConstruction = kwargs .get ("efConstruction" , 80 )#80
183
193
if "ivf" in index_factory .lower (): # enable reconstruction of vectors for inverted index
184
194
self .faiss_indexes [index ].set_direct_map_type (faiss .DirectMap .Hashtable )
185
195
186
196
logger .info (f"HNSW params: n_links: { n_links } , efSearch: { index .hnsw .efSearch } , efConstruction: { index .hnsw .efConstruction } " )
187
197
else :
188
- index = faiss .index_factory (vector_dim , index_factory , metric_type )
198
+ index = faiss .index_factory (embedding_dim , index_factory , metric_type )
189
199
return index
190
200
191
201
def write_documents (self , documents : Union [List [dict ], List [Document ]], index : Optional [str ] = None ,
@@ -217,7 +227,7 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
217
227
218
228
if not self .faiss_indexes .get (index ):
219
229
self .faiss_indexes [index ] = self ._create_new_index (
220
- vector_dim = self .vector_dim ,
230
+ embedding_dim = self .embedding_dim ,
221
231
index_factory = self .faiss_index_factory_str ,
222
232
metric_type = faiss .METRIC_INNER_PRODUCT ,
223
233
)
@@ -544,7 +554,7 @@ def save(self, index_path: Union[str, Path], config_path: Optional[Union[str, Pa
544
554
:param config_path: Path to save the initial configuration parameters to.
545
555
Defaults to the same as the file path, save the extension (.json).
546
556
This file contains all the parameters passed to FAISSDocumentStore()
547
- at creation time (for example the SQL path, vector_dim , etc), and will be
557
+ at creation time (for example the SQL path, embedding_dim , etc), and will be
548
558
used by the `load` method to restore the index with the appropriate configuration.
549
559
:return: None
550
560
"""
@@ -574,7 +584,7 @@ def _load_init_params_from_config(self, index_path: Union[str, Path], config_pat
574
584
575
585
# Add other init params to override the ones defined in the init params file
576
586
init_params ["faiss_index" ] = faiss_index
577
- init_params ["vector_dim " ] = faiss_index .d
587
+ init_params ["embedding_dim " ] = faiss_index .d
578
588
579
589
return init_params
580
590
0 commit comments