Skip to content

Commit

Permalink
Merge pull request #191 from Fatal1ty/jsonschema-forward-refs
Browse files Browse the repository at this point in the history
Add support for ForwardRef in json schema generation
  • Loading branch information
Fatal1ty committed Jan 27, 2024
2 parents d46fd7e + 1cf930c commit 3573dcf
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 15 deletions.
18 changes: 4 additions & 14 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import importlib
import inspect
import math
import sys
import types
import typing
import uuid
Expand Down Expand Up @@ -36,6 +35,7 @@
get_args,
get_class_that_defines_field,
get_class_that_defines_method,
get_forward_ref_referencing_globals,
get_literal_values,
get_name_error_name,
hash_type_args,
Expand Down Expand Up @@ -335,19 +335,9 @@ def evaluate_forward_ref(
typ: typing.ForwardRef,
owner: typing.Optional[typing.Type],
) -> typing.Optional[typing.Type]:
forward_module = getattr(typ, "__forward_module__", None)
if not forward_module and owner:
# We can't get the module in which ForwardRef's value is defined on
# Python < 3.10, ForwardRef evaluation might not work properly
# without this information, so we will consider the namespace of
# the module in which this ForwardRef is used as globalns.
globalns = getattr(
sys.modules.get(owner.__module__, None),
"__dict__",
self.globals,
)
else:
globalns = getattr(forward_module, "__dict__", self.globals)
globalns = get_forward_ref_referencing_globals(
typ, owner, self.globals
)
return evaluate_forward_ref(typ, globalns, self.__dict__)

def get_declared_hook(self, method_name: str) -> typing.Any:
Expand Down
26 changes: 25 additions & 1 deletion mashumaro/core/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import inspect
import re
import sys
import types
import typing
from contextlib import suppress
Expand Down Expand Up @@ -82,6 +83,7 @@
"is_hashable",
"is_hashable_type",
"evaluate_forward_ref",
"get_forward_ref_referencing_globals",
]


Expand Down Expand Up @@ -670,7 +672,7 @@ def is_not_required(typ: Type) -> bool:


def get_function_arg_annotation(
function: typing.Callable[[Any], Any],
function: typing.Callable[..., Any],
arg_name: typing.Optional[str] = None,
arg_pos: typing.Optional[int] = None,
) -> typing.Type:
Expand Down Expand Up @@ -769,3 +771,25 @@ def evaluate_forward_ref(
) # type: ignore[call-arg]
else:
return typ._evaluate(globalns, localns) # type: ignore[call-arg]


def get_forward_ref_referencing_globals(
referenced_type: typing.ForwardRef,
referencing_object: Optional[Any] = None,
fallback: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
if fallback is None:
fallback = {}
forward_module = getattr(referenced_type, "__forward_module__", None)
if not forward_module and referencing_object:
# We can't get the module in which ForwardRef's value is defined on
# Python < 3.10, ForwardRef evaluation might not work properly
# without this information, so we will consider the namespace of
# the module in which this ForwardRef is used as globalns.
return getattr(
sys.modules.get(referencing_object.__module__, None),
"__dict__",
fallback,
)
else:
return getattr(forward_module, "__dict__", fallback)
18 changes: 18 additions & 0 deletions mashumaro/jsonschema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Any,
Callable,
Dict,
ForwardRef,
Iterable,
List,
Optional,
Expand All @@ -28,7 +29,9 @@
from mashumaro.core.const import PY_39_MIN, PY_311_MIN
from mashumaro.core.meta.code.builder import CodeBuilder
from mashumaro.core.meta.helpers import (
evaluate_forward_ref,
get_args,
get_forward_ref_referencing_globals,
get_function_return_annotation,
get_literal_values,
get_type_origin,
Expand Down Expand Up @@ -138,6 +141,13 @@ def owner_class(self) -> Optional[Type]:
return None

def derive(self, **changes: Any) -> "Instance":
new_type = changes.get("type")
if isinstance(new_type, ForwardRef):
changes["type"] = evaluate_forward_ref(
new_type,
get_forward_ref_referencing_globals(new_type, self.type),
self.__dict__,
)
new_instance = replace(self, **changes)
if is_dataclass(self.origin_type):
new_instance.__owner_builder = self.__self_builder
Expand Down Expand Up @@ -437,6 +447,14 @@ def on_special_typing_primitive(
)
elif is_type_var_tuple(instance.type):
return get_schema(instance.derive(type=Tuple[Any, ...]), ctx)
elif isinstance(instance.type, ForwardRef):
evaluated = evaluate_forward_ref(
instance.type,
get_forward_ref_referencing_globals(instance.type),
None,
)
if evaluated is not None:
return get_schema(instance.derive(type=evaluated), ctx)


@register
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import ForwardRef, TypedDict

import pytest

from mashumaro.core.const import PY_39_MIN
from mashumaro.core.meta.helpers import get_function_arg_annotation
from mashumaro.jsonschema import build_json_schema
from mashumaro.jsonschema.models import (
JSONObjectSchema,
JSONSchema,
JSONSchemaInstanceType,
)


class MyTypedDict(TypedDict):
x: int


@pytest.mark.skipif(
not PY_39_MIN,
reason=(
"On Python 3.8 ForwardRef doesn't have __forward_module__ "
"which is needed here"
),
)
def test_jsonschema_generation_for_forward_refs():
def foo(x: int, y: MyTypedDict):
pass

x_type = get_function_arg_annotation(foo, "x")
assert isinstance(x_type, ForwardRef)
assert build_json_schema(x_type).type is JSONSchemaInstanceType.INTEGER

y_type = get_function_arg_annotation(foo, "y")
assert isinstance(y_type, ForwardRef)
assert build_json_schema(y_type) == JSONObjectSchema(
type=JSONSchemaInstanceType.OBJECT,
properties={"x": JSONSchema(type=JSONSchemaInstanceType.INTEGER)},
additionalProperties=False,
required=["x"],
)

0 comments on commit 3573dcf

Please sign in to comment.