Skip to content

Commit a44b6c1

Browse files
mathew55github-actions[bot]bogdankostic
authored
Unify vector_dim and embedding_dim parameter in Document Store (#1922)
* Refactored code to unify vector_dim and embedding_dim parameter in DocumentStores * Unit test cases updated to use `embedding_dim` instead of `vector_dim` * Unit test case update to use embedding_dim instead of vector_dim * Add latest docstring and tutorial changes * Put usage of `vector_dim` param in same if-block as corresponding warning Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: bogdankostic <bogdankostic@web.de>
1 parent 00dc30a commit a44b6c1

File tree

12 files changed

+86
-55
lines changed

12 files changed

+86
-55
lines changed

docs/_src/api/api/document_store.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,14 +1202,15 @@ the vector embeddings are indexed in a FAISS Index.
12021202
#### \_\_init\_\_
12031203

12041204
```python
1205-
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
1205+
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
12061206
```
12071207

12081208
**Arguments**:
12091209

12101210
- `sql_url`: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
12111211
deployment, Postgres is recommended.
1212-
- `vector_dim`: the embedding vector size.
1212+
- `vector_dim`: Deprecated. Use embedding_dim instead.
1213+
- `embedding_dim`: The embedding vector size. Default: 768.
12131214
- `faiss_index_factory_str`: Create a new FAISS index of the specified type.
12141215
The type is determined from the given string following the conventions
12151216
of the original FAISS index factory.
@@ -1231,7 +1232,7 @@ the vector embeddings are indexed in a FAISS Index.
12311232
- `index`: Name of index in document store to use.
12321233
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default since it is
12331234
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
1234-
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
1235+
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
12351236
For `dot_product`: expit(np.asarray(raw_score / 100))
12361237
FOr `cosine`: (raw_score + 1) / 2
12371238
- `embedding_field`: Name of field containing an embedding vector.
@@ -1424,7 +1425,7 @@ Save FAISS Index to the specified file.
14241425
- `config_path`: Path to save the initial configuration parameters to.
14251426
Defaults to the same as the file path, save the extension (.json).
14261427
This file contains all the parameters passed to FAISSDocumentStore()
1427-
at creation time (for example the SQL path, vector_dim, etc), and will be
1428+
at creation time (for example the SQL path, embedding_dim, etc), and will be
14281429
used by the `load` method to restore the index with the appropriate configuration.
14291430

14301431
**Returns**:
@@ -1478,7 +1479,7 @@ Usage:
14781479
#### \_\_init\_\_
14791480

14801481
```python
1481-
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
1482+
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
14821483
```
14831484

14841485
**Arguments**:
@@ -1491,7 +1492,8 @@ Usage:
14911492
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
14921493
- `connection_pool`: Connection pool type to connect with Milvus server. Default: "SingletonThread".
14931494
- `index`: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
1494-
- `vector_dim`: The embedding vector size. Default: 768.
1495+
- `vector_dim`: Deprecated. Use embedding_dim instead.
1496+
- `embedding_dim`: The embedding vector size. Default: 768.
14951497
- `index_file_size`: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
14961498
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
14971499
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.

docs/_src/tutorials/tutorials/12.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ For more info on which suits your use case: https://github.com/facebookresearch/
5454
```python
5555
from haystack.document_stores import FAISSDocumentStore
5656

57-
document_store = FAISSDocumentStore(vector_dim=128, faiss_index_factory_str="Flat")
57+
document_store = FAISSDocumentStore(embedding_dim=128, faiss_index_factory_str="Flat")
5858
```
5959

6060
### Cleaning & indexing documents

haystack/document_stores/faiss.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from typing import Union, List, Optional, Dict, Generator
1010
from tqdm.auto import tqdm
11+
import warnings
1112

1213
try:
1314
import faiss
@@ -37,7 +38,8 @@ class FAISSDocumentStore(SQLDocumentStore):
3738
def __init__(
3839
self,
3940
sql_url: str = "sqlite:///faiss_document_store.db",
40-
vector_dim: int = 768,
41+
vector_dim: int = None,
42+
embedding_dim: int = 768,
4143
faiss_index_factory_str: str = "Flat",
4244
faiss_index: Optional["faiss.swigfaiss.Index"] = None,
4345
return_embedding: bool = False,
@@ -53,7 +55,8 @@ def __init__(
5355
"""
5456
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
5557
deployment, Postgres is recommended.
56-
:param vector_dim: the embedding vector size.
58+
:param vector_dim: Deprecated. Use embedding_dim instead.
59+
:param embedding_dim: The embedding vector size. Default: 768.
5760
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
5861
The type is determined from the given string following the conventions
5962
of the original FAISS index factory.
@@ -75,7 +78,7 @@ def __init__(
7578
:param index: Name of index in document store to use.
7679
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
7780
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
78-
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
81+
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
7982
For `dot_product`: expit(np.asarray(raw_score / 100))
8083
FOr `cosine`: (raw_score + 1) / 2
8184
:param embedding_field: Name of field containing an embedding vector.
@@ -89,7 +92,7 @@ def __init__(
8992
exists.
9093
:param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
9194
If specified no other params besides faiss_config_path must be specified.
92-
:param faiss_config_path: Stored FAISS initial configuration parameters.
95+
:param faiss_config_path: Stored FAISS initial configuration parameters.
9396
Can be created via calling `save()`
9497
"""
9598
# special case if we want to load an existing index from disk
@@ -103,14 +106,15 @@ def __init__(
103106

104107
# save init parameters to enable export of component config as YAML
105108
self.set_config(
106-
sql_url=sql_url,
107-
vector_dim=vector_dim,
109+
sql_url=sql_url,
110+
vector_dim=vector_dim,
111+
embedding_dim=embedding_dim,
108112
faiss_index_factory_str=faiss_index_factory_str,
109113
return_embedding=return_embedding,
110-
duplicate_documents=duplicate_documents,
111-
index=index,
114+
duplicate_documents=duplicate_documents,
115+
index=index,
112116
similarity=similarity,
113-
embedding_field=embedding_field,
117+
embedding_field=embedding_field,
114118
progress_bar=progress_bar
115119
)
116120

@@ -124,14 +128,20 @@ def __init__(
124128
raise ValueError("The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
125129
"Please set similarity to one of the above.")
126130

127-
self.vector_dim = vector_dim
131+
if vector_dim is not None:
132+
warnings.warn("The 'vector_dim' parameter is deprecated, "
133+
"use 'embedding_dim' instead.", DeprecationWarning, 2)
134+
self.embedding_dim = vector_dim
135+
else:
136+
self.embedding_dim = embedding_dim
137+
128138
self.faiss_index_factory_str = faiss_index_factory_str
129139
self.faiss_indexes: Dict[str, faiss.swigfaiss.Index] = {}
130140
if faiss_index:
131141
self.faiss_indexes[index] = faiss_index
132142
else:
133143
self.faiss_indexes[index] = self._create_new_index(
134-
vector_dim=self.vector_dim,
144+
embedding_dim=self.embedding_dim,
135145
index_factory=faiss_index_factory_str,
136146
metric_type=self.metric_type,
137147
**kwargs
@@ -158,7 +168,7 @@ def _validate_params_load_from_disk(self, sig: Signature, locals: dict, kwargs:
158168
if param.name not in allowed_params and param.default != locals[param.name]:
159169
invalid_param_set = True
160170
break
161-
171+
162172
if invalid_param_set or len(kwargs) > 0:
163173
raise ValueError("if faiss_index_path is passed no other params besides faiss_config_path are allowed.")
164174

@@ -172,20 +182,20 @@ def _validate_index_sync(self):
172182
"configuration file correctly points to the same database that "
173183
"was used when creating the original index.")
174184

175-
def _create_new_index(self, vector_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
185+
def _create_new_index(self, embedding_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
176186
if index_factory == "HNSW":
177187
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
178188
# defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
179189
n_links = kwargs.get("n_links", 64)
180-
index = faiss.IndexHNSWFlat(vector_dim, n_links, metric_type)
190+
index = faiss.IndexHNSWFlat(embedding_dim, n_links, metric_type)
181191
index.hnsw.efSearch = kwargs.get("efSearch", 20)#20
182192
index.hnsw.efConstruction = kwargs.get("efConstruction", 80)#80
183193
if "ivf" in index_factory.lower(): # enable reconstruction of vectors for inverted index
184194
self.faiss_indexes[index].set_direct_map_type(faiss.DirectMap.Hashtable)
185195

186196
logger.info(f"HNSW params: n_links: {n_links}, efSearch: {index.hnsw.efSearch}, efConstruction: {index.hnsw.efConstruction}")
187197
else:
188-
index = faiss.index_factory(vector_dim, index_factory, metric_type)
198+
index = faiss.index_factory(embedding_dim, index_factory, metric_type)
189199
return index
190200

191201
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
@@ -217,7 +227,7 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
217227

218228
if not self.faiss_indexes.get(index):
219229
self.faiss_indexes[index] = self._create_new_index(
220-
vector_dim=self.vector_dim,
230+
embedding_dim=self.embedding_dim,
221231
index_factory=self.faiss_index_factory_str,
222232
metric_type=faiss.METRIC_INNER_PRODUCT,
223233
)
@@ -544,7 +554,7 @@ def save(self, index_path: Union[str, Path], config_path: Optional[Union[str, Pa
544554
:param config_path: Path to save the initial configuration parameters to.
545555
Defaults to the same as the file path, save the extension (.json).
546556
This file contains all the parameters passed to FAISSDocumentStore()
547-
at creation time (for example the SQL path, vector_dim, etc), and will be
557+
at creation time (for example the SQL path, embedding_dim, etc), and will be
548558
used by the `load` method to restore the index with the appropriate configuration.
549559
:return: None
550560
"""
@@ -574,7 +584,7 @@ def _load_init_params_from_config(self, index_path: Union[str, Path], config_pat
574584

575585
# Add other init params to override the ones defined in the init params file
576586
init_params["faiss_index"] = faiss_index
577-
init_params["vector_dim"] = faiss_index.d
587+
init_params["embedding_dim"] = faiss_index.d
578588

579589
return init_params
580590

haystack/document_stores/milvus.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from haystack.nodes.retriever import BaseRetriever
55

66
import logging
7+
import warnings
78
import numpy as np
89
from tqdm import tqdm
910
from scipy.special import expit
@@ -41,7 +42,8 @@ def __init__(
4142
milvus_url: str = "tcp://localhost:19530",
4243
connection_pool: str = "SingletonThread",
4344
index: str = "document",
44-
vector_dim: int = 768,
45+
vector_dim: int = None,
46+
embedding_dim: int = 768,
4547
index_file_size: int = 1024,
4648
similarity: str = "dot_product",
4749
index_type: IndexType = IndexType.FLAT,
@@ -62,7 +64,8 @@ def __init__(
6264
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
6365
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
6466
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
65-
:param vector_dim: The embedding vector size. Default: 768.
67+
:param vector_dim: Deprecated. Use embedding_dim instead.
68+
:param embedding_dim: The embedding vector size. Default: 768.
6669
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
6770
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
6871
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
@@ -98,13 +101,20 @@ def __init__(
98101
# save init parameters to enable export of component config as YAML
99102
self.set_config(
100103
sql_url=sql_url, milvus_url=milvus_url, connection_pool=connection_pool, index=index, vector_dim=vector_dim,
101-
index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
104+
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
102105
search_param=search_param, duplicate_documents=duplicate_documents,
103106
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
104107
)
105108

106109
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
107-
self.vector_dim = vector_dim
110+
111+
if vector_dim is not None:
112+
warnings.warn("The 'vector_dim' parameter is deprecated, "
113+
"use 'embedding_dim' instead.", DeprecationWarning, 2)
114+
self.embedding_dim = vector_dim
115+
else:
116+
self.embedding_dim = embedding_dim
117+
108118
self.index_file_size = index_file_size
109119

110120
if similarity in ("dot_product", "cosine"):
@@ -147,7 +157,7 @@ def _create_collection_and_index_if_not_exist(
147157
if not ok:
148158
collection_param = {
149159
'collection_name': index,
150-
'dimension': self.vector_dim,
160+
'dimension': self.embedding_dim,
151161
'index_file_size': self.index_file_size,
152162
'metric_type': self.metric_type
153163
}

haystack/document_stores/milvus2x.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import warnings
23
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
34

45
if TYPE_CHECKING:
@@ -59,7 +60,8 @@ def __init__(
5960
port: str = "19530",
6061
connection_pool: str = "SingletonThread",
6162
index: str = "document",
62-
vector_dim: int = 768,
63+
vector_dim: int = None,
64+
embedding_dim: int = 768,
6365
index_file_size: int = 1024,
6466
similarity: str = "dot_product",
6567
index_type: str = "IVF_FLAT",
@@ -81,7 +83,8 @@ def __init__(
8183
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
8284
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
8385
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
84-
:param vector_dim: The embedding vector size. Default: 768.
86+
:param vector_dim: Deprecated. Use embedding_dim instead.
87+
:param embedding_dim: The embedding vector size. Default: 768.
8588
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
8689
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
8790
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
@@ -120,7 +123,7 @@ def __init__(
120123
# save init parameters to enable export of component config as YAML
121124
self.set_config(
122125
sql_url=sql_url, host=host, port=port, connection_pool=connection_pool, index=index, vector_dim=vector_dim,
123-
index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
126+
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
124127
search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field,
125128
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
126129
custom_fields=custom_fields,
@@ -135,7 +138,13 @@ def __init__(
135138
connections.add_connection(default={"host": host, "port": port})
136139
connections.connect()
137140

138-
self.vector_dim = vector_dim
141+
if vector_dim is not None:
142+
warnings.warn("The 'vector_dim' parameter is deprecated, "
143+
"use 'embedding_dim' instead.", DeprecationWarning, 2)
144+
self.embedding_dim = vector_dim
145+
else:
146+
self.embedding_dim = embedding_dim
147+
139148
self.index_file_size = index_file_size
140149

141150
if similarity == "dot_product":
@@ -187,7 +196,7 @@ def _create_collection_and_index_if_not_exist(
187196
if not has_collection:
188197
fields = [
189198
FieldSchema(name=self.id_field, dtype=DataType.INT64, is_primary=True, auto_id=True),
190-
FieldSchema(name=self.embedding_field, dtype=DataType.FLOAT_VECTOR, dim=self.vector_dim)
199+
FieldSchema(name=self.embedding_field, dtype=DataType.FLOAT_VECTOR, dim=self.embedding_dim)
191200
]
192201

193202
for field in custom_fields:

0 commit comments

Comments
 (0)