Skip to content

Commit

Permalink
Refactor MockCache to have a narrow interface
Browse files Browse the repository at this point in the history
It should also be responsible for stopping the patchers, instead of acting merely as storage.

Follow up the previous commit.
  • Loading branch information
nicoddemus committed Mar 21, 2024
1 parent 4faf92a commit 5257e3c
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions src/pytest_mock/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,33 +49,37 @@ class MockCacheItem:

@dataclass
class MockCache:
"""
Cache MagicMock and Patcher instances so we can undo them later.
"""

cache: List[MockCacheItem] = field(default_factory=list)

def find(self, mock: MockType) -> MockCacheItem:
the_mock = next(
(mock_item for mock_item in self.cache if mock_item.mock == mock), None
)
if the_mock is None:
raise ValueError("This mock object is not registered")
return the_mock
def _find(self, mock: MockType) -> MockCacheItem:
for mock_item in self.cache:
if mock_item.mock is mock:
return mock_item
raise ValueError("This mock object is not registered")

def add(self, mock: MockType, **kwargs: Any) -> MockCacheItem:
self.cache.append(MockCacheItem(mock=mock, **kwargs))
return self.cache[-1]

def remove(self, mock: MockType) -> None:
mock_item = self.find(mock)
mock_item = self._find(mock)
if mock_item.patch:
mock_item.patch.stop()
self.cache.remove(mock_item)

def clear(self) -> None:
for mock_item in reversed(self.cache):
if mock_item.patch is not None:
mock_item.patch.stop()
self.cache.clear()

def __iter__(self) -> Iterator[MockCacheItem]:
return iter(self.cache)

def __reversed__(self) -> Iterator[MockCacheItem]:
return reversed(self.cache)


class MockerFixture:
"""
Expand Down Expand Up @@ -146,19 +150,13 @@ def stopall(self) -> None:
Stop all patchers started by this fixture. Can be safely called multiple
times.
"""
for mock_item in reversed(self._mock_cache):
if mock_item.patch is not None:
mock_item.patch.stop()
self._mock_cache.clear()

def stop(self, mock: unittest.mock.MagicMock) -> None:
"""
Stops a previous patch or spy call by passing the ``MagicMock`` object
returned by it.
"""
mock_item = self._mock_cache.find(mock)
if mock_item.patch:
mock_item.patch.stop()
self._mock_cache.remove(mock)

def spy(self, obj: object, name: str) -> MockType:
Expand Down

0 comments on commit 5257e3c

Please sign in to comment.