Skip to content

Commit

Permalink
Some changes to JSON Schema generation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Jan 18, 2025
1 parent 979b00f commit f42fcd8
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/adaptix/_internal/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(self):
self._key_to_item_set = defaultdict(set)

def add(self, key: K, item: ItemT) -> None:
if key in self._key_to_item_set and item not in self._key_to_item_set[key]:
if item not in self._key_to_item_set[key]:
self._key_to_item_set[key].add(item)
self._key_to_item_list[key].append(item)

Expand Down
52 changes: 51 additions & 1 deletion src/adaptix/_internal/morphing/facade/func.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from collections.abc import Iterable, Mapping
from typing import Any, Optional, TypeVar, overload

from ...common import TypeHint
from .retort import Retort
from ...definitions import Direction
from ..json_schema.definitions import ResolvedJSONSchema
from ..json_schema.mangling import CompoundRefMangler, IndexRefMangler, QualnameRefMangler
from ..json_schema.ref_generator import BuiltinRefGenerator
from ..json_schema.request_cls import JSONSchemaContext
from ..json_schema.resolver import BuiltinJSONSchemaResolver, JSONSchemaResolver
from ..json_schema.schema_model import JSONSchemaDialect
from .retort import AdornedRetort, Retort

_global_retort = Retort()
T = TypeVar("T")
Expand Down Expand Up @@ -33,3 +41,45 @@ def dump(data: Any, tp: Optional[TypeHint] = None, /) -> Any:

def dump(data: Any, tp: Optional[TypeHint] = None, /) -> Any:
return _global_retort.dump(data, tp)


_global_resolver = BuiltinJSONSchemaResolver(
ref_generator=BuiltinRefGenerator(),
ref_mangler=CompoundRefMangler(QualnameRefMangler(), IndexRefMangler()),
)


DumpedJSONSchema = Mapping[str, Any]


def generate_json_schemas(
retort: AdornedRetort,
tps: Iterable[TypeHint],
*,
direction: Direction,
resolver: JSONSchemaResolver = _global_resolver,
dialect: str = JSONSchemaDialect.DRAFT_2020_12,
) -> tuple[DumpedJSONSchema, Iterable[DumpedJSONSchema]]:
ctx = JSONSchemaContext(dialect=dialect, direction=direction)
defs, schemas = resolver.resolve((), [retort.make_json_schema(tp, ctx) for tp in tps])
dumped_defs = _global_retort.dump(defs, dict[str, ResolvedJSONSchema])
dumped_schemas = _global_retort.dump(schemas, Iterable[ResolvedJSONSchema])
return dumped_defs, dumped_schemas


def generate_json_schema(
retort: AdornedRetort,
tp: TypeHint,
*,
direction: Direction,
resolver: JSONSchemaResolver = _global_resolver,
dialect: str = JSONSchemaDialect.DRAFT_2020_12,
) -> Mapping[str, Any]:
defs, [schema] = generate_json_schemas(
retort,
[tp],
direction=direction,
resolver=resolver,
dialect=dialect,
)
return {**schema, "$defs": defs}
8 changes: 8 additions & 0 deletions src/adaptix/_internal/morphing/facade/retort.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
UnionProvider,
)
from ..iterable_provider import IterableProvider
from ..json_schema.definitions import JSONSchema
from ..json_schema.providers import InlineJSONSchemaProvider, JSONSchemaRefProvider
from ..json_schema.request_cls import JSONSchemaContext, JSONSchemaRequest
from ..model.crown_definitions import ExtraSkip
from ..model.dumper_provider import ModelDumperProvider
from ..model.loader_provider import ModelLoaderProvider
Expand Down Expand Up @@ -313,6 +315,12 @@ def dump(self, data: Any, tp: Optional[TypeHint] = None, /) -> Any:
)
return self.get_dumper(tp)(data)

def make_json_schema(self, tp: TypeHint, ctx: JSONSchemaContext) -> JSONSchema:
return self._facade_provide(
JSONSchemaRequest(loc_stack=LocStack(TypeHintLoc(type=tp)), ctx=ctx),
error_message=f"Cannot produce JSONSchema for type {tp!r}",
)


class Retort(FilledRetort, AdornedRetort):
pass
4 changes: 3 additions & 1 deletion src/adaptix/_internal/morphing/json_schema/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@
@dataclass(frozen=True)
class RefSource(Generic[JSONSchemaT]):
value: Optional[str]
json_schema: JSONSchemaT
json_schema: JSONSchemaT = field(hash=False)
loc_stack: LocStack = field(repr=False)


Boolable = Union[T, bool]


@dataclass(repr=False)
class JSONSchema(BaseJSONSchema[RefSource[FwdRef["JSONSchema"]], Boolable[FwdRef["JSONSchema"]]]):
pass


@dataclass(repr=False)
class ResolvedJSONSchema(BaseJSONSchema[str, Boolable[FwdRef["ResolvedJSONSchema"]]]):
pass

16 changes: 8 additions & 8 deletions src/adaptix/_internal/morphing/json_schema/mangling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections.abc import Mapping, Sequence
from collections.abc import Container, Mapping, Sequence
from itertools import count
from typing import Optional

from ...datastructures import OrderedUniqueGrouper
from .definitions import RefSource, ResolvedJSONSchema
from .definitions import RefSource
from .resolver import RefMangler


Expand All @@ -14,7 +14,7 @@ def __init__(self, start: int = 1, separator: str = "-"):

def mangle_refs(
self,
defs: Mapping[str, ResolvedJSONSchema],
occupied_refs: Container[str],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
Expand All @@ -24,7 +24,7 @@ def mangle_refs(
while True:
idx = next(counter)
mangled = self._with_index(common_ref, idx)
if mangled not in defs:
if mangled not in occupied_refs:
result[source] = mangled
break

Expand All @@ -37,7 +37,7 @@ def _with_index(self, common_ref: str, index: int) -> str:
class QualnameRefMangler(RefMangler):
def mangle_refs(
self,
defs: Mapping[str, ResolvedJSONSchema],
occupied_refs: Container[str],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
Expand All @@ -55,18 +55,18 @@ def __init__(self, base: RefMangler, wrapper: RefMangler):

def mangle_refs(
self,
defs: Mapping[str, ResolvedJSONSchema],
occupied_refs: Container[str],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
mangled = self._base.mangle_refs(defs, common_ref, sources)
mangled = self._base.mangle_refs(occupied_refs, common_ref, sources)

grouper = OrderedUniqueGrouper[str, RefSource]()
for source, ref in mangled.items():
grouper.add(ref, source)

for ref, ref_sources in grouper.finalize().items():
if len(ref_sources) > 1:
mangled = {**mangled, **self._wrapper.mangle_refs(defs, ref, ref_sources)}
mangled = {**mangled, **self._wrapper.mangle_refs(occupied_refs, ref, ref_sources)}

return mangled
8 changes: 8 additions & 0 deletions src/adaptix/_internal/morphing/json_schema/ref_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ...provider.loc_stack_filtering import LocStack
from .definitions import JSONSchema
from .resolver import RefGenerator


class BuiltinRefGenerator(RefGenerator):
def generate_ref(self, json_schema: JSONSchema, loc_stack: LocStack) -> str:
return str(loc_stack.last.type)
42 changes: 17 additions & 25 deletions src/adaptix/_internal/morphing/json_schema/resolver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping, MutableMapping, Sequence

from adaptix import Omitted
from collections.abc import Container, Mapping, Sequence

from ...datastructures import OrderedUniqueGrouper
from ...provider.loc_stack_filtering import LocStack
from ...provider.loc_stack_tools import format_loc_stack
from ...utils import Omitted
from .definitions import JSONSchema, RefSource, ResolvedJSONSchema
from .schema_tools import replace_json_schema_ref, traverse_json_schema

Expand All @@ -15,9 +14,9 @@ class JSONSchemaResolver(ABC):
@abstractmethod
def resolve(
self,
defs: MutableMapping[str, ResolvedJSONSchema],
occupied_refs: Container[str],
root_schemas: Sequence[JSONSchema],
) -> Sequence[ResolvedJSONSchema]:
) -> tuple[Mapping[str, ResolvedJSONSchema], Sequence[ResolvedJSONSchema]]:
...


Expand All @@ -31,7 +30,7 @@ class RefMangler(ABC):
@abstractmethod
def mangle_refs(
self,
defs: Mapping[str, ResolvedJSONSchema],
occupied_refs: Container[str],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
Expand All @@ -45,16 +44,20 @@ def __init__(self, ref_generator: RefGenerator, ref_mangler: RefMangler):

def resolve(
self,
defs: MutableMapping[str, ResolvedJSONSchema],
occupied_refs: Container[str],
root_schemas: Sequence[JSONSchema],
) -> Sequence[ResolvedJSONSchema]:
) -> tuple[Mapping[str, ResolvedJSONSchema], Sequence[ResolvedJSONSchema]]:
ref_to_sources = self._collect_ref_to_sources(root_schemas)
source_determinator = self._get_source_determinator(defs, ref_to_sources)
self._write_to_defs(defs, source_determinator)
return [
source_determinator = self._get_source_determinator(occupied_refs, ref_to_sources)
defs = {
ref: replace_json_schema_ref(source.json_schema, source_determinator)
for source, ref in source_determinator.items()
}
schemas = [
replace_json_schema_ref(root, source_determinator)
for root in root_schemas
]
return defs, schemas

def _collect_ref_to_sources(self, root_schemas: Sequence[JSONSchema]) -> Mapping[str, Sequence[RefSource]]:
grouper = OrderedUniqueGrouper[str, RefSource[JSONSchema]]()
Expand All @@ -74,16 +77,16 @@ def _collect_ref_to_sources(self, root_schemas: Sequence[JSONSchema]) -> Mapping

def _get_source_determinator(
self,
defs: Mapping[str, ResolvedJSONSchema],
occupied_refs: Container[str],
ref_to_sources: Mapping[str, Sequence[RefSource]],
) -> Mapping[RefSource, str]:
source_determinator = {}
for common_ref, sources in ref_to_sources.items():
if len(sources) == 1 and common_ref not in defs:
if len(sources) == 1 and common_ref not in occupied_refs:
source_determinator[sources[0]] = common_ref
else:
self._validate_sources(common_ref, sources)
mangling_result = self._ref_mangler.mangle_refs(defs, common_ref, sources)
mangling_result = self._ref_mangler.mangle_refs(occupied_refs, common_ref, sources)
source_determinator.update(mangling_result)
self._validate_mangling(source_determinator)
return source_determinator
Expand Down Expand Up @@ -117,14 +120,3 @@ def _validate_mangling(self, source_determinator: Mapping[RefSource, str]) -> No
f" can not mangle some refs."
f" {unmangled_desc}",
)

def _write_to_defs(
self,
defs: MutableMapping[str, ResolvedJSONSchema],
source_determinator: Mapping[RefSource, str],
) -> None:
for source, ref in source_determinator.items():
resolved_json_schema = replace_json_schema_ref(source.json_schema, source_determinator)
if ref in defs and defs[ref] == resolved_json_schema:
continue
defs[ref] = resolved_json_schema
47 changes: 24 additions & 23 deletions src/adaptix/_internal/morphing/json_schema/schema_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@

from ...utils import Omittable, Omitted
from .definitions import JSONSchema, RefSource, ResolvedJSONSchema
from .schema_model import (
BaseJSONSchema,
JSONNumeric,
JSONObject,
JSONSchemaBuiltinFormat,
JSONSchemaT,
JSONSchemaType,
JSONValue,
RefT,
)
from .schema_model import JSONNumeric, JSONObject, JSONSchemaBuiltinFormat, JSONSchemaT, JSONSchemaType, JSONValue, RefT

_non_generic_fields_types = [
Omittable[Union[JSONSchemaType, Sequence[JSONSchemaType]]], # type: ignore[misc]
Expand All @@ -39,22 +30,19 @@
"""
if __value__ != Omitted():
for item in __value__:
yield item
yield from __traverser__(item)
""",
),
Omittable[JSONSchemaT]: dedent( # type: ignore[misc, valid-type]
"""
if __value__ != Omitted():
yield __value__
yield from __traverser__(value)
yield from __traverser__(__value__)
""",
),
Omittable[JSONObject[JSONSchemaT]]: dedent( # type: ignore[misc, valid-type]
"""
if __value__ != Omitted():
for item in __value__.values():
yield item
yield from __traverser__(item)
""",
),
Expand All @@ -65,7 +53,7 @@
Omittable[RefT]: dedent( # type: ignore[misc, valid-type]
"""
if __value__ != Omitted():
yield __value__.json_schema
yield from __traverser__(__value__.json_schema)
""",
),
}
Expand All @@ -90,18 +78,21 @@ def _generate_json_schema_traverser(
.strip("\n"),
)

module_code = f"def {function_name}(obj, /):\n" + "\n\n".join(indent(item, " " * 4) for item in result)
module_code = dedent(
f"""
def {function_name}(obj, /):
if isinstance(obj, bool):
return
yield obj
""",
) + "\n\n".join(indent(item, " " * 4) for item in result)
namespace: dict[str, Any] = {"Omitted": Omitted}
exec(compile(module_code, file_name, "exec"), namespace, namespace) # noqa: S102
return namespace[function_name]


traverse_base_json_schema = _generate_json_schema_traverser(
function_name="traverse_base_json_schema",
file_name="<traverse_base_json_schema generation>",
templates=_base_json_schema_templates,
cls=BaseJSONSchema,
)
traverse_json_schema = _generate_json_schema_traverser(
function_name="traverse_json_schema",
file_name="<traverse_json_schema generation>",
Expand Down Expand Up @@ -158,7 +149,17 @@ def _generate_json_schema_replacer(
)

body = "\n".join(indent(item, " " * 8) for item in result)
module_code = f"def {function_name}(obj, ctx, /):\n return {target_cls.__name__}(\n{body}\n )"
module_code = dedent(
f"""
def {function_name}(obj, ctx, /):
if isinstance(obj, bool):
return obj
return {target_cls.__name__}(
{body}
)
""",
)
namespace: dict[str, Any] = {target_cls.__name__: target_cls, "Omitted": Omitted}
exec(compile(module_code, file_name, "exec"), namespace, namespace) # noqa: S102
return namespace[function_name]
Expand Down

0 comments on commit f42fcd8

Please sign in to comment.