From eb4bd8b6dbad77327c43b9828d7f7d2b0859316a Mon Sep 17 00:00:00 2001 From: Tim Connor Date: Fri, 27 Sep 2024 16:26:46 +1200 Subject: [PATCH] Add support for python 3.10+ typing --- .github/workflows/main.yml | 8 +- hologram/__init__.py | 204 +++++++++++++++++++++++-------------- setup.py | 7 +- tests/test_dict_fields.py | 8 +- tests/test_tuple.py | 11 +- tests/test_union.py | 17 ++-- 6 files changed, 156 insertions(+), 99 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index edb67a0..983807d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,10 +19,10 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.12 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.12 - name: Install dependencies run: | @@ -38,13 +38,13 @@ jobs: - name: Run mypy run: | source venv/bin/activate - mypy hologram --ignore-missing-imports + mypy hologram --ignore-missing-imports --install-types --non-interactive test: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.10', '3.11', '3.12'] steps: - name: Checkout repository diff --git a/hologram/__init__.py b/hologram/__init__.py index e6955cb..4861bea 100644 --- a/hologram/__init__.py +++ b/hologram/__init__.py @@ -1,20 +1,19 @@ import functools from typing import ( - Optional, Type, Union, Any, - Dict, cast, - Tuple, - List, TypeVar, get_type_hints, Callable, Generic, Hashable, ClassVar, + get_origin, + get_args, ) +from types import UnionType import re from datetime import datetime from dataclasses import fields, is_dataclass, Field, MISSING, dataclass, asdict @@ -22,6 +21,7 @@ from enum import Enum import threading import warnings +from collections.abc import Mapping from dateutil.parser import parse import jsonschema @@ -34,10 +34,8 @@ type(None): {"type": "null"}, } -JsonEncodable = Union[int, float, str, bool, None] -JsonDict = Dict[str, Any] - -OPTIONAL_TYPES = ["Union", "Optional"] +JsonEncodable = int | float | str | bool | None +JsonDict = dict[str, Any] class ValidationError(jsonschema.ValidationError): @@ -46,14 +44,14 @@ class ValidationError(jsonschema.ValidationError): class FutureValidationError(ValidationError): # a validation error where we haven't called str() on inputs yet. - def __init__(self, field: str, errors: Dict[str, Exception]): + def __init__(self, field: str, errors: dict[str, Exception]): self.errors = errors self.field = field super().__init__("generic validation error") self.initialized = False - def late_initialize(self): - lines: List[str] = [] + def late_initialize(self) -> None: + lines: list[str] = [] for name, exc in self.errors.items(): # do not use getattr(exc, 'message', str(exc)), it's slow! if hasattr(exc, "message"): @@ -86,15 +84,67 @@ def issubclass_safe(klass: Any, base: Type) -> bool: return False +def is_union(field_type: Any) -> bool: + """ + Checks if a field type is a union. + + Examples:: + + >>> assert is_union(int | str) == True + >>> assert is_union(Union[int, str]) == True + """ + return get_origin(field_type) in [Union, UnionType] + + def is_optional(field: Any) -> bool: - if str(field).startswith("typing.Union") or str(field).startswith( - "typing.Optional" - ): - for arg in field.__args__: - if isinstance(arg, type) and issubclass(arg, type(None)): - return True + """ + Checks if a field type is optional. + + Examples:: + + >>> assert is_optional(int | None) == True + >>> assert is_optional(Optional[int]) == True + """ + return is_union(field) and (type(None) in get_args(field)) + + +def is_list(field: Any) -> bool: + """ + Checks if a field type is a list. + + Examples:: + + >>> assert is_list(list[int]) == True + >>> assert is_list(List[int]) == True + """ + return get_origin(field) == list + + +def is_dict(field: Any) -> bool: + """ + Checks if a field type is a dict. + + Examples:: + + >>> assert is_dict(dict[str, Any]) == True + >>> assert is_dict(Dict[str, Any]) == True + >>> assert is_dict(Mapping[str, Any]) == True + """ + return get_origin(field) in (dict, Mapping) + - return False +def is_tuple(field: Any) -> bool: + """ + Checks if a field type is a tuple. + + Examples:: + + >>> assert is_tuple(tuple[str, int]) == True + >>> assert is_tuple(Tuple[str, int]) == True + >>> assert is_tuple(tuple[str, ...]) == True + >>> assert is_tuple(Tuple[str, ...]) == True + """ + return get_origin(field) == tuple TV = TypeVar("TV") @@ -173,10 +223,10 @@ def _to_camel_case(value: str) -> str: @dataclass class FieldMeta: default: Any = None - description: Optional[str] = None + description: str | None = None @property - def as_dict(self) -> Dict: + def as_dict(self) -> dict: return { _to_camel_case(k): v for k, v in asdict(self).items() @@ -193,11 +243,11 @@ def _validate_schema(h_schema_cls: Hashable) -> JsonDict: # a restriction is a list of Field, str pairs -Restriction = List[Tuple[Field, str]] +Restriction = list[tuple[Field, str]] # a restricted variant is a pair of an object that has fields with restrictions # and those restrictions. Only JsonSchemaMixin subclasses may have restrictied # fields. -Variant = Tuple[Type[T], Optional[Restriction]] +Variant = tuple[Type[T], Restriction | None] def _get_restrictions(variant_type: Type) -> Restriction: @@ -216,7 +266,7 @@ def _get_restrictions(variant_type: Type) -> Restriction: return restrictions -def get_union_fields(field_type: Union[Any]) -> List[Variant]: +def get_union_fields(field_type: Union[Any]) -> list[Variant]: """ Unions have a __args__ that is all their variants (after typing's type-collapsing magic has run, so caveat emptor...) @@ -232,9 +282,9 @@ def get_union_fields(field_type: Union[Any]) -> List[Variant]: The list will be sorted so that unrestricted variants will always be at the end. """ - fields: List[Variant] = [] - for variant in field_type.__args__: - restrictions: Optional[Restriction] = _get_restrictions(variant) + fields: list[Variant] = [] + for variant in get_args(field_type): + restrictions: Restriction | None = _get_restrictions(variant) if not restrictions: restrictions = None fields.append((variant, restrictions)) @@ -245,9 +295,9 @@ def get_union_fields(field_type: Union[Any]) -> List[Variant]: def _encode_restrictions_met( - value: Any, restrict_fields: Optional[List[Tuple[Field, str]]] + value: Any, restrict_fields: list[tuple[Field, str]] | None ) -> bool: - if restrict_fields is None: + if restrict_fields is None or len(restrict_fields) == 0: return True return all( ( @@ -259,7 +309,7 @@ def _encode_restrictions_met( def _decode_restrictions_met( - value: Any, restrict_fields: Optional[List[Tuple[Field, str]]] + value: Any, restrict_fields: list[tuple[Field, str]] | None ) -> bool: if restrict_fields is None: return True @@ -283,25 +333,23 @@ class JsonSchemaMixin: convert to and from JSON encodable dicts with validation against the schema """ - _field_encoders: ClassVar[Dict[Type, FieldEncoder]] = { + _field_encoders: ClassVar[dict[Type, FieldEncoder]] = { datetime: DateTimeFieldEncoder(), UUID: UuidField(), } # Cache of the generated schema - _schema: ClassVar[Optional[Dict[str, CompleteSchema]]] = None + _schema: ClassVar[dict[str, CompleteSchema] | None] = None # Cache of field encode / decode functions - _encode_cache: ClassVar[Optional[Dict[Any, _ValueEncoder]]] = None - _decode_cache: ClassVar[Optional[Dict[Any, _ValueDecoder]]] = None - _mapped_fields: ClassVar[ - Optional[Dict[Any, List[Tuple[Field, str]]]] - ] = None + _encode_cache: ClassVar[dict[Any, _ValueEncoder] | None] = None + _decode_cache: ClassVar[dict[Any, _ValueDecoder] | None] = None + _mapped_fields: ClassVar[dict[Any, list[tuple[Field, str]]] | None] = None ADDITIONAL_PROPERTIES: ClassVar[bool] = False @classmethod - def field_mapping(cls) -> Dict[str, str]: + def field_mapping(cls) -> dict[str, str]: """Defines the mapping of python field names to JSON field names. The main use-case is to allow JSON field names which are Python keywords @@ -309,7 +357,7 @@ def field_mapping(cls) -> Dict[str, str]: return {} @classmethod - def register_field_encoders(cls, field_encoders: Dict[Type, FieldEncoder]): + def register_field_encoders(cls, field_encoders: dict[Type, FieldEncoder]): """Registers additional custom field encoders. If called on the base, these are added globally. The DateTimeFieldEncoder is included by default. @@ -345,7 +393,7 @@ def encoder(ft, v, __): def encoder(_, v, __): return v.value - elif field_type_name in OPTIONAL_TYPES: + elif is_union(field_type): # Attempt to encode the field with each union variant. # TODO: Find a more reliable method than this since in the case 'Union[List[str], Dict[str, int]]' this # will just output the dict keys as a list @@ -367,7 +415,7 @@ def encoder(_, v, __): ) ) return encoded - elif field_type_name in ("Mapping", "Dict"): + elif is_dict(field_type): def encoder(ft, val, o): return { @@ -381,6 +429,7 @@ def encoder(ft, val, o): # TODO: is there some way to set __args__ on this so it can # just re-use Dict/Mapping? def encoder(ft, val, o): + return { cls._encode_field(str, k, o): cls._encode_field( ft.TARGET_TYPE, v, o @@ -388,8 +437,8 @@ def encoder(ft, val, o): for k, v in val.items() } - elif field_type_name == "List" or ( - field_type_name == "Tuple" and ... in field_type.__args__ + elif is_list(field_type) or ( + is_tuple(field_type) and ... in field_type.__args__ ): def encoder(ft, val, o): @@ -410,7 +459,7 @@ def encoder(ft, val, o): cls._encode_field(ft.__args__[0], v, o) for v in val ] - elif field_type_name == "Tuple": + elif is_tuple(field_type): def encoder(ft, val, o): return [ @@ -439,14 +488,14 @@ def encoder(_, v, __): return encoder(field_type, value, omit_none) @classmethod - def _get_fields(cls) -> List[Tuple[Field, str]]: + def _get_fields(cls) -> list[tuple[Field, str]]: if cls._mapped_fields is None: cls._mapped_fields = {} if cls.__name__ not in cls._mapped_fields: mapped_fields = [] type_hints = get_type_hints(cls) - for f in fields(cls): + for f in fields(cls): # type: ignore # Skip internal fields if f.name.startswith("_"): continue @@ -517,7 +566,7 @@ def decoder(_, ft, val): def decoder(_, ft, val): return ft.from_dict(val, validate=validate) - elif field_type_name in OPTIONAL_TYPES: + elif is_union(field_type): # Attempt to decode the value using each decoder in turn union_excs = ( AttributeError, @@ -525,7 +574,7 @@ def decoder(_, ft, val): ValueError, ValidationError, ) - errors: Dict[str, Exception] = {} + errors: dict[str, Exception] = {} union_fields = get_union_fields(field_type) for variant, restrict_fields in union_fields: @@ -541,7 +590,7 @@ def decoder(_, ft, val): # none of the unions decoded, so report about all of them raise FutureValidationError(field, errors) - elif field_type_name in ("Mapping", "Dict"): + elif is_dict(field_type): def decoder(f, ft, val): return { @@ -551,10 +600,10 @@ def decoder(f, ft, val): for k, v in val.items() } - elif field_type_name == "List" or ( - field_type_name == "Tuple" and ... in field_type.__args__ + elif is_list(field_type) or ( + is_tuple(field_type) and ... in field_type.__args__ ): - seq_type = tuple if field_type_name == "Tuple" else list + seq_type = tuple if is_tuple(field_type) else list def decoder(f, ft, val): if not isinstance(val, (tuple, list)): @@ -578,7 +627,7 @@ def decoder(f, ft, val): for v in val ) - elif field_type_name == "Tuple": + elif is_tuple(field_type): def decoder(f, ft, val): return tuple( @@ -621,7 +670,7 @@ def _find_matching_validator(cls: Type[T], data: JsonDict) -> T: for subclass in cls.__subclasses__(): try: if is_dataclass(subclass): - return subclass.from_dict(data) + return subclass.from_dict(data) # type: ignore except ValidationError: continue @@ -637,8 +686,8 @@ def from_dict(cls: Type[T], data: JsonDict, validate=True) -> T: if cls is JsonSchemaMixin: return cls._find_matching_validator(data) - init_values: Dict[str, Any] = {} - non_init_values: Dict[str, Any] = {} + init_values: dict[str, Any] = {} + non_init_values: dict[str, Any] = {} if validate: cls.validate(data) @@ -674,10 +723,10 @@ def _has_definition(field_type: Type) -> bool: ) @classmethod - def _get_field_meta(cls, field: Field) -> Tuple[FieldMeta, bool]: + def _get_field_meta(cls, field: Field) -> tuple[FieldMeta, bool]: required = True field_meta = FieldMeta() - default_value: Optional[Callable[[], Any]] = None + default_value: Callable[[], Any] | None = None if field.default is not MISSING and field.default is not None: # In case of default value given default_value = field.default @@ -702,7 +751,7 @@ def _get_field_meta(cls, field: Field) -> Tuple[FieldMeta, bool]: @classmethod def _encode_restrictions( - cls, restrictions: Union[List[Any], Type[Enum]] + cls, restrictions: list[Any] | Type[Enum] ) -> JsonDict: field_schema: JsonDict = {} member_types = set() @@ -737,8 +786,9 @@ def _get_schema_for_type( cls, target: Type, required: bool = True, - restrictions: Optional[List[Any]] = None, - ) -> Tuple[JsonDict, bool]: + restrictions: list[Any] | None = None, + ) -> tuple[JsonDict, bool]: + field_schema: JsonDict = {"type": "object"} type_name = cls._get_field_type_name(target) @@ -749,12 +799,12 @@ def _get_schema_for_type( elif restrictions: field_schema.update(cls._encode_restrictions(restrictions)) - # if Union[..., None] or Optional[...] - elif type_name in OPTIONAL_TYPES: + # if ... | None, Union[..., None] or Optional[...] + elif is_union(target): field_schema = { "oneOf": [ cls._get_field_schema(variant)[0] - for variant in target.__args__ + for variant in get_args(target) ] } @@ -764,7 +814,7 @@ def _get_schema_for_type( elif is_enum(target): field_schema.update(cls._encode_restrictions(target)) - elif type_name in ("Dict", "Mapping"): + elif is_dict(target): field_schema = {"type": "object"} if target.__args__[1] is not Any: field_schema["additionalProperties"] = cls._get_field_schema( @@ -776,8 +826,10 @@ def _get_schema_for_type( ".*": cls._get_field_schema(target.TARGET_TYPE)[0] } - elif type_name in ("Sequence", "List") or ( - type_name == "Tuple" and ... in target.__args__ + elif ( + type_name == "Sequence" + or is_list(target) + or (is_tuple(target) and ... in target.__args__) ): field_schema = {"type": "array"} if target.__args__[0] is not Any: @@ -785,7 +837,7 @@ def _get_schema_for_type( target.__args__[0] )[0] - elif type_name == "Tuple": + elif is_tuple(target): tuple_len = len(target.__args__) # TODO: How do we handle Optional type within lists / tuples field_schema = { @@ -809,14 +861,12 @@ def _get_schema_for_type( return field_schema, required @classmethod - def _get_field_schema( - cls, field: Union[Field, Type] - ) -> Tuple[JsonDict, bool]: + def _get_field_schema(cls, field: Field | Type) -> tuple[JsonDict, bool]: required = True restrictions = None if isinstance(field, Field): - field_type = field.type + field_type = cast(Type, field.type) field_meta, required = cls._get_field_meta(field) if field.metadata is not None: restrictions = field.metadata.get("restrict") @@ -842,19 +892,19 @@ def _get_field_schema( @classmethod def _get_field_definitions(cls, field_type: Any, definitions: JsonDict): field_type_name = cls._get_field_type_name(field_type) - if field_type_name == "Tuple": + if is_tuple(field_type): # tuples are either like Tuple[T, ...] or Tuple[T1, T2, T3]. for member in field_type.__args__: if member is not ...: cls._get_field_definitions(member, definitions) - elif field_type_name in ("Sequence", "List"): + elif field_type_name == "Sequence" or is_list(field_type): cls._get_field_definitions(field_type.__args__[0], definitions) - elif field_type_name in ("Dict", "Mapping"): + elif is_dict(field_type): cls._get_field_definitions(field_type.__args__[1], definitions) elif field_type_name == "PatternProperty": cls._get_field_definitions(field_type.TARGET_TYPE, definitions) - elif field_type_name in OPTIONAL_TYPES: - for variant in field_type.__args__: + elif is_union(field_type): + for variant in get_args(field_type): cls._get_field_definitions(variant, definitions) elif cls._is_json_schema_subclass(field_type): # Prevent recursion from forward refs & circular type dependencies @@ -872,7 +922,7 @@ def all_json_schemas(cls) -> JsonDict: definitions = {} for subclass in cls.__subclasses__(): if is_dataclass(subclass): - definitions.update(subclass.json_schema(embeddable=True)) + definitions.update(subclass.json_schema(embeddable=True)) # type: ignore else: definitions.update(subclass.all_json_schemas()) return definitions diff --git a/setup.py b/setup.py index 3267d4c..390d0da 100644 --- a/setup.py +++ b/setup.py @@ -31,10 +31,9 @@ def read(f): "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Software Development :: Libraries", ], ) diff --git a/tests/test_dict_fields.py b/tests/test_dict_fields.py index 78bc4ec..4294f07 100644 --- a/tests/test_dict_fields.py +++ b/tests/test_dict_fields.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from hologram import JsonSchemaMixin -from typing import Dict, Union +from typing import Dict, Union, Mapping @dataclass @@ -19,17 +19,19 @@ class SecondDictFieldValue(JsonSchemaMixin): class HasDictFields(JsonSchemaMixin): a: str x: Dict[str, str] - z: Dict[str, Union[DictFieldValue, SecondDictFieldValue]] + y: Mapping[str, str] + z: dict[str, Union[DictFieldValue, SecondDictFieldValue]] def test_schema(): schema = HasDictFields.json_schema() assert schema["type"] == "object" - assert schema["required"] == ["a", "x", "z"] + assert schema["required"] == ["a", "x", "y", "z"] assert schema["properties"] == { "a": {"type": "string"}, "x": {"type": "object", "additionalProperties": {"type": "string"}}, + "y": {"type": "object", "additionalProperties": {"type": "string"}}, "z": { "type": "object", "additionalProperties": { diff --git a/tests/test_tuple.py b/tests/test_tuple.py index a835178..669c13c 100644 --- a/tests/test_tuple.py +++ b/tests/test_tuple.py @@ -11,7 +11,8 @@ class TupleMember(JsonSchemaMixin): @dataclass class TupleEllipsisHolder(JsonSchemaMixin): - member: Tuple[TupleMember, ...] + member1: Tuple[TupleMember, ...] + member2: tuple[TupleMember, ...] @dataclass @@ -25,9 +26,13 @@ class TupleMemberSecondHolder(JsonSchemaMixin): def test_ellipsis_tuples(): - dct = {"member": [{"a": 1}, {"a": 2}, {"a": 3}]} + dct = { + "member1": [{"a": 1}, {"a": 2}, {"a": 3}], + "member2": [{"a": 1}, {"a": 2}, {"a": 3}], + } value = TupleEllipsisHolder( - member=(TupleMember(1), TupleMember(2), TupleMember(3)) + member1=(TupleMember(1), TupleMember(2), TupleMember(3)), + member2=(TupleMember(1), TupleMember(2), TupleMember(3)), ) assert value.to_dict() == dct assert TupleEllipsisHolder.from_dict(dct) == value diff --git a/tests/test_union.py b/tests/test_union.py index 1cddef4..c3d4230 100644 --- a/tests/test_union.py +++ b/tests/test_union.py @@ -9,7 +9,8 @@ @dataclass class IHaveAnnoyingUnions(JsonSchemaMixin): - my_field: Optional[Union[List[str], str]] + my_field1: list[str] | str | None + my_field2: Optional[Union[List[str], str]] @dataclass @@ -19,15 +20,12 @@ class IHaveAnnoyingUnionsReversed(JsonSchemaMixin): def test_union_decoding(): for field_value in (None, [">=0.0.0"], ">=0.0.0"): - obj = IHaveAnnoyingUnions(my_field=field_value) - dct = {"my_field": field_value} + obj = IHaveAnnoyingUnions(my_field1=field_value, my_field2=field_value) + dct = {"my_field1": field_value, "my_field2": field_value} decoded = IHaveAnnoyingUnions.from_dict(dct) assert decoded == obj assert obj.to_dict(omit_none=False) == dct - # this is allowed, for backwards-compatibility reasons - IHaveAnnoyingUnions(my_field=(">=0.0.0",)) == {"my_field": (">=0.0.0",)} - def test_union_decoding_ordering(): for field_value in (None, [">=0.0.0"], ">=0.0.0"): @@ -44,12 +42,15 @@ def test_union_decoding_ordering(): def test_union_decode_error(): - x = IHaveAnnoyingUnions(my_field={">=0.0.0"}) + x = IHaveAnnoyingUnions(my_field1={">=0.0.0"}, my_field2={">=0.0.0"}) with pytest.raises(ValidationError): x.to_dict(validate=True) with pytest.raises(ValidationError): - IHaveAnnoyingUnions.from_dict({"my_field": {">=0.0.0"}}) + IHaveAnnoyingUnions.from_dict({"my_field1": {">=0.0.0"}}) + + with pytest.raises(ValidationError): + IHaveAnnoyingUnions.from_dict({"my_field2": {">=0.0.0"}}) @dataclass