Skip to content

Commit 2ae8547

Browse files
committed
feat: replace milvus
1 parent 33daa7b commit 2ae8547

File tree

6 files changed

+143
-53
lines changed

6 files changed

+143
-53
lines changed

examples/using_milvus_as_vectorDB.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
import asyncio
3+
import numpy as np
4+
from nano_graphrag import GraphRAG, QueryParam
5+
from nano_graphrag._utils import logger
6+
from nano_graphrag.base import BaseVectorStorage
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass
11+
class MilvusLiteStorge(BaseVectorStorage):
12+
13+
@staticmethod
14+
def create_collection_if_not_exist(client, collection_name: str, **kwargs):
15+
if client.has_collection(collection_name):
16+
return
17+
# TODO add constants for ID max length to 32
18+
client.create_collection(
19+
collection_name, max_length=32, id_type="string", **kwargs
20+
)
21+
22+
def __post_init__(self):
23+
from pymilvus import MilvusClient
24+
25+
self._client_file_name = os.path.join(
26+
self.global_config["working_dir"], "milvus_lite.db"
27+
)
28+
self._client = MilvusClient(self._client_file_name)
29+
self._max_batch_size = self.global_config["embedding_batch_num"]
30+
MilvusLiteStorge.create_collection_if_not_exist(
31+
self._client,
32+
self.namespace,
33+
dimension=self.embedding_func.embedding_dim,
34+
)
35+
36+
async def upsert(self, data: dict[str, dict]):
37+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
38+
list_data = [
39+
{
40+
"id": k,
41+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
42+
}
43+
for k, v in data.items()
44+
]
45+
contents = [v["content"] for v in data.values()]
46+
batches = [
47+
contents[i : i + self._max_batch_size]
48+
for i in range(0, len(contents), self._max_batch_size)
49+
]
50+
embeddings_list = await asyncio.gather(
51+
*[self.embedding_func(batch) for batch in batches]
52+
)
53+
embeddings = np.concatenate(embeddings_list)
54+
for i, d in enumerate(list_data):
55+
d["vector"] = embeddings[i]
56+
results = self._client.upsert(collection_name=self.namespace, data=list_data)
57+
return results
58+
59+
async def query(self, query, top_k=5):
60+
embedding = await self.embedding_func([query])
61+
results = self._client.search(
62+
collection_name=self.namespace,
63+
data=embedding,
64+
limit=top_k,
65+
output_fields=list(self.meta_fields),
66+
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
67+
)
68+
return [
69+
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
70+
for dp in results[0]
71+
]
72+
73+
74+
def insert():
75+
data = ["YOUR TEXT DATA HERE", "YOUR TEXT DATA HERE"]
76+
rag = GraphRAG(
77+
working_dir="./nano_graphrag_cache_milvus_TEST",
78+
enable_llm_cache=True,
79+
vector_db_storage_cls=MilvusLiteStorge,
80+
)
81+
rag.insert(data)
82+
83+
84+
def query():
85+
rag = GraphRAG(
86+
working_dir="./nano_graphrag_cache_milvus_TEST",
87+
enable_llm_cache=True,
88+
vector_db_storage_cls=MilvusLiteStorge,
89+
)
90+
print(rag.query("YOUR QUERY HERE", param=QueryParam(mode="local")))
91+
92+
93+
insert()
94+
query()

nano_graphrag/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .graphrag import GraphRAG, QueryParam
22

3-
__version__ = "0.0.3"
3+
__version__ = "0.0.4.dev"
44
__author__ = "Jianbai Ye"
55
__url__ = "https://github.com/gusye1234/nano-graphrag"
66

nano_graphrag/_storage.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import networkx as nx
1010
import numpy as np
11-
from pymilvus import MilvusClient
11+
from nano_vectordb import NanoVectorDB
1212

1313
from ._utils import load_json, logger, write_json
1414
from .base import (
@@ -62,37 +62,23 @@ async def drop(self):
6262

6363

6464
@dataclass
65-
class MilvusLiteStorge(BaseVectorStorage):
66-
67-
@staticmethod
68-
def create_collection_if_not_exist(
69-
client: "MilvusClient", collection_name: str, **kwargs
70-
):
71-
if client.has_collection(collection_name):
72-
return
73-
# TODO add constants for ID max length to 32
74-
client.create_collection(
75-
collection_name, max_length=32, id_type="string", **kwargs
76-
)
65+
class NanoVectorDBStorage(BaseVectorStorage):
7766

7867
def __post_init__(self):
7968

8069
self._client_file_name = os.path.join(
81-
self.global_config["working_dir"], "milvus_lite.db"
70+
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
8271
)
83-
self._client = MilvusClient(self._client_file_name)
8472
self._max_batch_size = self.global_config["embedding_batch_num"]
85-
MilvusLiteStorge.create_collection_if_not_exist(
86-
self._client,
87-
self.namespace,
88-
dimension=self.embedding_func.embedding_dim,
73+
self._client = NanoVectorDB(
74+
self.embedding_func.embedding_dim, storage_file=self._client_file_name
8975
)
9076

9177
async def upsert(self, data: dict[str, dict]):
9278
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
9379
list_data = [
9480
{
95-
"id": k,
81+
"__id__": k,
9682
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
9783
}
9884
for k, v in data.items()
@@ -107,23 +93,23 @@ async def upsert(self, data: dict[str, dict]):
10793
)
10894
embeddings = np.concatenate(embeddings_list)
10995
for i, d in enumerate(list_data):
110-
d["vector"] = embeddings[i]
111-
results = self._client.upsert(collection_name=self.namespace, data=list_data)
96+
d["__vector__"] = embeddings[i]
97+
results = self._client.upsert(datas=list_data)
11298
return results
11399

114-
async def query(self, query, top_k=5):
100+
async def query(self, query: str, top_k=5):
115101
embedding = await self.embedding_func([query])
116-
results = self._client.search(
117-
collection_name=self.namespace,
118-
data=embedding,
119-
limit=top_k,
120-
output_fields=list(self.meta_fields),
121-
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
102+
embedding = embedding[0]
103+
results = self._client.query(
104+
query=embedding, top_k=top_k, better_than_threshold=0.2
122105
)
123-
return [
124-
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
125-
for dp in results[0]
106+
results = [
107+
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
126108
]
109+
return results
110+
111+
async def index_done_callback(self):
112+
self._client.save()
127113

128114

129115
@dataclass

nano_graphrag/graphrag.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
local_query,
1515
global_query,
1616
)
17-
from ._storage import JsonKVStorage, MilvusLiteStorge, NetworkXStorage
17+
from ._storage import (
18+
JsonKVStorage,
19+
NanoVectorDBStorage,
20+
NetworkXStorage,
21+
)
1822
from ._utils import EmbeddingFunc, compute_mdhash_id, limit_async_func_call, logger
1923
from .base import (
2024
BaseGraphStorage,
@@ -81,7 +85,7 @@ class GraphRAG:
8185

8286
# storage
8387
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
84-
vector_db_storage_cls: Type[BaseVectorStorage] = MilvusLiteStorge
88+
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
8589
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
8690
enable_llm_cache: bool = True
8791

readme.md

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,6 @@ Some important prompts:
122122
- `PROMPTS["global_reduce_rag_response"]` is the system prompt template of the global search generation.
123123
- `PROMPTS["fail_response"]` is the fallback response when nothing is related to the user query.
124124

125-
### Storage
126-
127-
You can replace all storage-related components to your own implementation, `nano-graphrag` mainly uses three kinds of storage:
128-
129-
- `base.BaseKVStorage` for storing key-json pairs of data.
130-
- By default we use disk file storage as the backend.
131-
- `GraphRAG(.., key_string_value_json_storage_cls=YOURS,...)`
132-
- `base.BaseVectorStorage` for indexing embeddings.
133-
- By default we use [`milvus-lite`](https://github.com/milvus-io/milvus-lite) as the backend.
134-
- `GraphRAG(.., vector_db_storage_cls=YOURS,...)`
135-
- `base.BaseGraphStorage` for storing knowledge graph.
136-
- By default we use [`networkx`](https://github.com/networkx/networkx) as the backend.
137-
- `GraphRAG(.., graph_storage_cls=YOURS,...)`
138-
139-
You can refer to `nano_graphrag.base` to see detailed interfaces for each components.
140-
141125
### LLM
142126

143127
In `nano-graphrag`, we requires two types of LLM, a great one and a cheap one. The former is used to plan and respond, the latter is used to summary. By default, the great one is `gpt-4o` and the cheap one is `gpt-4o-mini`
@@ -191,6 +175,28 @@ GraphRAG(embedding_func=your_embed_func, embedding_batch_num=..., embedding_func
191175

192176
You can refer to an [example](./examples/using_local_embedding_model.py) that use `sentence-transformer` to locally compute embeddings.
193177

178+
### Storage
179+
180+
You can replace all storage-related components to your own implementation, `nano-graphrag` mainly uses three kinds of storage:
181+
182+
**`base.BaseKVStorage` for storing key-json pairs of data**
183+
184+
- By default we use disk file storage as the backend.
185+
- `GraphRAG(.., key_string_value_json_storage_cls=YOURS,...)`
186+
187+
**`base.BaseVectorStorage` for indexing embeddings**
188+
189+
- By default we use [`nano-vectordb`](https://github.com/gusye1234/nano-vectordb) as the backend.
190+
- Check out this [example](./examples/using_milvus_as_vectorDB.py) that use [`milvus-lite`](https://github.com/milvus-io/milvus-lite) as the backend (not available in Windows).
191+
- `GraphRAG(.., vector_db_storage_cls=YOURS,...)`
192+
193+
**`base.BaseGraphStorage` for storing knowledge graph**
194+
195+
- By default we use [`networkx`](https://github.com/networkx/networkx) as the backend.
196+
- `GraphRAG(.., graph_storage_cls=YOURS,...)`
197+
198+
You can refer to `nano_graphrag.base` to see detailed interfaces for each components.
199+
194200

195201

196202
## Benchmark

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
openai
22
tiktoken
3-
pymilvus
43
networkx
5-
graspologic
4+
graspologic
5+
nano-vectordb

0 commit comments

Comments
 (0)