diff --git a/asynctest/mock.py b/asynctest/mock.py index 157f423..0d9f63d 100644 --- a/asynctest/mock.py +++ b/asynctest/mock.py @@ -946,13 +946,20 @@ def _decorate_coroutine_callable(func, new_patching): patchings = [new_patching] def patched_factory(*args, **kwargs): + # Patches must be copied for each instance of the coroutine, to avoid + # concurrency issues. If we don't do that the first coroutine instance + # to terminate will deactivate the patch and break the other + # coroutines. + # If one wants to ensure the same patch mock is used for concurrent + # coroutines, it must set it explicitly. + local_patchings = [patching.copy() for patching in patchings] extra_args = [] patchers_to_exit = [] patch_dict_with_limited_scope = [] exc_info = tuple() try: - for patching in patchings: + for patching in local_patchings: arg = patching.__enter__() if patching.scope == LIMITED: patchers_to_exit.append(patching) @@ -974,7 +981,7 @@ def patched_factory(*args, **kwargs): args += tuple(extra_args) gen = func(*args, **kwargs) - return _PatchedGenerator(gen, patchings, + return _PatchedGenerator(gen, local_patchings, asyncio.iscoroutinefunction(func)) except BaseException: if patching not in patchers_to_exit and _is_started(patching): @@ -1090,6 +1097,7 @@ def copy(self): def __enter__(self): # When patching a coroutine, we reuse the same mock object + # for the whole instance of the coroutine if self.mock_to_reuse is not None: self.target = self.getter() self.temp_original, self.is_local = self.get_original() @@ -1147,7 +1155,8 @@ def patch(target, new=DEFAULT, spec=None, create=False, spec_set=None, ``new`` specifies which object will replace the ``target`` when the patch is applied. By default, the target will be patched with an instance of :class:`~asynctest.CoroutineMock` if it is a coroutine, or - a :class:`~asynctest.MagicMock` object. + a :class:`~asynctest.MagicMock` object. In this case, the mock is passed as + an extra positional argument. It is a replacement to :func:`unittest.mock.patch`, but using :mod:`asynctest.mock` objects. @@ -1163,7 +1172,15 @@ def patch(target, new=DEFAULT, spec=None, create=False, spec_set=None, yields a value and pauses its execution (with ``yield``, ``yield from`` or ``await``). - The behavior differs from :func:`unittest.mock.patch` for generators. + Since asynctest 0.13, each instance of the generator or coroutine will have + its own set of patches. When several instances of the same coroutine are + running concurrently, they will patch the same target. If ``new`` is left + to its default vaue, the patched target is subject to a race condition. At + some point, its value might not be the same as the one passed to the extra + argument. + + To avoid this problem, ``scope`` should be set to + :const:`~asynctest.LIMITED`, or ``new`` should be specified. When used as a context manager, the patch is still active even if the generator or coroutine is paused, which may affect concurrent tasks:: @@ -1202,6 +1219,9 @@ def test_coro(self, mock_function1, mock_function2): .. versionadded:: 0.6 patch into generators and coroutines with a decorator. + + .. versionadded:: 0.13 patchs are now associated with a generator or + coroutine instance instead of the function. """ getter, attribute = unittest.mock._get_target(target) patcher = _patch(getter, attribute, new, spec, create, spec_set, autospec, @@ -1263,6 +1283,12 @@ def __init__(self, in_dict, values=(), clear=False, scope=GLOBAL, self._is_started = False self._global_patchings = [] + def copy(self): + patcher = _patch_dict(self.in_dict, self.values, self.clear, + self.scope) + patcher._global_patchings = [p.copy() for p in self._global_patchings] + return patcher + def _keep_global_patch(self, other_patching): self._global_patchings.append(other_patching) diff --git a/test/test_mock.py b/test/test_mock.py index 6692aa8..8ca5842 100644 --- a/test/test_mock.py +++ b/test/test_mock.py @@ -1075,6 +1075,62 @@ async def a_native_coroutine(fut, mock): run_coroutine(tester(a_native_coroutine)) + def test_mock_on_patched_coroutine_is_a_new_mock_for_each_call(self): + # See bug #121 + issued_mocks = set() + + + with self.subTest("old style coroutine"): + @asynctest.mock.patch("test.test_mock.Test") + @asyncio.coroutine + def store_mock_from_patch(mock): + issued_mocks.add(mock) + + run_coroutine(store_mock_from_patch()) + run_coroutine(store_mock_from_patch()) + + self.assertEqual(2, len(issued_mocks)) + + issued_mocks.clear() + + with self.subTest("native coroutine"): + @asynctest.mock.patch("test.test_mock.Test") + async def store_mock_from_patch(mock): + issued_mocks.add(mock) + + run_coroutine(store_mock_from_patch()) + run_coroutine(store_mock_from_patch()) + + self.assertEqual(2, len(issued_mocks)) + + def test_concurrent_patches_dont_affect_each_other(self): + with self.subTest("old style coroutine"): + @asynctest.patch("test.test_mock.Test") + @asyncio.coroutine + def test_patch(recursive, *_): + before = Test + + if recursive: + yield from test_patch(False) + + # mock must still apply + self.assertEqual(before, Test) + + run_coroutine(test_patch(True)) + + with self.subTest("native coroutine"): + @asynctest.patch("test.test_mock.Test") + async def test_patch_native(recursive, *_): + before = Test + + if recursive: + await test_patch_native(False) + + # mock must still apply + self.assertEqual(before, Test) + + run_coroutine(test_patch(True)) + class Test_patch_object(unittest.TestCase): def test_patch_with_MagicMock(self): @@ -1199,6 +1255,31 @@ def test_a_function(self): self.assertTrue(run_coroutine(instance.test_a_coroutine())) self.assertFalse(test.test_mock.Test().a_dict['is_patched']) + def test_patch_decorates_concurrent_coroutines(self): + # See bug #121 + with self.subTest("old style coroutine"): + @patch_dict_is_patched() + @asyncio.coroutine + def a_coroutine(recursive): + import test.test_mock + self.assertTrue(test.test_mock.Test().a_dict['is_patched']) + if recursive: + yield from a_coroutine(False) + self.assertTrue(test.test_mock.Test().a_dict['is_patched']) + + run_coroutine(a_coroutine(True)) + + with self.subTest("native coroutine"): + @patch_dict_is_patched() + async def a_native_coroutine(recursive): + import test.test_mock + self.assertTrue(test.test_mock.Test().a_dict['is_patched']) + if recursive: + await a_native_coroutine(False) + self.assertTrue(test.test_mock.Test().a_dict['is_patched']) + + run_coroutine(a_native_coroutine(True)) + class Test_patch_autospec(unittest.TestCase): test_class_path = "{}.Test".format(__name__)