Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

late decoration #73

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions enforce/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
from collections import OrderedDict
from multiprocessing import RLock
from typing import ForwardRef

from wrapt import decorator, ObjectProxy

Expand Down Expand Up @@ -65,6 +66,38 @@ def runtime_validation(data=None, *, enabled=None, group=None):
return generate_decorated()


def _should_decorate_later(data):
for annotation_val in data.__annotations__.values():
if not isinstance(annotation_val, str):
continue
# test if we can reference the class
frame = inspect.stack()[5].frame
try:
typing._eval_type(_ForwardRef(annotation_val), frame.f_globals, frame.f_locals)
except NameError:
# this indicates that late binding is in order
return True
return False


def _decorate(data, configuration, obj_instance=None, parent_root=None, stack_depth=1) -> typing.Callable:
data = apply_enforcer(data, parent_root=parent_root, settings=configuration)
universal = get_universal_decorator(stack_depth=stack_depth)
return universal(data)


def _decorate_later(data, configuration, obj_instance=None, parent_root=None) -> typing.Callable:
enforced = None

def wrap(*args, **kwargs):
nonlocal enforced, data
if enforced is None:
enforced = _decorate(data, configuration, obj_instance, parent_root, stack_depth=2)
return enforced(*args, **kwargs)

return wrap


def decorate(data, configuration, obj_instance=None, parent_root=None) -> typing.Callable:
"""
Performs the function decoration with a type checking wrapper
Expand All @@ -74,14 +107,13 @@ def decorate(data, configuration, obj_instance=None, parent_root=None) -> typing
if not hasattr(data, '__annotations__'):
return data

data = apply_enforcer(data, parent_root=parent_root, settings=configuration)

universal = get_universal_decorator()

return universal(data)
if _should_decorate_later(data):
return _decorate_later(data, configuration, obj_instance, parent_root)
else:
return _decorate(data, configuration, obj_instance, parent_root)


def get_universal_decorator():
def get_universal_decorator(stack_depth=1):
def universal(wrapped, instance, args, kwargs):
"""
This function will be returned by the decorator. It adds type checking before triggering
Expand All @@ -107,7 +139,7 @@ def universal(wrapped, instance, args, kwargs):
else:
parameters = Parameters(args, kwargs, skip)

frame = inspect.stack()[2].frame
frame = inspect.stack()[stack_depth].frame
outer_locals = frame.f_locals
outer_globals = frame.f_globals

Expand Down
71 changes: 71 additions & 0 deletions tests/test_forward_my_cls_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import unittest

from enforce.exceptions import RuntimeTypeError

from enforce import runtime_validation


class TestForwardMyClsReference(unittest.TestCase):

def test_output_ok(self):
class A:

@runtime_validation
def clone(self) -> 'A':
return A()

val = A().clone()
self.assertIsInstance(val, A)

def test_output_fail(self):
try:
class A:

@runtime_validation
def clone(self) -> 'A':
return 'str'

A().clone()
except RuntimeTypeError:
pass
else:
raise Exception('A typerror should have been raised')

def test_input_ok(self):
class A:

@runtime_validation
def __eq__(self, other: 'A') -> bool:
return self is other

a = A()
assert a == a

def test_input_not_ok(self):
class A:

@runtime_validation
def __eq__(self, other: 'A') -> bool:
return self is other

a = A()
try:
a == 'a'
except RuntimeTypeError:
pass
else:
raise Exception('A typerror should have been raised')

def test_input_fwd_ref_other_type(self):
class B:
pass

class A:

@runtime_validation
def __eq__(self, other: 'B') -> bool:
return self is other

a = A()
b = B()
assert a != b