diff --git a/testfixtures/__init__.py b/testfixtures/__init__.py index 7ab126e..d2dd079 100644 --- a/testfixtures/__init__.py +++ b/testfixtures/__init__.py @@ -25,6 +25,7 @@ def __repr__(self): Replace, replace, replace_in_environ, + replace_on_class, ) from testfixtures.shouldraise import ShouldRaise, should_raise, ShouldAssert from testfixtures.shouldwarn import ShouldWarn, ShouldNotWarn diff --git a/testfixtures/replace.py b/testfixtures/replace.py index c41e822..a6404b6 100644 --- a/testfixtures/replace.py +++ b/testfixtures/replace.py @@ -1,10 +1,12 @@ import os from contextlib import contextmanager from functools import partial +from gc import get_referrers, get_referents from operator import setitem, getitem +from types import ModuleType from typing import Any, TypeVar, Callable, Dict, Tuple -from testfixtures.resolve import resolve, not_there, Resolved +from testfixtures.resolve import resolve, not_there, Resolved, classmethod_type, class_type from testfixtures.utils import wrap, extend_docstring import warnings @@ -125,6 +127,36 @@ def in_environ(self, name: str, replacement: Any) -> None: self(os.environ, name=name, accessor=getitem, strict=False, replacement=not_there if replacement is not_there else str(replacement)) + def _find_container(self, attribute, name: str, break_on_static: bool): + for referrer in get_referrers(attribute): + if break_on_static and isinstance(referrer, staticmethod): + return None, referrer + elif isinstance(referrer, dict) and '__dict__' in referrer: + if referrer.get(name) is attribute: + for container in get_referrers(referrer): + if isinstance(container, type): + return container, None + return None, None + + def on_class(self, attribute: Callable, replacement: Any, name: str = None) -> None: + if not callable(attribute): + raise TypeError('attribute must be callable') + name = name or getattr(attribute, '__name__', None) + container = None + if isinstance(attribute, classmethod_type): + for referred in get_referents(attribute): + if isinstance(referred, class_type): + container = referred + else: + container, staticmethod_ = self._find_container(attribute, name, break_on_static=True) + if staticmethod_ is not None: + container, _ = self._find_container(staticmethod_, name, break_on_static=False) + + if container is None: + raise AttributeError(f'could not find container of {attribute!r} using name {name!r}') + + self(container, name=name, accessor=getattr, replacement=replacement) + def restore(self) -> None: """ Restore all the original objects that have been replaced by @@ -170,6 +202,14 @@ def replace_in_environ(name: str, replacement: Any): with Replacer() as r: r.in_environ(name, replacement) yield + + +@contextmanager +def replace_on_class(attribute: Callable, replacement: Any, name: str = None): + with Replacer() as r: + r.on_class(attribute, replacement, name) + yield + class Replace(object): """ A context manager that uses a :class:`Replacer` to replace a single target. diff --git a/testfixtures/resolve.py b/testfixtures/resolve.py index 65cfcac..aa83f3c 100644 --- a/testfixtures/resolve.py +++ b/testfixtures/resolve.py @@ -63,3 +63,18 @@ def resolve(dotted_name: str, container: Optional[Any] = None) -> Resolved: found = getattr(found, name) setter = getattr return Resolved(container, setter, name, found) + + +class _Reference: + + @classmethod + def classmethod(cls): # pragma: no cover + pass + + @staticmethod + def staticmethod(cls): # pragma: no cover + pass + + +class_type = type(_Reference) +classmethod_type = type(_Reference.classmethod) diff --git a/testfixtures/tests/test_replace.py b/testfixtures/tests/test_replace.py index daa8e8c..67bc519 100644 --- a/testfixtures/tests/test_replace.py +++ b/testfixtures/tests/test_replace.py @@ -9,6 +9,7 @@ compare, not_there, replace_in_environ, + replace_on_class, ) from unittest import TestCase @@ -17,7 +18,7 @@ from testfixtures.mock import Mock from testfixtures.tests import sample1, sample3 from testfixtures.tests import sample2 -from .sample1 import z +from .sample1 import z, X from .sample3 import SOME_CONSTANT from ..compat import PY_310_PLUS @@ -713,6 +714,169 @@ def test_ensure_not_present(self): assert 'TESTFIXTURES_SAMPLE_KEY_PRESENT' not in os.environ +class TestOnClass: + + def test_method_on_class(self): + + class SampleClass: + + def method(self, x): + return x*2 + + sample = SampleClass() + + with Replacer() as replace: + replace.on_class(SampleClass.method, lambda self, x: x*3) + compare(sample.method(1), expected=3) + + compare(sample.method(1), expected=2) + + def test_method_on_instance(self): + + class SampleClass: + + def method(self, x): + return x*2 + + sample = SampleClass() + + with Replacer() as replace: + with ShouldRaise(AttributeError): + replace.on_class(sample.method, lambda self, x: x*3) + + # ...so use explicit and non-strict: + replace(sample.method, lambda x: x * 3, container=sample, strict=False) + + compare(sample.method(1), expected=3) + + compare(sample.method(1), expected=2) + + def test_badly_decorated_method(self): + + def bad(f): + def inner(self, x): + return f(self, x) + return inner + + class SampleClass: + + @bad + def method(self, x): + return x*2 + + sample = SampleClass() + + with Replacer() as replace: + + # without the name, we get a useful error: + with ShouldRaise(AttributeError( + f"could not find container of {SampleClass.method} using name 'inner'" + )): + replace.on_class(SampleClass.method, lambda self_, x: x*3) + + assert SampleClass.__dict__['method'] is original + replace.on_class(SampleClass.method, lambda self_, x: x*3, name='method') + compare(sample.method(1), expected=3) + + compare(sample.method(1), expected=2) + + def test_classmethod(self): + + class SampleClass: + + @classmethod + def method(cls, x): + return x*2 + + with Replacer() as replace: + replace.on_class(SampleClass.method, classmethod(lambda cls, x: x*3)) + compare(SampleClass.method(1), expected=3) + + compare(SampleClass.method(1), expected=2) + + def test_staticmethod(self): + + class SampleClass: + + @staticmethod + def method(x): + return x*2 + + with Replacer() as replace: + replace.on_class(SampleClass.method, lambda x: x*3) + compare(SampleClass.method(1), expected=3) + + compare(SampleClass.method(1), expected=2) + + def test_not_callable(self): + + class SampleClass: + + FOO = 1 + + sample = SampleClass() + + replace = Replacer() + with ShouldRaise(TypeError('attribute must be callable')): + replace.on_class(SampleClass.FOO, 2) + compare(sample.FOO, expected=1) + + def test_method_on_class_in_module(self): + sample = X() + + with Replacer() as replace: + replace.on_class(X.y, lambda self_: 'replacement y') + compare(sample.y(), expected='replacement y') + + compare(sample.y(), expected='original y') + + def test_method_on_instance_in_module(self): + + sample = X() + + with Replacer() as replace: + replace(sample.y, lambda: 'replacement y', container=sample, strict=False) + compare(sample.y(), expected='replacement y') + + compare(sample.y(), expected='original y') + + def test_classmethod_on_class_in_module(self): + + with Replacer() as replace: + replace.on_class(X.aMethod, classmethod(lambda cls: (cls, cls))) + compare(X.aMethod(), expected=(X, X)) + + compare(X.aMethod(), expected=X) + + def test_classmethod_on_instance_in_module(self): + + sample = X() + + with Replacer() as replace: + replace.on_class(sample.aMethod, classmethod(lambda cls: (cls, cls))) + compare(sample.aMethod(), expected=(X, X)) + + compare(sample.aMethod(), expected=X) + + def test_staticmethod_on_class_in_module(self): + + with Replacer() as replace: + replace.on_class(X.bMethod, lambda: 3) + compare(X.bMethod(), expected=3) + + compare(X.bMethod(), expected=2) + + def test_staticmethod_on_instance_in_module(self): + + sample = X() + + with Replacer() as replace: + replace(sample.bMethod, lambda: 3, container=sample, strict=False) + compare(sample.bMethod(), expected=3) + + compare(X.bMethod(), expected=2) + + class TestConvenience: def test_environ(self): @@ -720,3 +884,17 @@ def test_environ(self): with replace_in_environ('TESTFIXTURES_SAMPLE_KEY_PRESENT', 'NEW'): compare(os.environ['TESTFIXTURES_SAMPLE_KEY_PRESENT'], expected='NEW') compare(os.environ['TESTFIXTURES_SAMPLE_KEY_PRESENT'], expected='ORIGINAL') + + def test_on_class(self): + + class SampleClass: + + def method(self, x): + return x*2 + + sample = SampleClass() + + with replace_on_class(SampleClass.method, lambda self, x: x*3, name='method'): + compare(sample.method(1), expected=3) + + compare(sample.method(1), expected=2)