Skip to content

Commit

Permalink
(#34) Add support for Python 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
drewbanin committed Oct 8, 2020
1 parent d9c079a commit f04a6b8
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions hologram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
JsonEncodable = Union[int, float, str, bool, None]
JsonDict = Dict[str, Any]

OPTIONAL_TYPES = ["Union", "Optional"]


class ValidationError(jsonschema.ValidationError):
pass
Expand Down Expand Up @@ -83,7 +85,9 @@ def issubclass_safe(klass: Any, base: Type) -> bool:


def is_optional(field: Any) -> bool:
if str(field).startswith("typing.Union"):
if str(field).startswith("typing.Union") or str(field).startswith(
"typing.Optional"
):
for arg in field.__args__:
if isinstance(arg, type) and issubclass(arg, type(None)):
return True
Expand Down Expand Up @@ -333,7 +337,7 @@ def encoder(ft, v, __):
def encoder(_, v, __):
return v.value

elif field_type_name == "Union":
elif field_type_name in OPTIONAL_TYPES:
# Attempt to encode the field with each union variant.
# TODO: Find a more reliable method than this since in the case 'Union[List[str], Dict[str, int]]' this
# will just output the dict keys as a list
Expand Down Expand Up @@ -491,7 +495,7 @@ def decoder(_, ft, val):
def decoder(_, ft, val):
return ft.from_dict(val, validate=validate)

elif field_type_name == "Union":
elif field_type_name in OPTIONAL_TYPES:
# Attempt to decode the value using each decoder in turn
union_excs = (
AttributeError,
Expand Down Expand Up @@ -725,7 +729,7 @@ def _get_schema_for_type(
field_schema.update(cls._encode_restrictions(restrictions))

# if Union[..., None] or Optional[...]
elif type_name == "Union":
elif type_name in OPTIONAL_TYPES:
field_schema = {
"oneOf": [
cls._get_field_schema(variant)[0]
Expand Down Expand Up @@ -828,7 +832,7 @@ def _get_field_definitions(cls, field_type: Any, definitions: JsonDict):
cls._get_field_definitions(field_type.__args__[1], definitions)
elif field_type_name == "PatternProperty":
cls._get_field_definitions(field_type.TARGET_TYPE, definitions)
elif field_type_name == "Union":
elif field_type_name in OPTIONAL_TYPES:
for variant in field_type.__args__:
cls._get_field_definitions(variant, definitions)
elif cls._is_json_schema_subclass(field_type):
Expand Down Expand Up @@ -956,6 +960,6 @@ def _get_field_type_name(field_type: Any) -> str:
def validate(cls, data: Any):
schema = _validate_schema(cls)
validator = jsonschema.Draft7Validator(schema)
error = next(iter(validator.iter_errors(data)), None)
error = jsonschema.exceptions.best_match(validator.iter_errors(data))
if error is not None:
raise ValidationError.create_from(error) from error

0 comments on commit f04a6b8

Please sign in to comment.