Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bind patches to coroutine instances to avoid concurrency issues #123

Merged
merged 3 commits into from
May 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions asynctest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,13 +946,20 @@ def _decorate_coroutine_callable(func, new_patching):
patchings = [new_patching]

def patched_factory(*args, **kwargs):
# Patches must be copied for each instance of the coroutine, to avoid
# concurrency issues. If we don't do that the first coroutine instance
# to terminate will deactivate the patch and break the other
# coroutines.
# If one wants to ensure the same patch mock is used for concurrent
# coroutines, it must set it explicitly.
local_patchings = [patching.copy() for patching in patchings]
extra_args = []
patchers_to_exit = []
patch_dict_with_limited_scope = []

exc_info = tuple()
try:
for patching in patchings:
for patching in local_patchings:
arg = patching.__enter__()
if patching.scope == LIMITED:
patchers_to_exit.append(patching)
Expand All @@ -974,7 +981,7 @@ def patched_factory(*args, **kwargs):

args += tuple(extra_args)
gen = func(*args, **kwargs)
return _PatchedGenerator(gen, patchings,
return _PatchedGenerator(gen, local_patchings,
asyncio.iscoroutinefunction(func))
except BaseException:
if patching not in patchers_to_exit and _is_started(patching):
Expand Down Expand Up @@ -1090,6 +1097,7 @@ def copy(self):

def __enter__(self):
# When patching a coroutine, we reuse the same mock object
# for the whole instance of the coroutine
if self.mock_to_reuse is not None:
self.target = self.getter()
self.temp_original, self.is_local = self.get_original()
Expand Down Expand Up @@ -1147,7 +1155,8 @@ def patch(target, new=DEFAULT, spec=None, create=False, spec_set=None,
``new`` specifies which object will replace the ``target`` when the patch
is applied. By default, the target will be patched with an instance of
:class:`~asynctest.CoroutineMock` if it is a coroutine, or
a :class:`~asynctest.MagicMock` object.
a :class:`~asynctest.MagicMock` object. In this case, the mock is passed as
an extra positional argument.

It is a replacement to :func:`unittest.mock.patch`, but using
:mod:`asynctest.mock` objects.
Expand All @@ -1163,7 +1172,15 @@ def patch(target, new=DEFAULT, spec=None, create=False, spec_set=None,
yields a value and pauses its execution (with ``yield``, ``yield from``
or ``await``).

The behavior differs from :func:`unittest.mock.patch` for generators.
Since asynctest 0.13, each instance of the generator or coroutine will have
its own set of patches. When several instances of the same coroutine are
running concurrently, they will patch the same target. If ``new`` is left
to its default vaue, the patched target is subject to a race condition. At
some point, its value might not be the same as the one passed to the extra
argument.

To avoid this problem, ``scope`` should be set to
:const:`~asynctest.LIMITED`, or ``new`` should be specified.

When used as a context manager, the patch is still active even if the
generator or coroutine is paused, which may affect concurrent tasks::
Expand Down Expand Up @@ -1202,6 +1219,9 @@ def test_coro(self, mock_function1, mock_function2):

.. versionadded:: 0.6 patch into generators and coroutines with
a decorator.

.. versionadded:: 0.13 patchs are now associated with a generator or
coroutine instance instead of the function.
"""
getter, attribute = unittest.mock._get_target(target)
patcher = _patch(getter, attribute, new, spec, create, spec_set, autospec,
Expand Down Expand Up @@ -1263,6 +1283,12 @@ def __init__(self, in_dict, values=(), clear=False, scope=GLOBAL,
self._is_started = False
self._global_patchings = []

def copy(self):
patcher = _patch_dict(self.in_dict, self.values, self.clear,
self.scope)
patcher._global_patchings = [p.copy() for p in self._global_patchings]
return patcher

def _keep_global_patch(self, other_patching):
self._global_patchings.append(other_patching)

Expand Down
81 changes: 81 additions & 0 deletions test/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,62 @@ async def a_native_coroutine(fut, mock):

run_coroutine(tester(a_native_coroutine))

def test_mock_on_patched_coroutine_is_a_new_mock_for_each_call(self):
# See bug #121
issued_mocks = set()


with self.subTest("old style coroutine"):
@asynctest.mock.patch("test.test_mock.Test")
@asyncio.coroutine
def store_mock_from_patch(mock):
issued_mocks.add(mock)

run_coroutine(store_mock_from_patch())
run_coroutine(store_mock_from_patch())

self.assertEqual(2, len(issued_mocks))

issued_mocks.clear()

with self.subTest("native coroutine"):
@asynctest.mock.patch("test.test_mock.Test")
async def store_mock_from_patch(mock):
issued_mocks.add(mock)

run_coroutine(store_mock_from_patch())
run_coroutine(store_mock_from_patch())

self.assertEqual(2, len(issued_mocks))

def test_concurrent_patches_dont_affect_each_other(self):
with self.subTest("old style coroutine"):
@asynctest.patch("test.test_mock.Test")
@asyncio.coroutine
def test_patch(recursive, *_):
before = Test

if recursive:
yield from test_patch(False)

# mock must still apply
self.assertEqual(before, Test)

run_coroutine(test_patch(True))

with self.subTest("native coroutine"):
@asynctest.patch("test.test_mock.Test")
async def test_patch_native(recursive, *_):
before = Test

if recursive:
await test_patch_native(False)

# mock must still apply
self.assertEqual(before, Test)

run_coroutine(test_patch(True))


class Test_patch_object(unittest.TestCase):
def test_patch_with_MagicMock(self):
Expand Down Expand Up @@ -1199,6 +1255,31 @@ def test_a_function(self):
self.assertTrue(run_coroutine(instance.test_a_coroutine()))
self.assertFalse(test.test_mock.Test().a_dict['is_patched'])

def test_patch_decorates_concurrent_coroutines(self):
# See bug #121
with self.subTest("old style coroutine"):
@patch_dict_is_patched()
@asyncio.coroutine
def a_coroutine(recursive):
import test.test_mock
self.assertTrue(test.test_mock.Test().a_dict['is_patched'])
if recursive:
yield from a_coroutine(False)
self.assertTrue(test.test_mock.Test().a_dict['is_patched'])

run_coroutine(a_coroutine(True))

with self.subTest("native coroutine"):
@patch_dict_is_patched()
async def a_native_coroutine(recursive):
import test.test_mock
self.assertTrue(test.test_mock.Test().a_dict['is_patched'])
if recursive:
await a_native_coroutine(False)
self.assertTrue(test.test_mock.Test().a_dict['is_patched'])

run_coroutine(a_native_coroutine(True))


class Test_patch_autospec(unittest.TestCase):
test_class_path = "{}.Test".format(__name__)
Expand Down