Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ForwardRef in json schema generation #191

Merged
merged 4 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import importlib
import inspect
import math
import sys
import types
import typing
import uuid
Expand Down Expand Up @@ -35,6 +34,7 @@
get_args,
get_class_that_defines_field,
get_class_that_defines_method,
get_forward_ref_referencing_globals,
get_literal_values,
get_name_error_name,
hash_type_args,
Expand Down Expand Up @@ -334,19 +334,9 @@ def evaluate_forward_ref(
typ: typing.ForwardRef,
owner: typing.Optional[typing.Type],
) -> typing.Optional[typing.Type]:
forward_module = getattr(typ, "__forward_module__", None)
if not forward_module and owner:
# We can't get the module in which ForwardRef's value is defined on
# Python < 3.10, ForwardRef evaluation might not work properly
# without this information, so we will consider the namespace of
# the module in which this ForwardRef is used as globalns.
globalns = getattr(
sys.modules.get(owner.__module__, None),
"__dict__",
self.globals,
)
else:
globalns = getattr(forward_module, "__dict__", self.globals)
globalns = get_forward_ref_referencing_globals(
typ, owner, self.globals
)
return evaluate_forward_ref(typ, globalns, self.__dict__)

def get_declared_hook(self, method_name: str) -> typing.Any:
Expand Down
26 changes: 25 additions & 1 deletion mashumaro/core/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import inspect
import re
import sys
import types
import typing
from contextlib import suppress
Expand Down Expand Up @@ -82,6 +83,7 @@
"is_hashable",
"is_hashable_type",
"evaluate_forward_ref",
"get_forward_ref_referencing_globals",
]


Expand Down Expand Up @@ -670,7 +672,7 @@ def is_not_required(typ: Type) -> bool:


def get_function_arg_annotation(
function: typing.Callable[[Any], Any],
function: typing.Callable[..., Any],
arg_name: typing.Optional[str] = None,
arg_pos: typing.Optional[int] = None,
) -> typing.Type:
Expand Down Expand Up @@ -769,3 +771,25 @@ def evaluate_forward_ref(
) # type: ignore[call-arg]
else:
return typ._evaluate(globalns, localns) # type: ignore[call-arg]


def get_forward_ref_referencing_globals(
referenced_type: typing.ForwardRef,
referencing_object: Optional[Any] = None,
fallback: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
if fallback is None:
fallback = {}
forward_module = getattr(referenced_type, "__forward_module__", None)
if not forward_module and referencing_object:
# We can't get the module in which ForwardRef's value is defined on
# Python < 3.10, ForwardRef evaluation might not work properly
# without this information, so we will consider the namespace of
# the module in which this ForwardRef is used as globalns.
return getattr(
sys.modules.get(referencing_object.__module__, None),
"__dict__",
fallback,
)
else:
return getattr(forward_module, "__dict__", fallback)
18 changes: 18 additions & 0 deletions mashumaro/jsonschema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Any,
Callable,
Dict,
ForwardRef,
Iterable,
List,
Optional,
Expand All @@ -28,7 +29,9 @@
from mashumaro.core.const import PY_39_MIN, PY_311_MIN
from mashumaro.core.meta.code.builder import CodeBuilder
from mashumaro.core.meta.helpers import (
evaluate_forward_ref,
get_args,
get_forward_ref_referencing_globals,
get_function_return_annotation,
get_literal_values,
get_type_origin,
Expand Down Expand Up @@ -138,6 +141,13 @@ def owner_class(self) -> Optional[Type]:
return None

def derive(self, **changes: Any) -> "Instance":
new_type = changes.get("type")
if isinstance(new_type, ForwardRef):
changes["type"] = evaluate_forward_ref(
new_type,
get_forward_ref_referencing_globals(new_type, self.type),
self.__dict__,
)
new_instance = replace(self, **changes)
if is_dataclass(self.origin_type):
new_instance.__owner_builder = self.__self_builder
Expand Down Expand Up @@ -437,6 +447,14 @@ def on_special_typing_primitive(
)
elif is_type_var_tuple(instance.type):
return get_schema(instance.derive(type=Tuple[Any, ...]), ctx)
elif isinstance(instance.type, ForwardRef):
evaluated = evaluate_forward_ref(
instance.type,
get_forward_ref_referencing_globals(instance.type),
None,
)
if evaluated is not None:
return get_schema(instance.derive(type=evaluated), ctx)


@register
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import ForwardRef, TypedDict

import pytest

from mashumaro.core.const import PY_39_MIN
from mashumaro.core.meta.helpers import get_function_arg_annotation
from mashumaro.jsonschema import build_json_schema
from mashumaro.jsonschema.models import (
JSONObjectSchema,
JSONSchema,
JSONSchemaInstanceType,
)


class MyTypedDict(TypedDict):
x: int


@pytest.mark.skipif(
not PY_39_MIN,
reason=(
"On Python 3.8 ForwardRef doesn't have __forward_module__ "
"which is needed here"
),
)
def test_jsonschema_generation_for_forward_refs():
def foo(x: int, y: MyTypedDict):
pass

x_type = get_function_arg_annotation(foo, "x")
assert isinstance(x_type, ForwardRef)
assert build_json_schema(x_type).type is JSONSchemaInstanceType.INTEGER

y_type = get_function_arg_annotation(foo, "y")
assert isinstance(y_type, ForwardRef)
assert build_json_schema(y_type) == JSONObjectSchema(
type=JSONSchemaInstanceType.OBJECT,
properties={"x": JSONSchema(type=JSONSchemaInstanceType.INTEGER)},
additionalProperties=False,
required=["x"],
)
Loading