Skip to content

Commit

Permalink
Merge pull request mongomock#905 from mongomock/fix/gridfs-patching
Browse files Browse the repository at this point in the history
fix(gridfs): Adapt code to refactored Mongo driver GridFS implementation
  • Loading branch information
mdomke authored Nov 6, 2024
2 parents d821f32 + f5472ac commit b233622
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 55 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Remove legacy syntax constructs using `pyupgrade --py38-plus`

### Fixed
- The Mongo Python driver did refactor the `gridfs` implementation, so that the patched code had to
be adapted.

## [4.2.0] - 2024-09-11
### Changed
- Switch to [hatch](https://hatch.pypa.io) as build system.
Expand All @@ -20,4 +24,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Remove support for deprecated Python versions (everything prior to 3.8)


[4.3.0]: https://github.com/mongomock/mongomock/compare/4.2.0...4.3.0
[4.2.0]: https://github.com/mongomock/mongomock/compare/4.1.3...4.2.0
30 changes: 19 additions & 11 deletions mongomock/gridfs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from unittest import mock

from mongomock import Database as MongoMockDatabase, Collection as MongoMockCollection
from mongomock import Collection as MongoMockCollection
from mongomock.collection import Cursor as MongoMockCursor
from mongomock import Database as MongoMockDatabase

try:
from gridfs.grid_file import GridOut as PyMongoGridOut
from gridfs.grid_file import GridOutCursor as PyMongoGridOutCursor
from pymongo.collection import Collection as PyMongoCollection
from pymongo.database import Database as PyMongoDatabase
from gridfs.grid_file import GridOut as PyMongoGridOut, GridOutCursor as PyMongoGridOutCursor

_HAVE_PYMONGO = True
except ImportError:
_HAVE_PYMONGO = False
Expand All @@ -16,15 +19,13 @@
# need both classes as one might want to access both mongomock and real
# MongoDb.
class _MongoMockGridOutCursor(MongoMockCursor):

def __init__(self, collection, *args, **kwargs):
self.__root_collection = collection
super().__init__(collection.files, *args, **kwargs)

def next(self):
next_file = super().next()
return PyMongoGridOut(
self.__root_collection, file_document=next_file, session=self.session)
return PyMongoGridOut(self.__root_collection, file_document=next_file, session=self.session)

__next__ = next

Expand All @@ -45,7 +46,6 @@ def _create_grid_out_cursor(collection, *args, **kwargs):


def enable_gridfs_integration():

"""This function enables the use of mongomock Database's and Collection's inside gridfs
Gridfs library use `isinstance` to make sure the passed elements
Expand All @@ -54,8 +54,16 @@ def enable_gridfs_integration():
"""

if not _HAVE_PYMONGO:
raise NotImplementedError('gridfs mocking requires pymongo to work')

mock.patch('gridfs.Database', (PyMongoDatabase, MongoMockDatabase)).start()
mock.patch('gridfs.grid_file.Collection', (PyMongoCollection, MongoMockCollection)).start()
mock.patch('gridfs.GridOutCursor', _create_grid_out_cursor).start()
raise NotImplementedError("gridfs mocking requires pymongo to work")

Database = (PyMongoDatabase, MongoMockDatabase)
Collection = (PyMongoCollection, MongoMockCollection)

try:
mock.patch("gridfs.synchronous.grid_file.Database", Database).start()
mock.patch("gridfs.synchronous.grid_file.Collection", Collection).start()
mock.patch("gridfs.synchronous.grid_file.GridOutCursor", _create_grid_out_cursor).start()
except (AttributeError, ModuleNotFoundError):
mock.patch("gridfs.Database", Database).start()
mock.patch("gridfs.grid_file.Collection", Collection).start()
mock.patch("gridfs.GridOutCursor", _create_grid_out_cursor).start()
99 changes: 55 additions & 44 deletions tests/test__gridfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,51 @@
import unittest
from unittest import TestCase, skipIf, skipUnless

from packaging import version

import mongomock
import mongomock.gridfs
from mongomock import helpers
from packaging import version

try:
import gridfs
from gridfs import errors

_HAVE_GRIDFS = True
except ImportError:
_HAVE_GRIDFS = False


try:
from bson.objectid import ObjectId

import pymongo
from pymongo import MongoClient as PymongoClient
except ImportError:
...


@skipUnless(helpers.HAVE_PYMONGO, 'pymongo not installed')
@skipUnless(_HAVE_GRIDFS and hasattr(gridfs.__builtins__, 'copy'), 'gridfs not installed')
@skipIf(os.getenv('NO_LOCAL_MONGO'), 'No local Mongo server running')
@skipUnless(helpers.HAVE_PYMONGO, "pymongo not installed")
@skipUnless(_HAVE_GRIDFS and hasattr(gridfs.__builtins__, "copy"), "gridfs not installed")
@skipIf(os.getenv("NO_LOCAL_MONGO"), "No local Mongo server running")
class GridFsTest(TestCase):

@classmethod
def setUpClass(cls):
mongomock.gridfs.enable_gridfs_integration()

def setUp(self):
super(GridFsTest, self).setUp()
super().setUp()
self.fake_conn = mongomock.MongoClient()
self.mongo_conn = self._connect_to_local_mongodb()
self.db_name = 'mongomock___testing_db'
self.db_name = "mongomock___testing_db"

self.mongo_conn[self.db_name]['fs']['files'].drop()
self.mongo_conn[self.db_name]['fs']['chunks'].drop()
self.mongo_conn[self.db_name]["fs"]["files"].drop()
self.mongo_conn[self.db_name]["fs"]["chunks"].drop()

self.real_gridfs = gridfs.GridFS(self.mongo_conn[self.db_name])
self.fake_gridfs = gridfs.GridFS(self.fake_conn[self.db_name])

def tearDown(self):
super(GridFsTest, self).setUp()
super().setUp()
self.mongo_conn.close()
self.fake_conn.close()

Expand Down Expand Up @@ -101,12 +101,14 @@ def test__delete_no_file(self):
self.fake_gridfs.delete(ObjectId())

def test__list_files(self):
fids = [self.fake_gridfs.put(GenFile(50, 9), filename='one'),
self.fake_gridfs.put(GenFile(62, 5), filename='two'),
self.fake_gridfs.put(GenFile(654, 1), filename='three'),
self.fake_gridfs.put(GenFile(5), filename='four')]
names = ['one', 'two', 'three', 'four']
names_no_two = [x for x in names if x != 'two']
fids = [
self.fake_gridfs.put(GenFile(50, 9), filename="one"),
self.fake_gridfs.put(GenFile(62, 5), filename="two"),
self.fake_gridfs.put(GenFile(654, 1), filename="three"),
self.fake_gridfs.put(GenFile(5), filename="four"),
]
names = ["one", "two", "three", "four"]
names_no_two = [x for x in names if x != "two"]
for x in self.fake_gridfs.list():
self.assertIn(x, names)

Expand All @@ -116,47 +118,56 @@ def test__list_files(self):
self.assertIn(x, names_no_two)

three_file = self.get_fake_file(fids[2])
self.assertEqual('three', three_file['filename'])
self.assertEqual(654, three_file['length'])
self.assertEqual("three", three_file["filename"])
self.assertEqual(654, three_file["length"])
self.fake_gridfs.delete(fids[0])
self.fake_gridfs.delete(fids[2])
self.fake_gridfs.delete(fids[3])
self.assertEqual(0, len(self.fake_gridfs.list()))

def test__find_files(self):
fids = [self.fake_gridfs.put(GenFile(50, 9), filename='a'),
self.fake_gridfs.put(GenFile(62, 5), filename='b'),
self.fake_gridfs.put(GenFile(654, 1), filename='b'),
self.fake_gridfs.put(GenFile(5), filename='a')]
c = self.fake_gridfs.find({'filename': 'a'}).sort('uploadDate', -1)
should_be_fid3 = c.next()
should_be_fid0 = c.next()
file_ids = []
for name, data in [
("a", GenFile(50, 9)),
("b", GenFile(62, 5)),
("b", GenFile(654, 1)),
("a", GenFile(5)),
]:
time.sleep(0.001)
file_ids.append(self.fake_gridfs.put(data, filename=name))

c = self.fake_gridfs.find({"filename": "a"}).sort("uploadDate", -1)
file3 = c.next()
file0 = c.next()
self.assertFalse(c.alive)
self.assertNotEqual(file3.uploadDate, file0.uploadDate)

self.assertEqual(fids[3], should_be_fid3._id)
self.assertEqual(fids[0], should_be_fid0._id)
self.assertEqual(file_ids[3], file3._id)
self.assertEqual(file_ids[0], file0._id)

def test__put_exists(self):
self.fake_gridfs.put(GenFile(1), _id='12345')
self.fake_gridfs.put(GenFile(1), _id="12345")
with self.assertRaises(errors.FileExists):
self.fake_gridfs.put(GenFile(2, 3), _id='12345')
self.fake_gridfs.put(GenFile(2, 3), _id="12345")

def assertSameFile(self, real, fake, max_delta_seconds=1):
# https://pymongo.readthedocs.io/en/stable/migrate-to-pymongo4.html#disable-md5-parameter-is-removed
if helpers.PYMONGO_VERSION < version.parse('4.0'):
self.assertEqual(real['md5'], fake['md5'])
if helpers.PYMONGO_VERSION < version.parse("4.0"):
self.assertEqual(real["md5"], fake["md5"])

self.assertEqual(real['length'], fake['length'])
self.assertEqual(real['chunkSize'], fake['chunkSize'])
self.assertEqual(real["length"], fake["length"])
self.assertEqual(real["chunkSize"], fake["chunkSize"])
self.assertLessEqual(
abs(real['uploadDate'] - fake['uploadDate']).seconds, max_delta_seconds,
msg='real: %s, fake: %s' % (real['uploadDate'], fake['uploadDate']))
abs(real["uploadDate"] - fake["uploadDate"]).seconds,
max_delta_seconds,
msg="real: {}, fake: {}".format(real["uploadDate"], fake["uploadDate"]),
)

def get_mongo_file(self, i):
return self.mongo_conn[self.db_name]['fs']['files'].find_one({'_id': i})
return self.mongo_conn[self.db_name]["fs"]["files"].find_one({"_id": i})

def get_fake_file(self, i):
return self.fake_conn[self.db_name]['fs']['files'].find_one({'_id': i})
return self.fake_conn[self.db_name]["fs"]["files"].find_one({"_id": i})

def _connect_to_local_mongodb(self, num_retries=60):
"""Performs retries on connection refused errors (for travis-ci builds)"""
Expand All @@ -165,16 +176,16 @@ def _connect_to_local_mongodb(self, num_retries=60):
time.sleep(0.5)
try:
return PymongoClient(
host=os.environ.get('TEST_MONGO_HOST', 'localhost'), maxPoolSize=1
host=os.environ.get("TEST_MONGO_HOST", "localhost"), maxPoolSize=1
)
except pymongo.errors.ConnectionFailure as e:
if retry == num_retries - 1:
raise
if 'connection refused' not in e.message.lower():
if "connection refused" not in e.message.lower():
raise


class GenFile(object):
class GenFile:
def __init__(self, length, value=0, do_encode=True):
self.gen = self._gen_data(length, value)
self.do_encode = do_encode
Expand All @@ -186,11 +197,11 @@ def _gen_data(self, length, value):

def _maybe_encode(self, s):
if self.do_encode and isinstance(s, str):
return s.encode('UTF-8')
return s.encode("UTF-8")
return s

def read(self, num_bytes=-1):
s = ''
s = ""
if num_bytes <= 0:
bytes_left = -1
else:
Expand All @@ -205,5 +216,5 @@ def read(self, num_bytes=-1):
return self._maybe_encode(s)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

0 comments on commit b233622

Please sign in to comment.