Skip to content

Commit

Permalink
Add more tests to field and defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Oct 14, 2023
1 parent 68f2346 commit c3c9898
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 5 deletions.
52 changes: 50 additions & 2 deletions polyforce/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Type, TypedDict, Union

from typing_extensions import Annotated, Self, Unpack, get_args
from typing_extensions import Annotated, Self, Unpack, _SpecialForm, get_args

from ._internal import _representation
from .core import _utils
Expand Down Expand Up @@ -87,11 +87,59 @@ def __init__(self, **kwargs: Unpack[_FieldInputs]) -> None:
if self.default is not PolyforceUndefined and self.default_factory is not None:
raise TypeError("cannot specify both default and default_factory")

self.title = kwargs.pop("title", None)
self.name = kwargs.pop("name", None)
self._validate_default_with_annotation()
self.title = kwargs.pop("title", None)
self.description = kwargs.pop("description", None)
self.metadata = metadata

def _extract_type_hint(self, type_hint: Union[Type, tuple]) -> Any:
"""
Extracts the base type from a type hint, considering typing extensions.
This function checks if the given type hint is a generic type hint and extracts
the base type. If not, it returns the original type hint.
Args:
type_hint (Union[Type, tuple]): The type hint to extract the base type from.
Returns:
Union[Type, tuple]: The base type of the type hint or the original type hint.
Example:
```
from typing import List, Union
# Extract the base type from a List hint
base_type = extract_type_hint(List[int]) # Returns int
# If the hint is not a generic type, it returns the original hint
original_hint = extract_type_hint(Union[int, str]) # Returns Union[int, str]
```
"""

origin = getattr(type_hint, "__origin__", type_hint)
if isinstance(origin, _SpecialForm):
origin = type_hint.__args__ # type: ignore
return origin

def _validate_default_with_annotation(self) -> None:
"""
Validates if the default is allowed for the type of annotation
generated by the field.
"""
if not self.default or self.default == PolyforceUndefined:
return None

default = self.default() if callable(self.default) else self.default

type_hint = self._extract_type_hint(self.annotation)
if not isinstance(default, type_hint):
raise TypeError(
f"default '{type(default).__name__}' for field '{self.name}' is not valid for the field type annotation, it must be type '{self.annotation.__name__}'"
)
self.default = default

@classmethod
def _extract_annotation(
cls, annotation: Union[Type[Any], None]
Expand Down
55 changes: 52 additions & 3 deletions tests/fields/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,76 @@
from typing import Any, Dict, List, Mapping, Union

import pytest

from polyforce import PolyField, PolyModel
from polyforce.exceptions import ValidationError


class Model(PolyModel):
def __init__(self, name: str, age: Union[str, int]) -> None:
...

def create_model(self, names: List[str]) -> None:
...
return names

def get_model(self, models: Dict[str, Any]) -> Dict[str, Any]:
...
return models

def set_model(self, models: Mapping[str, PolyModel]) -> None:
...
return models


def test_can_create_polyfield():
field = PolyField(annotation=str, name="field")
assert field is not None
assert field.annotation == str
assert field.name == "field"
assert field.is_required() is True


def test_raise_type_error_on_default_field():
with pytest.raises(TypeError) as raised:
PolyField(annotation=str, default=2, name="name")

assert (
raised.value.args[0]
== "default 'int' for field 'name' is not valid for the field type annotation, it must be type 'str'"
)


def test_default_field():
default = "john"

def get_default():
nonlocal default
return default

field = PolyField(annotation=str, default=get_default, name="name")
assert field.default == default


def test_functions():
model = Model(name="PolyModel", age=1)

names = model.create_model(names=["poly"])
assert names == ["poly"]

models = model.get_model(models={"name": "poly"})
assert models == {"name": "poly"}

models = model.set_model(models={"name": "poly"})
assert models == {"name": "poly"}


@pytest.mark.parametrize("func", ["get_model", "set_model"])
def test_functions_raises_validation_error(func):
model = Model(name="PolyModel", age=1)

with pytest.raises(ValidationError):
model.create_model(names="a")

with pytest.raises(ValidationError):
getattr(model, func)(models="a")


def test_poly_fields():
Expand Down

0 comments on commit c3c9898

Please sign in to comment.