diff --git a/hologram/__init__.py b/hologram/__init__.py index 53d0eba..2897eeb 100644 --- a/hologram/__init__.py +++ b/hologram/__init__.py @@ -35,6 +35,8 @@ JsonEncodable = Union[int, float, str, bool, None] JsonDict = Dict[str, Any] +OPTIONAL_TYPES = ["Union", "Optional"] + class ValidationError(jsonschema.ValidationError): pass @@ -83,7 +85,7 @@ def issubclass_safe(klass: Any, base: Type) -> bool: def is_optional(field: Any) -> bool: - if str(field).startswith("typing.Union"): + if str(field).startswith("typing.Union") or str(field).startswith("typing.Optional"): for arg in field.__args__: if isinstance(arg, type) and issubclass(arg, type(None)): return True @@ -333,7 +335,7 @@ def encoder(ft, v, __): def encoder(_, v, __): return v.value - elif field_type_name == "Union": + elif field_type_name in OPTIONAL_TYPES: # Attempt to encode the field with each union variant. # TODO: Find a more reliable method than this since in the case 'Union[List[str], Dict[str, int]]' this # will just output the dict keys as a list @@ -491,7 +493,7 @@ def decoder(_, ft, val): def decoder(_, ft, val): return ft.from_dict(val, validate=validate) - elif field_type_name == "Union": + elif field_type_name in OPTIONAL_TYPES: # Attempt to decode the value using each decoder in turn union_excs = ( AttributeError, @@ -725,7 +727,7 @@ def _get_schema_for_type( field_schema.update(cls._encode_restrictions(restrictions)) # if Union[..., None] or Optional[...] - elif type_name == "Union": + elif type_name in OPTIONAL_TYPES: field_schema = { "oneOf": [ cls._get_field_schema(variant)[0] @@ -828,7 +830,7 @@ def _get_field_definitions(cls, field_type: Any, definitions: JsonDict): cls._get_field_definitions(field_type.__args__[1], definitions) elif field_type_name == "PatternProperty": cls._get_field_definitions(field_type.TARGET_TYPE, definitions) - elif field_type_name == "Union": + elif field_type_name in OPTIONAL_TYPES: for variant in field_type.__args__: cls._get_field_definitions(variant, definitions) elif cls._is_json_schema_subclass(field_type): @@ -956,6 +958,6 @@ def _get_field_type_name(field_type: Any) -> str: def validate(cls, data: Any): schema = _validate_schema(cls) validator = jsonschema.Draft7Validator(schema) - error = next(iter(validator.iter_errors(data)), None) + error = jsonschema.exceptions.best_match(validator.iter_errors(data)) if error is not None: raise ValidationError.create_from(error) from error