From f42fcd89a09908161a90de69ccd0dc4f5551e6d7 Mon Sep 17 00:00:00 2001 From: pavel Date: Sat, 18 Jan 2025 17:59:13 +0300 Subject: [PATCH] Some changes to JSON Schema generation --- src/adaptix/_internal/datastructures.py | 2 +- src/adaptix/_internal/morphing/facade/func.py | 52 ++++++++++++++++++- .../_internal/morphing/facade/retort.py | 8 +++ .../morphing/json_schema/definitions.py | 4 +- .../morphing/json_schema/mangling.py | 16 +++--- .../morphing/json_schema/ref_generator.py | 8 +++ .../morphing/json_schema/resolver.py | 42 ++++++--------- .../morphing/json_schema/schema_tools.py | 47 +++++++++-------- 8 files changed, 120 insertions(+), 59 deletions(-) create mode 100644 src/adaptix/_internal/morphing/json_schema/ref_generator.py diff --git a/src/adaptix/_internal/datastructures.py b/src/adaptix/_internal/datastructures.py index ef22ec7f..28944ed4 100644 --- a/src/adaptix/_internal/datastructures.py +++ b/src/adaptix/_internal/datastructures.py @@ -285,7 +285,7 @@ def __init__(self): self._key_to_item_set = defaultdict(set) def add(self, key: K, item: ItemT) -> None: - if key in self._key_to_item_set and item not in self._key_to_item_set[key]: + if item not in self._key_to_item_set[key]: self._key_to_item_set[key].add(item) self._key_to_item_list[key].append(item) diff --git a/src/adaptix/_internal/morphing/facade/func.py b/src/adaptix/_internal/morphing/facade/func.py index 5281a3c3..aae5b0e9 100644 --- a/src/adaptix/_internal/morphing/facade/func.py +++ b/src/adaptix/_internal/morphing/facade/func.py @@ -1,7 +1,15 @@ +from collections.abc import Iterable, Mapping from typing import Any, Optional, TypeVar, overload from ...common import TypeHint -from .retort import Retort +from ...definitions import Direction +from ..json_schema.definitions import ResolvedJSONSchema +from ..json_schema.mangling import CompoundRefMangler, IndexRefMangler, QualnameRefMangler +from ..json_schema.ref_generator import BuiltinRefGenerator +from ..json_schema.request_cls import JSONSchemaContext +from ..json_schema.resolver import BuiltinJSONSchemaResolver, JSONSchemaResolver +from ..json_schema.schema_model import JSONSchemaDialect +from .retort import AdornedRetort, Retort _global_retort = Retort() T = TypeVar("T") @@ -33,3 +41,45 @@ def dump(data: Any, tp: Optional[TypeHint] = None, /) -> Any: def dump(data: Any, tp: Optional[TypeHint] = None, /) -> Any: return _global_retort.dump(data, tp) + + +_global_resolver = BuiltinJSONSchemaResolver( + ref_generator=BuiltinRefGenerator(), + ref_mangler=CompoundRefMangler(QualnameRefMangler(), IndexRefMangler()), +) + + +DumpedJSONSchema = Mapping[str, Any] + + +def generate_json_schemas( + retort: AdornedRetort, + tps: Iterable[TypeHint], + *, + direction: Direction, + resolver: JSONSchemaResolver = _global_resolver, + dialect: str = JSONSchemaDialect.DRAFT_2020_12, +) -> tuple[DumpedJSONSchema, Iterable[DumpedJSONSchema]]: + ctx = JSONSchemaContext(dialect=dialect, direction=direction) + defs, schemas = resolver.resolve((), [retort.make_json_schema(tp, ctx) for tp in tps]) + dumped_defs = _global_retort.dump(defs, dict[str, ResolvedJSONSchema]) + dumped_schemas = _global_retort.dump(schemas, Iterable[ResolvedJSONSchema]) + return dumped_defs, dumped_schemas + + +def generate_json_schema( + retort: AdornedRetort, + tp: TypeHint, + *, + direction: Direction, + resolver: JSONSchemaResolver = _global_resolver, + dialect: str = JSONSchemaDialect.DRAFT_2020_12, +) -> Mapping[str, Any]: + defs, [schema] = generate_json_schemas( + retort, + [tp], + direction=direction, + resolver=resolver, + dialect=dialect, + ) + return {**schema, "$defs": defs} diff --git a/src/adaptix/_internal/morphing/facade/retort.py b/src/adaptix/_internal/morphing/facade/retort.py index 8d2f1658..9d115dc8 100644 --- a/src/adaptix/_internal/morphing/facade/retort.py +++ b/src/adaptix/_internal/morphing/facade/retort.py @@ -48,7 +48,9 @@ UnionProvider, ) from ..iterable_provider import IterableProvider +from ..json_schema.definitions import JSONSchema from ..json_schema.providers import InlineJSONSchemaProvider, JSONSchemaRefProvider +from ..json_schema.request_cls import JSONSchemaContext, JSONSchemaRequest from ..model.crown_definitions import ExtraSkip from ..model.dumper_provider import ModelDumperProvider from ..model.loader_provider import ModelLoaderProvider @@ -313,6 +315,12 @@ def dump(self, data: Any, tp: Optional[TypeHint] = None, /) -> Any: ) return self.get_dumper(tp)(data) + def make_json_schema(self, tp: TypeHint, ctx: JSONSchemaContext) -> JSONSchema: + return self._facade_provide( + JSONSchemaRequest(loc_stack=LocStack(TypeHintLoc(type=tp)), ctx=ctx), + error_message=f"Cannot produce JSONSchema for type {tp!r}", + ) + class Retort(FilledRetort, AdornedRetort): pass diff --git a/src/adaptix/_internal/morphing/json_schema/definitions.py b/src/adaptix/_internal/morphing/json_schema/definitions.py index c1917088..e6f37faf 100644 --- a/src/adaptix/_internal/morphing/json_schema/definitions.py +++ b/src/adaptix/_internal/morphing/json_schema/definitions.py @@ -13,17 +13,19 @@ @dataclass(frozen=True) class RefSource(Generic[JSONSchemaT]): value: Optional[str] - json_schema: JSONSchemaT + json_schema: JSONSchemaT = field(hash=False) loc_stack: LocStack = field(repr=False) Boolable = Union[T, bool] +@dataclass(repr=False) class JSONSchema(BaseJSONSchema[RefSource[FwdRef["JSONSchema"]], Boolable[FwdRef["JSONSchema"]]]): pass +@dataclass(repr=False) class ResolvedJSONSchema(BaseJSONSchema[str, Boolable[FwdRef["ResolvedJSONSchema"]]]): pass diff --git a/src/adaptix/_internal/morphing/json_schema/mangling.py b/src/adaptix/_internal/morphing/json_schema/mangling.py index 9adbf2ab..89dc70ea 100644 --- a/src/adaptix/_internal/morphing/json_schema/mangling.py +++ b/src/adaptix/_internal/morphing/json_schema/mangling.py @@ -1,9 +1,9 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Container, Mapping, Sequence from itertools import count from typing import Optional from ...datastructures import OrderedUniqueGrouper -from .definitions import RefSource, ResolvedJSONSchema +from .definitions import RefSource from .resolver import RefMangler @@ -14,7 +14,7 @@ def __init__(self, start: int = 1, separator: str = "-"): def mangle_refs( self, - defs: Mapping[str, ResolvedJSONSchema], + occupied_refs: Container[str], common_ref: str, sources: Sequence[RefSource], ) -> Mapping[RefSource, str]: @@ -24,7 +24,7 @@ def mangle_refs( while True: idx = next(counter) mangled = self._with_index(common_ref, idx) - if mangled not in defs: + if mangled not in occupied_refs: result[source] = mangled break @@ -37,7 +37,7 @@ def _with_index(self, common_ref: str, index: int) -> str: class QualnameRefMangler(RefMangler): def mangle_refs( self, - defs: Mapping[str, ResolvedJSONSchema], + occupied_refs: Container[str], common_ref: str, sources: Sequence[RefSource], ) -> Mapping[RefSource, str]: @@ -55,11 +55,11 @@ def __init__(self, base: RefMangler, wrapper: RefMangler): def mangle_refs( self, - defs: Mapping[str, ResolvedJSONSchema], + occupied_refs: Container[str], common_ref: str, sources: Sequence[RefSource], ) -> Mapping[RefSource, str]: - mangled = self._base.mangle_refs(defs, common_ref, sources) + mangled = self._base.mangle_refs(occupied_refs, common_ref, sources) grouper = OrderedUniqueGrouper[str, RefSource]() for source, ref in mangled.items(): @@ -67,6 +67,6 @@ def mangle_refs( for ref, ref_sources in grouper.finalize().items(): if len(ref_sources) > 1: - mangled = {**mangled, **self._wrapper.mangle_refs(defs, ref, ref_sources)} + mangled = {**mangled, **self._wrapper.mangle_refs(occupied_refs, ref, ref_sources)} return mangled diff --git a/src/adaptix/_internal/morphing/json_schema/ref_generator.py b/src/adaptix/_internal/morphing/json_schema/ref_generator.py new file mode 100644 index 00000000..f28a3275 --- /dev/null +++ b/src/adaptix/_internal/morphing/json_schema/ref_generator.py @@ -0,0 +1,8 @@ +from ...provider.loc_stack_filtering import LocStack +from .definitions import JSONSchema +from .resolver import RefGenerator + + +class BuiltinRefGenerator(RefGenerator): + def generate_ref(self, json_schema: JSONSchema, loc_stack: LocStack) -> str: + return str(loc_stack.last.type) diff --git a/src/adaptix/_internal/morphing/json_schema/resolver.py b/src/adaptix/_internal/morphing/json_schema/resolver.py index 0b18a475..1a0111db 100644 --- a/src/adaptix/_internal/morphing/json_schema/resolver.py +++ b/src/adaptix/_internal/morphing/json_schema/resolver.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Mapping, MutableMapping, Sequence - -from adaptix import Omitted +from collections.abc import Container, Mapping, Sequence from ...datastructures import OrderedUniqueGrouper from ...provider.loc_stack_filtering import LocStack from ...provider.loc_stack_tools import format_loc_stack +from ...utils import Omitted from .definitions import JSONSchema, RefSource, ResolvedJSONSchema from .schema_tools import replace_json_schema_ref, traverse_json_schema @@ -15,9 +14,9 @@ class JSONSchemaResolver(ABC): @abstractmethod def resolve( self, - defs: MutableMapping[str, ResolvedJSONSchema], + occupied_refs: Container[str], root_schemas: Sequence[JSONSchema], - ) -> Sequence[ResolvedJSONSchema]: + ) -> tuple[Mapping[str, ResolvedJSONSchema], Sequence[ResolvedJSONSchema]]: ... @@ -31,7 +30,7 @@ class RefMangler(ABC): @abstractmethod def mangle_refs( self, - defs: Mapping[str, ResolvedJSONSchema], + occupied_refs: Container[str], common_ref: str, sources: Sequence[RefSource], ) -> Mapping[RefSource, str]: @@ -45,16 +44,20 @@ def __init__(self, ref_generator: RefGenerator, ref_mangler: RefMangler): def resolve( self, - defs: MutableMapping[str, ResolvedJSONSchema], + occupied_refs: Container[str], root_schemas: Sequence[JSONSchema], - ) -> Sequence[ResolvedJSONSchema]: + ) -> tuple[Mapping[str, ResolvedJSONSchema], Sequence[ResolvedJSONSchema]]: ref_to_sources = self._collect_ref_to_sources(root_schemas) - source_determinator = self._get_source_determinator(defs, ref_to_sources) - self._write_to_defs(defs, source_determinator) - return [ + source_determinator = self._get_source_determinator(occupied_refs, ref_to_sources) + defs = { + ref: replace_json_schema_ref(source.json_schema, source_determinator) + for source, ref in source_determinator.items() + } + schemas = [ replace_json_schema_ref(root, source_determinator) for root in root_schemas ] + return defs, schemas def _collect_ref_to_sources(self, root_schemas: Sequence[JSONSchema]) -> Mapping[str, Sequence[RefSource]]: grouper = OrderedUniqueGrouper[str, RefSource[JSONSchema]]() @@ -74,16 +77,16 @@ def _collect_ref_to_sources(self, root_schemas: Sequence[JSONSchema]) -> Mapping def _get_source_determinator( self, - defs: Mapping[str, ResolvedJSONSchema], + occupied_refs: Container[str], ref_to_sources: Mapping[str, Sequence[RefSource]], ) -> Mapping[RefSource, str]: source_determinator = {} for common_ref, sources in ref_to_sources.items(): - if len(sources) == 1 and common_ref not in defs: + if len(sources) == 1 and common_ref not in occupied_refs: source_determinator[sources[0]] = common_ref else: self._validate_sources(common_ref, sources) - mangling_result = self._ref_mangler.mangle_refs(defs, common_ref, sources) + mangling_result = self._ref_mangler.mangle_refs(occupied_refs, common_ref, sources) source_determinator.update(mangling_result) self._validate_mangling(source_determinator) return source_determinator @@ -117,14 +120,3 @@ def _validate_mangling(self, source_determinator: Mapping[RefSource, str]) -> No f" can not mangle some refs." f" {unmangled_desc}", ) - - def _write_to_defs( - self, - defs: MutableMapping[str, ResolvedJSONSchema], - source_determinator: Mapping[RefSource, str], - ) -> None: - for source, ref in source_determinator.items(): - resolved_json_schema = replace_json_schema_ref(source.json_schema, source_determinator) - if ref in defs and defs[ref] == resolved_json_schema: - continue - defs[ref] = resolved_json_schema diff --git a/src/adaptix/_internal/morphing/json_schema/schema_tools.py b/src/adaptix/_internal/morphing/json_schema/schema_tools.py index acd0d907..1cc5ee8e 100644 --- a/src/adaptix/_internal/morphing/json_schema/schema_tools.py +++ b/src/adaptix/_internal/morphing/json_schema/schema_tools.py @@ -7,16 +7,7 @@ from ...utils import Omittable, Omitted from .definitions import JSONSchema, RefSource, ResolvedJSONSchema -from .schema_model import ( - BaseJSONSchema, - JSONNumeric, - JSONObject, - JSONSchemaBuiltinFormat, - JSONSchemaT, - JSONSchemaType, - JSONValue, - RefT, -) +from .schema_model import JSONNumeric, JSONObject, JSONSchemaBuiltinFormat, JSONSchemaT, JSONSchemaType, JSONValue, RefT _non_generic_fields_types = [ Omittable[Union[JSONSchemaType, Sequence[JSONSchemaType]]], # type: ignore[misc] @@ -39,22 +30,19 @@ """ if __value__ != Omitted(): for item in __value__: - yield item yield from __traverser__(item) """, ), Omittable[JSONSchemaT]: dedent( # type: ignore[misc, valid-type] """ if __value__ != Omitted(): - yield __value__ - yield from __traverser__(value) + yield from __traverser__(__value__) """, ), Omittable[JSONObject[JSONSchemaT]]: dedent( # type: ignore[misc, valid-type] """ if __value__ != Omitted(): for item in __value__.values(): - yield item yield from __traverser__(item) """, ), @@ -65,7 +53,7 @@ Omittable[RefT]: dedent( # type: ignore[misc, valid-type] """ if __value__ != Omitted(): - yield __value__.json_schema + yield from __traverser__(__value__.json_schema) """, ), } @@ -90,18 +78,21 @@ def _generate_json_schema_traverser( .strip("\n"), ) - module_code = f"def {function_name}(obj, /):\n" + "\n\n".join(indent(item, " " * 4) for item in result) + module_code = dedent( + f""" + def {function_name}(obj, /): + if isinstance(obj, bool): + return + + yield obj + + """, + ) + "\n\n".join(indent(item, " " * 4) for item in result) namespace: dict[str, Any] = {"Omitted": Omitted} exec(compile(module_code, file_name, "exec"), namespace, namespace) # noqa: S102 return namespace[function_name] -traverse_base_json_schema = _generate_json_schema_traverser( - function_name="traverse_base_json_schema", - file_name="", - templates=_base_json_schema_templates, - cls=BaseJSONSchema, -) traverse_json_schema = _generate_json_schema_traverser( function_name="traverse_json_schema", file_name="", @@ -158,7 +149,17 @@ def _generate_json_schema_replacer( ) body = "\n".join(indent(item, " " * 8) for item in result) - module_code = f"def {function_name}(obj, ctx, /):\n return {target_cls.__name__}(\n{body}\n )" + module_code = dedent( + f""" + def {function_name}(obj, ctx, /): + if isinstance(obj, bool): + return obj + + return {target_cls.__name__}( + {body} + ) + """, + ) namespace: dict[str, Any] = {target_cls.__name__: target_cls, "Omitted": Omitted} exec(compile(module_code, file_name, "exec"), namespace, namespace) # noqa: S102 return namespace[function_name]