Skip to content

Commit

Permalink
Fix inheritance for polyforce
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Oct 10, 2023
1 parent 8ed4b24 commit 2b64732
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 20 deletions.
36 changes: 24 additions & 12 deletions polyforce/_internal/_construction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from inspect import Parameter, Signature
from itertools import islice
from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple, Type, cast
from typing import TYPE_CHECKING, Any, Dict, List, Set, Type, cast

from polyforce.exceptions import MissingAnnotation, ReturnSignatureMissing

Expand All @@ -18,10 +18,13 @@ def complete_poly_class(cls: Type["PolyModel"], config: ConfigWrapper) -> bool:
"""
methods: List[str] = [
attr
for attr in dir(cls)
for attr in cls.__dict__.keys()
if not attr.startswith("__") and not attr.endswith("__") and callable(getattr(cls, attr))
]
methods.append("__init__")

if "__init__" in cls.__dict__:
methods.append("__init__")

signatures: Dict[str, Signature] = {}

for method in methods:
Expand All @@ -31,22 +34,32 @@ def complete_poly_class(cls: Type["PolyModel"], config: ConfigWrapper) -> bool:
return True


def ignore_signature(signature: Signature) -> Signature:
"""
Ignores the signature and assigns the Any to all the fields and return signature.
"""
merged_params: Dict[str, Parameter] = {}
for param in islice(signature.parameters.values(), 1, None):
param = param.replace(annotation=Any)
merged_params[param.name] = param
return Signature(parameters=list(merged_params.values()), return_annotation=Any)


def generate_model_signature(
cls: Type["PolyModel"], value: str, config: ConfigWrapper
) -> Signature:
"""
Generates a signature for each method of the given class.
"""
func = getattr(cls, value)
func_signature = Signature.from_callable(func)
signature = Signature.from_callable(func)

if config.ignore:
return func_signature
return ignore_signature(signature)

params = func_signature.parameters.values()
params = signature.parameters.values()
merged_params: Dict[str, Parameter] = {}

if func_signature.return_annotation == inspect.Signature.empty:
if signature.return_annotation == inspect.Signature.empty:
raise ReturnSignatureMissing(func=value)

for param in islice(params, 1, None): # skip self arg
Expand All @@ -63,7 +76,7 @@ def generate_model_signature(

# Generate the new signatures.
return Signature(
parameters=list(merged_params.values()), return_annotation=func_signature.return_annotation
parameters=list(merged_params.values()), return_annotation=signature.return_annotation
)


Expand Down Expand Up @@ -97,16 +110,15 @@ def __new__(cls, name: str, bases: Any, attrs: Any) -> Any:
return super().__new__(cls, name, bases, attrs)

@staticmethod
def _collect_data_from_bases(bases: Any) -> Tuple[Set[str], Set[str]]:
def _collect_data_from_bases(bases: Any) -> Set[str]:
"""
Collects all the data from the bases.
"""
from ..main import PolyModel

field_names: Set[str] = set()
class_vars: Set[str] = set()

for base in bases:
if issubclass(base, PolyModel) and base is not PolyModel:
class_vars.update(base.__class_vars__)
return field_names, class_vars
return class_vars
2 changes: 1 addition & 1 deletion polyforce/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __str__(self) -> str:

class ReturnSignatureMissing(PolyException):
detail: Union[str, None] = (
"Missing return: {func}. A return value of a function should be type annotated. "
"Missing return in '{func}'. A return value of a function should be type annotated. "
"If your function doesn't return a value or returns None, annotate it as returning 'NoReturn' or 'None' respectively."
)

Expand Down
31 changes: 24 additions & 7 deletions polyforce/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from inspect import Signature
from inspect import Parameter, Signature
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Set, _SpecialForm

from typing_extensions import get_args, get_origin
Expand Down Expand Up @@ -32,6 +32,15 @@ class MyObject(PolyModel):
config = Config()

def __getattribute__(self, __name: str) -> Any:
"""
Special action where it adds the static check validation
for the data being passed.
It checks if the values are properly checked and validated with
the right types.
The class version of the decorator `polyforce.decorator.polycheck`.
"""
try:
func = object.__getattribute__(self, __name)
signatures = object.__getattribute__(self, "__signature__")
Expand All @@ -41,11 +50,19 @@ def polycheck(*args: Any, **kwargs: Any) -> Any:
nonlocal signature
nonlocal func
params = dict(zip(signature.parameters.values(), args))
params_from_kwargs = {
signature.parameters.get(key, type(value)): value
for key, value in kwargs.items()
}
params.update(params_from_kwargs) # type: ignore
params_from_kwargs: Dict[Parameter, Any] = {}

for key, value in kwargs.items():
parameter = signature.parameters.get(key)
if parameter:
params_from_kwargs[parameter] = value
continue

params_from_kwargs[
Parameter(name=key, kind=Parameter.KEYWORD_ONLY, annotation=type(value))
] = value

params.update(params_from_kwargs)

for parameter, value in params.items():
type_hint = parameter.annotation
Expand Down Expand Up @@ -74,7 +91,7 @@ def polycheck(*args: Any, **kwargs: Any) -> Any:
f" but received type '{type(value)}' instead."
)

return func(*args, **kwargs)
return func(*args, **kwargs)

return polycheck

Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from polyforce import PolyModel


class User(PolyModel):
...


class Profile(User):
def __init__(self) -> None:
super().__init__()

def get_name(self, name: str) -> str:
return name


def test_can_inherit():
profile = Profile()
name = profile.get_name("poly")

assert name == "poly"


def test_ignores_checks():
class NewUser(PolyModel):
config = {"ignore": True}

class NewProfile(NewUser):
def __init__(self):
super().__init__()

def get_name(self, name):
return name

profile = NewProfile()
name = profile.get_name(1)
assert name == 1

0 comments on commit 2b64732

Please sign in to comment.