Skip to content

Commit

Permalink
Static-analysis friendly ways to replace methods on classes
Browse files Browse the repository at this point in the history
  • Loading branch information
cjw296 committed Feb 8, 2023
1 parent e964516 commit f49ecd0
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 2 deletions.
1 change: 1 addition & 0 deletions testfixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __repr__(self):
Replace,
replace,
replace_in_environ,
replace_on_class,
)
from testfixtures.shouldraise import ShouldRaise, should_raise, ShouldAssert
from testfixtures.shouldwarn import ShouldWarn, ShouldNotWarn
Expand Down
42 changes: 41 additions & 1 deletion testfixtures/replace.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from contextlib import contextmanager
from functools import partial
from gc import get_referrers, get_referents
from operator import setitem, getitem
from types import ModuleType
from typing import Any, TypeVar, Callable, Dict, Tuple

from testfixtures.resolve import resolve, not_there, Resolved
from testfixtures.resolve import resolve, not_there, Resolved, classmethod_type, class_type
from testfixtures.utils import wrap, extend_docstring

import warnings
Expand Down Expand Up @@ -125,6 +127,36 @@ def in_environ(self, name: str, replacement: Any) -> None:
self(os.environ, name=name, accessor=getitem, strict=False,
replacement=not_there if replacement is not_there else str(replacement))

def _find_container(self, attribute, name: str, break_on_static: bool):
for referrer in get_referrers(attribute):
if break_on_static and isinstance(referrer, staticmethod):
return None, referrer
elif isinstance(referrer, dict) and '__dict__' in referrer:
if referrer.get(name) is attribute:
for container in get_referrers(referrer):
if isinstance(container, type):
return container, None
return None, None

def on_class(self, attribute: Callable, replacement: Any, name: str = None) -> None:
if not callable(attribute):
raise TypeError('attribute must be callable')
name = name or getattr(attribute, '__name__', None)
container = None
if isinstance(attribute, classmethod_type):
for referred in get_referents(attribute):
if isinstance(referred, class_type):
container = referred
else:
container, staticmethod_ = self._find_container(attribute, name, break_on_static=True)
if staticmethod_ is not None:
container, _ = self._find_container(staticmethod_, name, break_on_static=False)

if container is None:
raise AttributeError(f'could not find container of {attribute!r} using name {name!r}')

self(container, name=name, accessor=getattr, replacement=replacement)

def restore(self) -> None:
"""
Restore all the original objects that have been replaced by
Expand Down Expand Up @@ -170,6 +202,14 @@ def replace_in_environ(name: str, replacement: Any):
with Replacer() as r:
r.in_environ(name, replacement)
yield


@contextmanager
def replace_on_class(attribute: Callable, replacement: Any, name: str = None):
with Replacer() as r:
r.on_class(attribute, replacement, name)
yield

class Replace(object):
"""
A context manager that uses a :class:`Replacer` to replace a single target.
Expand Down
15 changes: 15 additions & 0 deletions testfixtures/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,18 @@ def resolve(dotted_name: str, container: Optional[Any] = None) -> Resolved:
found = getattr(found, name)
setter = getattr
return Resolved(container, setter, name, found)


class _Reference:

@classmethod
def classmethod(cls): # pragma: no cover
pass

@staticmethod
def staticmethod(cls): # pragma: no cover
pass


class_type = type(_Reference)
classmethod_type = type(_Reference.classmethod)
180 changes: 179 additions & 1 deletion testfixtures/tests/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
compare,
not_there,
replace_in_environ,
replace_on_class,
)
from unittest import TestCase

Expand All @@ -17,7 +18,7 @@
from testfixtures.mock import Mock
from testfixtures.tests import sample1, sample3
from testfixtures.tests import sample2
from .sample1 import z
from .sample1 import z, X
from .sample3 import SOME_CONSTANT
from ..compat import PY_310_PLUS

Expand Down Expand Up @@ -713,10 +714,187 @@ def test_ensure_not_present(self):
assert 'TESTFIXTURES_SAMPLE_KEY_PRESENT' not in os.environ


class TestOnClass:

def test_method_on_class(self):

class SampleClass:

def method(self, x):
return x*2

sample = SampleClass()

with Replacer() as replace:
replace.on_class(SampleClass.method, lambda self, x: x*3)
compare(sample.method(1), expected=3)

compare(sample.method(1), expected=2)

def test_method_on_instance(self):

class SampleClass:

def method(self, x):
return x*2

sample = SampleClass()

with Replacer() as replace:
with ShouldRaise(AttributeError):
replace.on_class(sample.method, lambda self, x: x*3)

# ...so use explicit and non-strict:
replace(sample.method, lambda x: x * 3, container=sample, strict=False)

compare(sample.method(1), expected=3)

compare(sample.method(1), expected=2)

def test_badly_decorated_method(self):

def bad(f):
def inner(self, x):
return f(self, x)
return inner

class SampleClass:

@bad
def method(self, x):
return x*2

sample = SampleClass()

with Replacer() as replace:

# without the name, we get a useful error:
with ShouldRaise(AttributeError(
f"could not find container of {SampleClass.method} using name 'inner'"
)):
replace.on_class(SampleClass.method, lambda self_, x: x*3)

assert SampleClass.__dict__['method'] is original
replace.on_class(SampleClass.method, lambda self_, x: x*3, name='method')
compare(sample.method(1), expected=3)

compare(sample.method(1), expected=2)

def test_classmethod(self):

class SampleClass:

@classmethod
def method(cls, x):
return x*2

with Replacer() as replace:
replace.on_class(SampleClass.method, classmethod(lambda cls, x: x*3))
compare(SampleClass.method(1), expected=3)

compare(SampleClass.method(1), expected=2)

def test_staticmethod(self):

class SampleClass:

@staticmethod
def method(x):
return x*2

with Replacer() as replace:
replace.on_class(SampleClass.method, lambda x: x*3)
compare(SampleClass.method(1), expected=3)

compare(SampleClass.method(1), expected=2)

def test_not_callable(self):

class SampleClass:

FOO = 1

sample = SampleClass()

replace = Replacer()
with ShouldRaise(TypeError('attribute must be callable')):
replace.on_class(SampleClass.FOO, 2)
compare(sample.FOO, expected=1)

def test_method_on_class_in_module(self):
sample = X()

with Replacer() as replace:
replace.on_class(X.y, lambda self_: 'replacement y')
compare(sample.y(), expected='replacement y')

compare(sample.y(), expected='original y')

def test_method_on_instance_in_module(self):

sample = X()

with Replacer() as replace:
replace(sample.y, lambda: 'replacement y', container=sample, strict=False)
compare(sample.y(), expected='replacement y')

compare(sample.y(), expected='original y')

def test_classmethod_on_class_in_module(self):

with Replacer() as replace:
replace.on_class(X.aMethod, classmethod(lambda cls: (cls, cls)))
compare(X.aMethod(), expected=(X, X))

compare(X.aMethod(), expected=X)

def test_classmethod_on_instance_in_module(self):

sample = X()

with Replacer() as replace:
replace.on_class(sample.aMethod, classmethod(lambda cls: (cls, cls)))
compare(sample.aMethod(), expected=(X, X))

compare(sample.aMethod(), expected=X)

def test_staticmethod_on_class_in_module(self):

with Replacer() as replace:
replace.on_class(X.bMethod, lambda: 3)
compare(X.bMethod(), expected=3)

compare(X.bMethod(), expected=2)

def test_staticmethod_on_instance_in_module(self):

sample = X()

with Replacer() as replace:
replace(sample.bMethod, lambda: 3, container=sample, strict=False)
compare(sample.bMethod(), expected=3)

compare(X.bMethod(), expected=2)


class TestConvenience:

def test_environ(self):
os.environ['TESTFIXTURES_SAMPLE_KEY_PRESENT'] = 'ORIGINAL'
with replace_in_environ('TESTFIXTURES_SAMPLE_KEY_PRESENT', 'NEW'):
compare(os.environ['TESTFIXTURES_SAMPLE_KEY_PRESENT'], expected='NEW')
compare(os.environ['TESTFIXTURES_SAMPLE_KEY_PRESENT'], expected='ORIGINAL')

def test_on_class(self):

class SampleClass:

def method(self, x):
return x*2

sample = SampleClass()

with replace_on_class(SampleClass.method, lambda self, x: x*3, name='method'):
compare(sample.method(1), expected=3)

compare(sample.method(1), expected=2)

0 comments on commit f49ecd0

Please sign in to comment.