diff --git a/AUTHORS b/AUTHORS index 40508b532..ad897f0e5 100644 --- a/AUTHORS +++ b/AUTHORS @@ -259,6 +259,7 @@ that much better: * Agustin Barto (https://github.com/abarto) * Stankiewicz Mateusz (https://github.com/mas15) * Felix Schultheiß (https://github.com/felix-smashdocs) + * Lake Chan (https://github.com/StoneMoe) * Jan Stein (https://github.com/janste63) * Timothé Perez (https://github.com/AchilleAsh) * oleksandr-l5 (https://github.com/oleksandr-l5) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 3f0850327..aa90d8755 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1778,6 +1778,9 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + def __hash__(self): + return int("{}{}".format(self.collection_name, self.grid_id).encode().hex(), 16) + @property def fs(self): if not self._fs: diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 7da9c2e63..45814e858 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -512,7 +512,20 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None): ) with set_write_concern(queryset._collection, write_concern) as collection: + gridfs_refs = set() + for name, field in doc._fields.items(): + if field.__class__.__name__ == "FileField": + gridfs_refs.update(queryset.scalar(name)) + if ( + field.__class__.__name__ == "ListField" + and field.field.__class__.__name__ == "FileField" + ): + for ref_list in queryset.scalar(name): + gridfs_refs.update(ref_list) + result = collection.delete_many(queryset._query) + for ref in gridfs_refs: + ref.delete() # If we're using an unack'd write concern, we don't really know how # many items have been deleted at this point, hence we only return diff --git a/tests/fields/test_file_field.py b/tests/fields/test_file_field.py index 4f584b4c8..10aa4ee5d 100644 --- a/tests/fields/test_file_field.py +++ b/tests/fields/test_file_field.py @@ -9,6 +9,7 @@ from mongoengine import * from mongoengine.connection import get_db +from mongoengine.fields import FileField, ListField try: from PIL import Image # noqa: F401 @@ -575,6 +576,73 @@ class Animal(Document): assert marmot.photos[0].foo == "bar" assert marmot.photos[0].get().length == 8313 + def test_cascade_del_filefield(self): + """Ensure cascade deletion also remove file chunks""" + + class User(Document): + username = StringField() + + class Album(Document): + user = ReferenceField("User") + photo = FileField() + + User.register_delete_rule(Album, "user", CASCADE) + + User.drop_collection() + Album.drop_collection() + self.db["fs.files"].drop() + self.db["fs.chunks"].drop() + + user = User(username="bob").save() + assert User.objects.get() == user + + album = Album(user=user) + with open(TEST_IMAGE_PATH, "rb") as img: + album.photo.put(img) + album.save() + assert Album.objects.get().user == user + + user.delete() + assert User.objects.count() == 0 + assert Album.objects.count() == 0 + assert self.db["fs.files"].count() == 0 + assert self.db["fs.chunks"].count() == 0 + + def test_cascade_del_complex_field_filefield(self): + """Ensure cascade deletion also remove file chunks""" + + class User(Document): + username = StringField() + + class Album(Document): + user = ReferenceField("User") + photos = ListField(FileField()) + + User.register_delete_rule(Album, "user", CASCADE) + + User.drop_collection() + Album.drop_collection() + self.db["fs.files"].drop() + self.db["fs.chunks"].drop() + + user = User(username="bob").save() + assert User.objects.get() == user + + album = Album(user=user) + with open(TEST_IMAGE_PATH, "rb") as img: + photos_field = album._fields["photos"].field + new_proxy = photos_field.get_proxy_obj("photos", album) + new_proxy.put(img, content_type="image/jpeg", foo="bar") + album.photos.append(new_proxy) + album.save() + assert Album.objects.get().user == user + + user.delete() + assert User.objects.count() == 0 + assert Album.objects.count() == 0 + assert self.db["fs.files"].count() == 0 + assert self.db["fs.chunks"].count() == 0 + if __name__ == "__main__": unittest.main()