Skip to content

Commit 6d0b9f8

Browse files
author
Erez Sharim
committed
feat: mongodb implementation
adds a mongodb database implementation
1 parent cdea503 commit 6d0b9f8

File tree

8 files changed

+419
-17
lines changed

8 files changed

+419
-17
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,10 @@ jobs:
2626
- name: Install the project
2727
run: uv sync --all-packages --all-extras
2828

29+
- name: Start MongoDB
30+
uses: supercharge/mongodb-github-action@1.12.0
31+
with:
32+
mongodb-version: '8.0'
33+
2934
- name: Run tests
3035
run: uv run ./scripts/test.py

packages/core/src/flux0_core/storage/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
class StorageType(Enum):
55
NANODB_MEMORY = "nanodb_memory"
66
NANODB_JSON = "nanodb_json"
7+
MONGODB = "mongodb"
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
from __future__ import annotations
2+
3+
import uuid
4+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, cast
5+
6+
import jsonpatch
7+
8+
if TYPE_CHECKING:
9+
from pymongo import ASCENDING, DESCENDING, AsyncMongoClient
10+
from pymongo.asynchronous.collection import AsyncCollection as Collection
11+
from pymongo.errors import DuplicateKeyError
12+
13+
try:
14+
from pymongo import ASCENDING, DESCENDING, AsyncMongoClient
15+
from pymongo.errors import DuplicateKeyError
16+
except ImportError:
17+
raise ImportError(
18+
"MongoDB dependencies are not installed. Install them with: pip install pymongo jsonpatch"
19+
)
20+
21+
from flux0_nanodb.api import DocumentCollection, DocumentDatabase
22+
from flux0_nanodb.projection import Projection, apply_projection
23+
from flux0_nanodb.query import And, Comparison, Or, QueryFilter
24+
from flux0_nanodb.types import (
25+
DeleteResult,
26+
DocumentID,
27+
DocumentVersion,
28+
InsertOneResult,
29+
JSONPatchOperation,
30+
SortingOrder,
31+
TDocument,
32+
UpdateOneResult,
33+
)
34+
35+
36+
class MongoDocumentDatabase(DocumentDatabase):
37+
"""MongoDB implementation of DocumentDatabase using PyMongo async API."""
38+
39+
def __init__(self, client: AsyncMongoClient[Any], database_name: str):
40+
self.client = client
41+
self.database = client[database_name]
42+
43+
async def create_collection(
44+
self, name: str, schema: Type[TDocument]
45+
) -> DocumentCollection[TDocument]:
46+
"""Create a new collection with the given name and document schema."""
47+
# MongoDB creates collections automatically on first write
48+
# We'll just return a collection instance
49+
collection = self.database[name]
50+
return MongoDocumentCollection(collection, schema)
51+
52+
async def get_collection(
53+
self, name: str, schema: Type[TDocument]
54+
) -> DocumentCollection[TDocument]:
55+
"""Retrieve an existing collection by its name and document schema."""
56+
# Check if collection exists
57+
collection_names = await self.database.list_collection_names()
58+
if name not in collection_names:
59+
raise ValueError(f"Collection '{name}' does not exist")
60+
61+
collection = self.database[name]
62+
return MongoDocumentCollection(collection, schema)
63+
64+
async def delete_collection(self, name: str) -> None:
65+
"""Delete a collection by its name."""
66+
await self.database.drop_collection(name)
67+
68+
69+
class MongoDocumentCollection(DocumentCollection[TDocument]):
70+
"""MongoDB implementation of DocumentCollection using PyMongo async API."""
71+
72+
def __init__(self, collection: Collection[Any], schema: Type[TDocument]):
73+
self.collection = collection
74+
self.schema = schema
75+
76+
def _convert_query_filter_to_mongo(self, query_filter: QueryFilter) -> Dict[str, Any]:
77+
"""Convert our QueryFilter to MongoDB query format."""
78+
if isinstance(query_filter, Comparison):
79+
field = query_filter.path
80+
# Handle MongoDB's _id field mapping
81+
if field == "id":
82+
field = "_id"
83+
84+
if query_filter.op == "$eq":
85+
return {field: query_filter.value}
86+
elif query_filter.op == "$ne":
87+
return {field: {"$ne": query_filter.value}}
88+
elif query_filter.op == "$gt":
89+
return {field: {"$gt": query_filter.value}}
90+
elif query_filter.op == "$gte":
91+
return {field: {"$gte": query_filter.value}}
92+
elif query_filter.op == "$lt":
93+
return {field: {"$lt": query_filter.value}}
94+
elif query_filter.op == "$lte":
95+
return {field: {"$lte": query_filter.value}}
96+
elif query_filter.op == "$in":
97+
return {field: {"$in": query_filter.value}}
98+
else:
99+
raise ValueError(f"Unsupported operator: {query_filter.op}")
100+
101+
elif isinstance(query_filter, And):
102+
return {
103+
"$and": [
104+
self._convert_query_filter_to_mongo(expr) for expr in query_filter.expressions
105+
]
106+
}
107+
108+
elif isinstance(query_filter, Or):
109+
return {
110+
"$or": [
111+
self._convert_query_filter_to_mongo(expr) for expr in query_filter.expressions
112+
]
113+
}
114+
115+
else:
116+
raise ValueError(f"Unsupported query filter type: {type(query_filter)}")
117+
118+
def _convert_projection_to_mongo(self, projection: Mapping[str, Projection]) -> Dict[str, int]:
119+
"""Convert our Projection to MongoDB projection format."""
120+
mongo_projection = {}
121+
for field, proj_type in projection.items():
122+
# Handle MongoDB's _id field mapping
123+
mongo_field = "_id" if field == "id" else field
124+
mongo_projection[mongo_field] = 1 if proj_type == Projection.INCLUDE else 0
125+
return mongo_projection
126+
127+
def _convert_sort_to_mongo(
128+
self, sort: Sequence[Tuple[str, SortingOrder]]
129+
) -> List[Tuple[str, int]]:
130+
"""Convert our sort specification to MongoDB sort format."""
131+
mongo_sort = []
132+
for field, order in sort:
133+
# Handle MongoDB's _id field mapping
134+
mongo_field = "_id" if field == "id" else field
135+
mongo_order = ASCENDING if order == SortingOrder.ASC else DESCENDING
136+
mongo_sort.append((mongo_field, mongo_order))
137+
return mongo_sort
138+
139+
def _convert_from_mongo_doc(self, mongo_doc: Dict[str, Any]) -> TDocument:
140+
"""Convert MongoDB document to our document format."""
141+
if mongo_doc is None:
142+
return None
143+
144+
# Convert MongoDB's _id to our id field
145+
if "_id" in mongo_doc:
146+
mongo_doc["id"] = DocumentID(str(mongo_doc["_id"]))
147+
del mongo_doc["_id"]
148+
149+
return cast(TDocument, mongo_doc)
150+
151+
def _convert_to_mongo_doc(self, document: TDocument) -> Dict[str, Any]:
152+
"""Convert our document format to MongoDB document."""
153+
mongo_doc = dict(document)
154+
155+
# Convert our id field to MongoDB's _id
156+
if "id" in mongo_doc:
157+
mongo_doc["_id"] = mongo_doc["id"]
158+
del mongo_doc["id"]
159+
160+
return mongo_doc
161+
162+
async def find(
163+
self,
164+
filters: Optional[QueryFilter],
165+
projection: Optional[Mapping[str, Projection]] = None,
166+
limit: Optional[int] = None,
167+
offset: Optional[int] = None,
168+
sort: Optional[Sequence[Tuple[str, SortingOrder]]] = None,
169+
) -> Sequence[TDocument]:
170+
"""Find all documents that match the optional filters."""
171+
# Build MongoDB query
172+
mongo_query = {}
173+
if filters:
174+
mongo_query = self._convert_query_filter_to_mongo(filters)
175+
176+
# Build MongoDB projection
177+
mongo_projection = None
178+
if projection:
179+
mongo_projection = self._convert_projection_to_mongo(projection)
180+
181+
# Start with the base query
182+
cursor = self.collection.find(mongo_query, mongo_projection)
183+
184+
# Apply sorting
185+
if sort:
186+
mongo_sort = self._convert_sort_to_mongo(sort)
187+
cursor = cursor.sort(mongo_sort)
188+
189+
# Apply pagination
190+
if offset:
191+
cursor = cursor.skip(offset)
192+
if limit:
193+
cursor = cursor.limit(limit)
194+
195+
# Execute query and convert results
196+
results = await cursor.to_list(length=None)
197+
documents = [self._convert_from_mongo_doc(doc) for doc in results]
198+
199+
# Apply projection if it was specified (MongoDB projection might not handle deep paths)
200+
if projection:
201+
projected_docs = []
202+
for doc in documents:
203+
projected_doc = apply_projection(doc, projection)
204+
projected_docs.append(cast(TDocument, projected_doc))
205+
return projected_docs
206+
207+
return documents
208+
209+
async def insert_one(self, document: TDocument) -> InsertOneResult:
210+
"""Insert a single document into the collection."""
211+
mongo_doc = self._convert_to_mongo_doc(document)
212+
213+
# Generate ID and version if not present
214+
if "_id" not in mongo_doc:
215+
mongo_doc["_id"] = str(uuid.uuid4())
216+
if "version" not in mongo_doc:
217+
mongo_doc["version"] = DocumentVersion(str(uuid.uuid4()))
218+
219+
try:
220+
result = await self.collection.insert_one(mongo_doc)
221+
return InsertOneResult(
222+
acknowledged=result.acknowledged, inserted_id=DocumentID(str(result.inserted_id))
223+
)
224+
except DuplicateKeyError:
225+
# Handle duplicate key error
226+
return InsertOneResult(
227+
acknowledged=False, inserted_id=DocumentID(str(mongo_doc["_id"]))
228+
)
229+
230+
async def update_one(
231+
self, filters: QueryFilter, patch: List[JSONPatchOperation], upsert: bool = False
232+
) -> UpdateOneResult:
233+
"""Apply a JSON Patch to a single document that matches the provided filters."""
234+
mongo_query = self._convert_query_filter_to_mongo(filters)
235+
236+
# Find the document to patch
237+
existing_doc = await self.collection.find_one(mongo_query)
238+
if existing_doc is None and not upsert:
239+
return UpdateOneResult(
240+
acknowledged=True, matched_count=0, modified_count=0, upserted_id=None
241+
)
242+
243+
if existing_doc is None and upsert:
244+
# Create a new document for upsert
245+
new_doc: Dict[str, Any] = {}
246+
# Generate ID and version
247+
new_doc["_id"] = str(uuid.uuid4())
248+
new_doc["version"] = DocumentVersion(str(uuid.uuid4()))
249+
else:
250+
# Convert from MongoDB format for patching
251+
# At this point, existing_doc cannot be None, so we add an assertion for the type checker
252+
assert existing_doc is not None
253+
new_doc = dict(existing_doc)
254+
if "_id" in new_doc:
255+
new_doc["id"] = str(new_doc["_id"])
256+
del new_doc["_id"]
257+
258+
# Apply JSON patch
259+
patch_obj = jsonpatch.JsonPatch([dict(op) for op in patch])
260+
try:
261+
patched_doc = patch_obj.apply(new_doc)
262+
except jsonpatch.JsonPatchException as e:
263+
raise ValueError(f"Invalid JSON patch: {e}")
264+
265+
# Convert back to MongoDB format
266+
mongo_doc = dict(patched_doc)
267+
if "id" in mongo_doc:
268+
mongo_doc["_id"] = mongo_doc["id"]
269+
del mongo_doc["id"]
270+
271+
# Update version
272+
mongo_doc["version"] = DocumentVersion(str(uuid.uuid4()))
273+
274+
if existing_doc is None:
275+
# Insert new document (upsert)
276+
iresult = await self.collection.insert_one(mongo_doc)
277+
return UpdateOneResult(
278+
acknowledged=iresult.acknowledged,
279+
matched_count=0,
280+
modified_count=0,
281+
upserted_id=DocumentID(str(iresult.inserted_id)),
282+
)
283+
else:
284+
# Update existing document
285+
uresult = await self.collection.replace_one(mongo_query, mongo_doc)
286+
return UpdateOneResult(
287+
acknowledged=uresult.acknowledged,
288+
matched_count=uresult.matched_count,
289+
modified_count=uresult.modified_count,
290+
upserted_id=None,
291+
)
292+
293+
async def delete_one(self, filters: QueryFilter) -> DeleteResult[TDocument]:
294+
"""Delete the first document that matches the provided filters."""
295+
mongo_query = self._convert_query_filter_to_mongo(filters)
296+
297+
# Find the document before deleting it
298+
existing_doc = await self.collection.find_one(mongo_query)
299+
deleted_document = None
300+
if existing_doc:
301+
deleted_document = self._convert_from_mongo_doc(existing_doc)
302+
303+
# Delete the document
304+
result = await self.collection.delete_one(mongo_query)
305+
306+
return DeleteResult(
307+
acknowledged=result.acknowledged,
308+
deleted_count=result.deleted_count,
309+
deleted_document=deleted_document,
310+
)
311+
312+
313+
def create_client(uri: str) -> AsyncMongoClient[Any]:
314+
"""Create a MongoDB client."""
315+
return AsyncMongoClient(uri)

packages/nanodb/tests/test_db.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from flux0_nanodb.json import JsonDocumentDatabase
1313
from flux0_nanodb.memory import MemoryDocumentDatabase
14+
from flux0_nanodb.mongodb import MongoDocumentDatabase
1415
from flux0_nanodb.projection import Projection
1516
from flux0_nanodb.query import Comparison, QueryFilter
1617
from flux0_nanodb.types import (
@@ -21,6 +22,7 @@
2122
JSONPatchOperation,
2223
SortingOrder,
2324
)
25+
from pymongo import AsyncMongoClient
2426

2527

2628
# A test document that extends our base Document.
@@ -34,7 +36,7 @@ class SimpleDocument(TypedDict, total=False):
3436

3537

3638
# Fixture to provide a DocumentDatabase instance.
37-
@pytest_asyncio.fixture(params=["memory", "json"])
39+
@pytest_asyncio.fixture(params=["memory", "json", "mongodb"])
3840
async def db(request):
3941
if request.param == "memory":
4042
yield MemoryDocumentDatabase()
@@ -53,6 +55,39 @@ async def db(request):
5355
# Cleanup after tests
5456
if data_dir.exists():
5557
shutil.rmtree(data_dir)
58+
elif request.param == "mongodb":
59+
from pymongo.errors import OperationFailure, ServerSelectionTimeoutError
60+
61+
client = AsyncMongoClient("mongodb://localhost:27017")
62+
db_instance = MongoDocumentDatabase(client, "test_nanodb")
63+
64+
# Clean up any existing test data
65+
try:
66+
await client.drop_database("test_nanodb")
67+
except OperationFailure as e:
68+
# Database might not exist yet, which is fine
69+
# But re-raise if it's a different operation failure
70+
if "not found" not in str(e).lower():
71+
raise
72+
except ServerSelectionTimeoutError:
73+
# Re-raise connection issues as they indicate real problems
74+
raise
75+
76+
yield db_instance
77+
78+
# Cleanup after tests
79+
try:
80+
await client.drop_database("test_nanodb")
81+
except OperationFailure as e:
82+
# Database might not exist, which is fine
83+
# But re-raise if it's a different operation failure
84+
if "not found" not in str(e).lower():
85+
raise
86+
except ServerSelectionTimeoutError:
87+
# Re-raise connection issues as they indicate real problems
88+
raise
89+
finally:
90+
await client.close()
5691

5792

5893
# Fixture to provide a collection of TestDocument.

0 commit comments

Comments
 (0)