Skip to content

Commit

Permalink
Add support for python 3.10+ typing
Browse files Browse the repository at this point in the history
  • Loading branch information
tlconnor committed Oct 14, 2024
1 parent 98c1719 commit 2f23b5c
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 73 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v3

- name: Set up Python 3.9
- name: Set up Python 3.12
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.12

- name: Install dependencies
run: |
Expand All @@ -44,7 +44,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
python-version: ['3.10', '3.11', '3.12']

steps:
- name: Checkout repository
Expand Down
111 changes: 59 additions & 52 deletions hologram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@
Generic,
Hashable,
ClassVar,
get_origin,
get_args,
)
from types import UnionType
import re
from datetime import datetime
from dataclasses import fields, is_dataclass, Field, MISSING, dataclass, asdict
from uuid import UUID
from enum import Enum
import threading
import warnings
from collections.abc import Mapping, Sequence

from dateutil.parser import parse
import jsonschema
Expand All @@ -37,8 +41,6 @@
JsonEncodable = Union[int, float, str, bool, None]
JsonDict = Dict[str, Any]

OPTIONAL_TYPES = ["Union", "Optional"]


class ValidationError(jsonschema.ValidationError):
pass
Expand Down Expand Up @@ -86,15 +88,37 @@ def issubclass_safe(klass: Any, base: Type) -> bool:
return False


def is_union(field: Any) -> bool:
return get_origin(field) in [Union, UnionType]


def is_optional(field: Any) -> bool:
if str(field).startswith("typing.Union") or str(field).startswith(
"typing.Optional"
):
for arg in field.__args__:
if isinstance(arg, type) and issubclass(arg, type(None)):
return True
"""
Checks if a field is optional. Handles both the `Optional[str]` and `str | None` syntax.
"""
return is_union(field) and (type(None) in get_args(field))


def is_list(field: Any) -> bool:
"""
Checks if the origin of a field is a list. Handles both `typing.List` and `list`
"""
return get_origin(field) in (list, Sequence)


def is_dict(field: Any) -> bool:
"""
Checks if the origin of a field is a dict. Handles all of `typing.Mapping`, `typing.Dict` and
`dict`.
"""
return get_origin(field) in (dict, Mapping)


return False
def is_tuple(field: Any) -> bool:
"""
Checks if the origin of a field is a tuple. Handles both `typing.Tuple` and `tuple`
"""
return get_origin(field) == tuple


TV = TypeVar("TV")
Expand Down Expand Up @@ -233,7 +257,7 @@ def get_union_fields(field_type: Union[Any]) -> List[Variant]:
end.
"""
fields: List[Variant] = []
for variant in field_type.__args__:
for variant in get_args(field_type):
restrictions: Optional[Restriction] = _get_restrictions(variant)
if not restrictions:
restrictions = None
Expand All @@ -247,7 +271,7 @@ def get_union_fields(field_type: Union[Any]) -> List[Variant]:
def _encode_restrictions_met(
value: Any, restrict_fields: Optional[List[Tuple[Field, str]]]
) -> bool:
if restrict_fields is None:
if restrict_fields is None or len(restrict_fields) == 0:
return True
return all(
(
Expand Down Expand Up @@ -345,7 +369,7 @@ def encoder(ft, v, __):
def encoder(_, v, __):
return v.value

elif field_type_name in OPTIONAL_TYPES:
elif is_union(field_type):
# Attempt to encode the field with each union variant.
# TODO: Find a more reliable method than this since in the case 'Union[List[str], Dict[str, int]]' this
# will just output the dict keys as a list
Expand All @@ -367,7 +391,7 @@ def encoder(_, v, __):
)
)
return encoded
elif field_type_name in ("Mapping", "Dict"):
elif is_dict(field_type):

def encoder(ft, val, o):
return {
Expand All @@ -381,15 +405,16 @@ def encoder(ft, val, o):
# TODO: is there some way to set __args__ on this so it can
# just re-use Dict/Mapping?
def encoder(ft, val, o):

return {
cls._encode_field(str, k, o): cls._encode_field(
ft.TARGET_TYPE, v, o
)
for k, v in val.items()
}

elif field_type_name == "List" or (
field_type_name == "Tuple" and ... in field_type.__args__
elif is_list(field_type) or (
is_tuple(field_type) and ... in field_type.__args__
):

def encoder(ft, val, o):
Expand All @@ -403,14 +428,7 @@ def encoder(ft, val, o):
cls._encode_field(ft.__args__[0], v, o) for v in val
]

elif field_type_name == "Sequence":

def encoder(ft, val, o):
return [
cls._encode_field(ft.__args__[0], v, o) for v in val
]

elif field_type_name == "Tuple":
elif is_tuple(field_type):

def encoder(ft, val, o):
return [
Expand Down Expand Up @@ -517,7 +535,7 @@ def decoder(_, ft, val):
def decoder(_, ft, val):
return ft.from_dict(val, validate=validate)

elif field_type_name in OPTIONAL_TYPES:
elif is_union(field_type):
# Attempt to decode the value using each decoder in turn
union_excs = (
AttributeError,
Expand All @@ -541,7 +559,7 @@ def decoder(_, ft, val):
# none of the unions decoded, so report about all of them
raise FutureValidationError(field, errors)

elif field_type_name in ("Mapping", "Dict"):
elif is_dict(field_type):

def decoder(f, ft, val):
return {
Expand All @@ -551,10 +569,10 @@ def decoder(f, ft, val):
for k, v in val.items()
}

elif field_type_name == "List" or (
field_type_name == "Tuple" and ... in field_type.__args__
elif is_list(field_type) or (
is_tuple(field_type) and ... in field_type.__args__
):
seq_type = tuple if field_type_name == "Tuple" else list
seq_type = tuple if is_tuple(field_type) else list

def decoder(f, ft, val):
if not isinstance(val, (tuple, list)):
Expand All @@ -568,17 +586,7 @@ def decoder(f, ft, val):
for v in val
)

# if you want to allow strings as sequences for some reason, you
# can still use 'Sequence' to get back a list of characters...
elif field_type_name == "Sequence":

def decoder(f, ft, val):
return list(
cls._decode_field(f, ft.__args__[0], v, validate)
for v in val
)

elif field_type_name == "Tuple":
elif is_tuple(field_type):

def decoder(f, ft, val):
return tuple(
Expand Down Expand Up @@ -739,6 +747,7 @@ def _get_schema_for_type(
required: bool = True,
restrictions: Optional[List[Any]] = None,
) -> Tuple[JsonDict, bool]:

field_schema: JsonDict = {"type": "object"}

type_name = cls._get_field_type_name(target)
Expand All @@ -749,12 +758,12 @@ def _get_schema_for_type(
elif restrictions:
field_schema.update(cls._encode_restrictions(restrictions))

# if Union[..., None] or Optional[...]
elif type_name in OPTIONAL_TYPES:
# if ... | None, Union[..., None] or Optional[...]
elif is_union(target):
field_schema = {
"oneOf": [
cls._get_field_schema(variant)[0]
for variant in target.__args__
for variant in get_args(target)
]
}

Expand All @@ -764,7 +773,7 @@ def _get_schema_for_type(
elif is_enum(target):
field_schema.update(cls._encode_restrictions(target))

elif type_name in ("Dict", "Mapping"):
elif is_dict(target):
field_schema = {"type": "object"}
if target.__args__[1] is not Any:
field_schema["additionalProperties"] = cls._get_field_schema(
Expand All @@ -776,16 +785,14 @@ def _get_schema_for_type(
".*": cls._get_field_schema(target.TARGET_TYPE)[0]
}

elif type_name in ("Sequence", "List") or (
type_name == "Tuple" and ... in target.__args__
):
elif is_list(target) or (is_tuple(target) and ... in target.__args__):
field_schema = {"type": "array"}
if target.__args__[0] is not Any:
field_schema["items"] = cls._get_field_schema(
target.__args__[0]
)[0]

elif type_name == "Tuple":
elif is_tuple(target):
tuple_len = len(target.__args__)
# TODO: How do we handle Optional type within lists / tuples
field_schema = {
Expand Down Expand Up @@ -842,19 +849,19 @@ def _get_field_schema(
@classmethod
def _get_field_definitions(cls, field_type: Any, definitions: JsonDict):
field_type_name = cls._get_field_type_name(field_type)
if field_type_name == "Tuple":
if is_tuple(field_type):
# tuples are either like Tuple[T, ...] or Tuple[T1, T2, T3].
for member in field_type.__args__:
if member is not ...:
cls._get_field_definitions(member, definitions)
elif field_type_name in ("Sequence", "List"):
elif is_list(field_type):
cls._get_field_definitions(field_type.__args__[0], definitions)
elif field_type_name in ("Dict", "Mapping"):
elif is_dict(field_type):
cls._get_field_definitions(field_type.__args__[1], definitions)
elif field_type_name == "PatternProperty":
cls._get_field_definitions(field_type.TARGET_TYPE, definitions)
elif field_type_name in OPTIONAL_TYPES:
for variant in field_type.__args__:
elif is_union(field_type):
for variant in get_args(field_type):
cls._get_field_definitions(variant, definitions)
elif cls._is_json_schema_subclass(field_type):
# Prevent recursion from forward refs & circular type dependencies
Expand Down
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def read(f):
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Software Development :: Libraries",
],
)
8 changes: 5 additions & 3 deletions tests/test_dict_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from hologram import JsonSchemaMixin
from typing import Dict, Union
from typing import Dict, Union, Mapping


@dataclass
Expand All @@ -19,17 +19,19 @@ class SecondDictFieldValue(JsonSchemaMixin):
class HasDictFields(JsonSchemaMixin):
a: str
x: Dict[str, str]
z: Dict[str, Union[DictFieldValue, SecondDictFieldValue]]
y: Mapping[str, str]
z: dict[str, Union[DictFieldValue, SecondDictFieldValue]]


def test_schema():
schema = HasDictFields.json_schema()

assert schema["type"] == "object"
assert schema["required"] == ["a", "x", "z"]
assert schema["required"] == ["a", "x", "y", "z"]
assert schema["properties"] == {
"a": {"type": "string"},
"x": {"type": "object", "additionalProperties": {"type": "string"}},
"y": {"type": "object", "additionalProperties": {"type": "string"}},
"z": {
"type": "object",
"additionalProperties": {
Expand Down
11 changes: 8 additions & 3 deletions tests/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class TupleMember(JsonSchemaMixin):

@dataclass
class TupleEllipsisHolder(JsonSchemaMixin):
member: Tuple[TupleMember, ...]
member1: Tuple[TupleMember, ...]
member2: tuple[TupleMember, ...]


@dataclass
Expand All @@ -25,9 +26,13 @@ class TupleMemberSecondHolder(JsonSchemaMixin):


def test_ellipsis_tuples():
dct = {"member": [{"a": 1}, {"a": 2}, {"a": 3}]}
dct = {
"member1": [{"a": 1}, {"a": 2}, {"a": 3}],
"member2": [{"a": 1}, {"a": 2}, {"a": 3}],
}
value = TupleEllipsisHolder(
member=(TupleMember(1), TupleMember(2), TupleMember(3))
member1=(TupleMember(1), TupleMember(2), TupleMember(3)),
member2=(TupleMember(1), TupleMember(2), TupleMember(3)),
)
assert value.to_dict() == dct
assert TupleEllipsisHolder.from_dict(dct) == value
Expand Down
Loading

0 comments on commit 2f23b5c

Please sign in to comment.