Skip to content

Commit

Permalink
Merge pull request #24 from Ricardaux/master
Browse files Browse the repository at this point in the history
Add support for Union type and improve error management
  • Loading branch information
JGiard authored Jul 19, 2024
2 parents 63c940b + d0baeac commit 1f95777
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/pyckson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pyckson.json import *
from pyckson.parser import parse
from pyckson.parsers.base import Parser
from pyckson.parsers.base import ParserException
from pyckson.serializer import serialize
from pyckson.serializers.base import Serializer
from pyckson.dates.helpers import configure_date_formatter, configure_explicit_nulls
Expand Down
6 changes: 6 additions & 0 deletions src/pyckson/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def is_typing_dict_annotation(annotation):
return False


def is_union_annotation(annotation) -> bool:
if hasattr(annotation, '__name__'):
return annotation.__name__ == 'Union'
return False


def using(attr):
def class_decorator(cls):
set_cls_attr(cls, PYCKSON_RULE_ATTR, attr)
Expand Down
13 changes: 11 additions & 2 deletions src/pyckson/model/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,14 @@ def inspect_optional_typing(annotation) -> Tuple[bool, type]:
else:
union_params = annotation.__union_params__

is_optional = len(union_params) == 2 and isinstance(None, union_params[1])
return is_optional, union_params[0]
try:
is_optional = isinstance(None, union_params[-1])
except TypeError:
is_optional = False
if is_optional:
union_param = Union[union_params[:-1]]
elif len(union_params) > 1:
union_param = Union[union_params]
else:
union_param = union_params[0]
return is_optional, union_param
31 changes: 31 additions & 0 deletions src/pyckson/parsers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from decimal import Decimal
from enum import Enum

class ParserException(Exception):
pass


class Parser:
Expand All @@ -16,22 +20,30 @@ def __init__(self, cls):
self.cls = cls

def parse(self, json_value):
if not isinstance(json_value, self.cls):
raise ParserException(f'"{json_value}" is supposed to be a {self.cls.__name__}.')
return self.cls(json_value)


class ListParser(Parser):
def __init__(self, sub_parser: Parser):
self.sub_parser = sub_parser
self.cls = list

def parse(self, json_value):
if not isinstance(json_value, list):
raise ParserException(f'"{json_value}" is supposed to be a list.')
return [self.sub_parser.parse(item) for item in json_value]


class SetParser(Parser):
def __init__(self, sub_parser: Parser):
self.sub_parser = sub_parser
self.cls = set

def parse(self, json_value):
if not isinstance(json_value, set) and not isinstance(json_value, list):
raise ParserException(f'"{json_value}" is supposed to be a set or a list.')
return {self.sub_parser.parse(item) for item in json_value}


Expand All @@ -40,14 +52,19 @@ def __init__(self, cls):
self.cls = cls

def parse(self, value):
if value not in self.cls.__members__:
raise ParserException(f'"{value}" is not a valid value for "{self.cls.__name__}" Enum.')
return self.cls[value]


class CaseInsensitiveEnumParser(Parser):
def __init__(self, cls):
self.values = {member.name.lower(): member for member in cls}
self.cls = Enum

def parse(self, value):
if value.lower() not in self.values:
raise ParserException(f'"{value}" is not a valid value for "{self.cls.__name__}" Enum.')
return self.values[value.lower()]


Expand Down Expand Up @@ -75,3 +92,17 @@ def parse(self, json_value):
class DecimalParser(Parser):
def parse(self, json_value):
return Decimal(json_value)


class UnionParser(Parser):
def __init__(self, value_parsers: list[Parser]):
self.value_parsers = value_parsers

def parse(self, json_value):
for parser in self.value_parsers:
if hasattr(parser, 'cls') and isinstance(json_value, parser.cls):
try:
return parser.parse(json_value)
except:
pass
raise TypeError(f'{json_value} is not compatible with Union type in Pyckson.')
7 changes: 5 additions & 2 deletions src/pyckson/parsers/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from pyckson.const import BASIC_TYPES, PYCKSON_TYPEINFO, PYCKSON_ENUM_OPTIONS, ENUM_CASE_INSENSITIVE, PYCKSON_PARSER, \
DATE_TYPES, EXTRA_TYPES, get_cls_attr, has_cls_attr, ENUM_USE_VALUES
from pyckson.helpers import TypeProvider, is_list_annotation, is_set_annotation, is_enum_annotation, \
is_basic_dict_annotation, is_typing_dict_annotation
is_basic_dict_annotation, is_typing_dict_annotation, is_union_annotation
from pyckson.parsers.advanced import UnresolvedParser, ClassParser, CustomDeferredParser, DateParser
from pyckson.parsers.base import Parser, BasicParser, ListParser, CaseInsensitiveEnumParser, DefaultEnumParser, \
BasicDictParser, SetParser, BasicParserWithCast, TypingDictParser, DecimalParser, ValuesEnumParser
BasicDictParser, SetParser, BasicParserWithCast, TypingDictParser, DecimalParser, ValuesEnumParser, UnionParser
from pyckson.providers import ParserProvider, ModelProvider


Expand Down Expand Up @@ -66,6 +66,9 @@ def get(self, obj_type, parent_class, name_in_parent) -> Parser:
if obj_type.__args__[0] != str:
raise TypeError('typing.Dict key can only be str in class {}'.format(parent_class))
return TypingDictParser(self.get(obj_type.__args__[1], parent_class, name_in_parent))
if is_union_annotation(obj_type):
return UnionParser([self.get(obj_type_arg, parent_class, name_in_parent)
for obj_type_arg in obj_type.__args__])
if has_cls_attr(obj_type, PYCKSON_PARSER):
return CustomDeferredParser(obj_type)
return ClassParser(obj_type, self.model_provider)
8 changes: 5 additions & 3 deletions src/pyckson/serializers/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@

from pyckson.const import PYCKSON_SERIALIZER, has_cls_attr
from pyckson.dates.helpers import get_class_date_formatter, get_class_use_explicit_nulls
from pyckson.helpers import is_base_type, get_custom_serializer
from pyckson.helpers import is_base_type, get_custom_serializer, is_base_type_with_cast
from pyckson.providers import ModelProvider
from pyckson.serializers.base import Serializer, BasicSerializer
from pyckson.serializers.base import Serializer, BasicSerializer, ListSerializer


class GenericSerializer(Serializer):
def __init__(self, model_provider: ModelProvider):
self.model_provider = model_provider

def serialize(self, obj):
if is_base_type(obj):
if is_base_type(obj) or is_base_type_with_cast(obj):
return BasicSerializer().serialize(obj)
elif has_cls_attr(obj.__class__, PYCKSON_SERIALIZER):
return get_custom_serializer(obj.__class__).serialize(obj)
elif isinstance(obj, list):
return ListSerializer(GenericSerializer(self.model_provider)).serialize(obj)
else:
return ClassSerializer(self.model_provider).serialize(obj)

Expand Down
4 changes: 3 additions & 1 deletion src/pyckson/serializers/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pyckson.defaults import apply_enum_default
from pyckson.helpers import is_list_annotation, is_set_annotation, is_enum_annotation, is_basic_dict_annotation, \
is_typing_dict_annotation
is_typing_dict_annotation, is_union_annotation

try:
from typing import _ForwardRef as ForwardRef
Expand Down Expand Up @@ -52,6 +52,8 @@ def get(self, obj_type, parent_class, name_in_parent) -> Serializer:
if obj_type.__args__[0] != str:
raise TypeError('typing.Dict key can only be str in class {}'.format(parent_class))
return TypingDictSerializer(self.get(obj_type.__args__[1], parent_class, name_in_parent))
if is_union_annotation(obj_type):
return GenericSerializer(self.model_provider)
if has_cls_attr(obj_type, PYCKSON_SERIALIZER):
return CustomDeferredSerializer(obj_type)
return ClassSerializer(self.model_provider)
7 changes: 5 additions & 2 deletions tests/model/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,8 @@ def test_union_with_none_should_be_optional():


def test_other_unions_should_not_be_optional():
assert inspect_optional_typing(Union[int, str]) == (False, int)
assert inspect_optional_typing(Union[int, str, None]) == (False, int)
assert inspect_optional_typing(Union[int, str]) == (False, Union[int, str])


def test_multiple_union_with_none_should_be_optional():
assert inspect_optional_typing(Union[int, str, None]) == (True, Union[int, str])
88 changes: 88 additions & 0 deletions tests/parsers/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from assertpy import assert_that

from pyckson.parsers.base import ParserException, SetParser, UnionParser, BasicParserWithCast, ListParser, BasicParser


class TestBasicParserWithCast:
def test_should_handle_simple_type(self):
parser = BasicParserWithCast(int)

result = parser.parse(5)

assert_that(result).is_equal_to(5)

def test_should_raise_when_it_is_not_the_correct_type(self):
parser = BasicParserWithCast(str)

assert_that(parser.parse).raises(ParserException).when_called_with(5)


class TestUnionParser:
def test_should_parse_simple_union(self):
parser = UnionParser([BasicParserWithCast(int)])

result = parser.parse(5)

assert result == 5

def test_should_parse_list_in_union(self):
parser = UnionParser([ListParser(BasicParserWithCast(int))])

result = parser.parse([5, 6])

assert result == [5, 6]

def test_should_raise_if_parser_does_not_correspond_to_union_type(self):
parser = UnionParser([BasicParserWithCast(int)])

assert_that(parser.parse).raises(TypeError).when_called_with("str")

def test_should_not_raise_if_parser_does_not_have_cls(self):
parser = UnionParser([BasicParser(), BasicParserWithCast(int)])

result = parser.parse(5)

assert_that(result).is_equal_to(5)

def test_should_parse_list_of_list_in_union(self):
parser = UnionParser([ListParser(BasicParserWithCast(int)), ListParser(ListParser(BasicParserWithCast(int)))])

result = parser.parse([[5], [6]])

assert result == [[5], [6]]



class TestListParser:
def test_should_accept_list(self):
parser = ListParser(BasicParserWithCast(int))

result = parser.parse([5])

assert_that(result).is_equal_to([5])

def test_should_raise_when_parse_other_than_list(self):
parser = ListParser(BasicParserWithCast(int))

assert_that(parser.parse).raises(ParserException).when_called_with(5)


class TestSetParser:
def test_should_accept_set(self):
parser = SetParser(BasicParserWithCast(int))

result = parser.parse({5})

assert_that(result).is_equal_to({5})

def test_should_accept_list_as_set(self):
parser = SetParser(BasicParserWithCast(int))

result = parser.parse([5])

assert_that(result).is_equal_to({5})

def test_should_raise_when_parse_other_than_list(self):
parser = SetParser(BasicParserWithCast(int))

assert_that(parser.parse).raises(ParserException).when_called_with(5)
8 changes: 4 additions & 4 deletions tests/parsers/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from unittest import TestCase

from pyckson.parsers.base import DefaultEnumParser, CaseInsensitiveEnumParser
from pyckson.parsers.base import DefaultEnumParser, CaseInsensitiveEnumParser, ParserException


class MyEnum(Enum):
Expand All @@ -20,11 +20,11 @@ def test_should_parse_value_in_enum(self):
self.assertEqual(self.parser.parse('b'), MyEnum.b)

def test_should_not_parse_uppercase_not_in_enum(self):
with self.assertRaises(KeyError):
with self.assertRaises(ParserException):
self.parser.parse('B')

def test_should_not_parse_value_not_in_enum(self):
with self.assertRaises(KeyError):
with self.assertRaises(ParserException):
self.parser.parse('c')


Expand All @@ -46,5 +46,5 @@ def test_should_parse_case_insensitive(self):
self.assertEqual(self.parser.parse('b'), MyInsensitiveEnum.B)

def test_should_not_parse_value_not_in_enum(self):
with self.assertRaises(KeyError):
with self.assertRaises(ParserException):
self.parser.parse('c')
26 changes: 25 additions & 1 deletion tests/parsers/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
from typing import List, Dict, Set, Optional
from typing import List, Dict, Set, Optional, Union
from unittest import TestCase

from pyckson import date_formatter, loads
Expand Down Expand Up @@ -377,3 +377,27 @@ def __init__(self, e: MyEnum):
self.e = e

assert parse(Foo, {'e': 'fooo'}).e == MyEnum.FOO


def test_parse_union_str_values():
class Foo:
def __init__(self, e: Union[str, int]):
self.e = e

assert parse(Foo, {'e': 'fooo'}).e == 'fooo'


def test_parse_union_int_values():
class Foo:
def __init__(self, e: Union[str, int]):
self.e = e

assert parse(Foo, {'e': 5}).e == 5


def test_parse_union_list_values():
class Foo:
def __init__(self, e: Union[str, List[str]]):
self.e = e

assert parse(Foo, {'e': ['yo']}).e == ['yo']
24 changes: 24 additions & 0 deletions tests/serializers/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,30 @@ def __init__(self, foo: Union[X, Y]):
assert serialize(Foo(X('a'))) == {'foo': {'x': 'a'}}


def test_serialize_union_str_values():
class Foo:
def __init__(self, e: Union[str, int]):
self.e = e

assert serialize(Foo('fooo')) == {'e': 'fooo'}


def test_serialize_union_int_values():
class Foo:
def __init__(self, e: Union[str, int]):
self.e = e

assert serialize(Foo(5)) == {'e': 5}


def test_serialize_union_list_values():
class Foo:
def __init__(self, e: Union[str, List[str]]):
self.e = e

assert serialize(Foo(['yo'])) == {'e': ['yo']}


def test_should_serialize_decimal():
class Foo:
def __init__(self, x: Decimal):
Expand Down
Loading

0 comments on commit 1f95777

Please sign in to comment.