Skip to content

Commit

Permalink
The next iteration of JSONSchema machinery
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Jan 3, 2025
1 parent ae7f63b commit 169acf4
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 31 deletions.
22 changes: 22 additions & 0 deletions src/adaptix/_internal/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import (
Collection,
Hashable,
Expand All @@ -6,6 +7,7 @@
KeysView,
Mapping,
Reversible,
Sequence,
Set,
Sized,
ValuesView,
Expand Down Expand Up @@ -270,3 +272,23 @@ def reversed_slice(self: StackT, end_offset: int) -> StackT:

def count(self, item: T_co) -> int: # type: ignore[misc]
return sum(loc == item for loc in reversed(self))


ItemT = TypeVar("ItemT", bound=Hashable)


class OrderedUniqueGrouper(Generic[K, ItemT]):
__slots__ = ("_key_to_item_list", "_key_to_item_set")

def __init__(self):
self._key_to_item_list = defaultdict(list)
self._key_to_item_set = defaultdict(set)

def add(self, key: K, item: ItemT) -> None:
if key in self._key_to_item_set and item not in self._key_to_item_set[key]:
self._key_to_item_set[key].add(item)
self._key_to_item_list[key].append(item)

def finalize(self) -> Mapping[K, Sequence[ItemT]]:
self._key_to_item_list.default_factory = None
return self._key_to_item_list
12 changes: 4 additions & 8 deletions src/adaptix/_internal/morphing/json_schema/definitions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Generic, TypeVar, Union
from typing import Generic, Optional, TypeVar, Union

from ...provider.loc_stack_filtering import LocStack
from ...type_tools.fwd_ref import FwdRef
Expand All @@ -11,20 +11,16 @@


@dataclass(frozen=True)
class JSONSchemaRef(Generic[JSONSchemaT]):
value: str
is_final: bool
class RefSource(Generic[JSONSchemaT]):
value: Optional[str]
json_schema: JSONSchemaT
loc_stack: LocStack = field(repr=False)

def __hash__(self):
return hash(self.value)


Boolable = Union[T, bool]


class JSONSchema(BaseJSONSchema[JSONSchemaRef[Boolable[FwdRef["JSONSchema"]]], Boolable[FwdRef["JSONSchema"]]]):
class JSONSchema(BaseJSONSchema[RefSource[FwdRef["JSONSchema"]], Boolable[FwdRef["JSONSchema"]]]):
pass


Expand Down
72 changes: 72 additions & 0 deletions src/adaptix/_internal/morphing/json_schema/mangling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from collections.abc import Mapping, Sequence
from itertools import count
from typing import Optional

from ...datastructures import OrderedUniqueGrouper
from .definitions import RefSource, ResolvedJSONSchema
from .resolver import RefMangler


class IndexRefMangler(RefMangler):
def __init__(self, start: int = 1, separator: str = "-"):
self._start = start
self._separator = separator

def mangle_refs(
self,
defs: Mapping[str, ResolvedJSONSchema],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
result = {}
counter = count(self._start)
for source in sources:
while True:
idx = next(counter)
mangled = self._with_index(common_ref, idx)
if mangled not in defs:
result[source] = mangled
break

return result

def _with_index(self, common_ref: str, index: int) -> str:
return f"{common_ref}{self._separator}{index}"


class QualnameRefMangler(RefMangler):
def mangle_refs(
self,
defs: Mapping[str, ResolvedJSONSchema],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
return {source: self._generate_name(source) or common_ref for source in sources}

def _generate_name(self, source: RefSource) -> Optional[str]:
tp = source.loc_stack.last.type
return getattr(tp, "__qualname__", None)


class CompoundRefMangler(RefMangler):
def __init__(self, base: RefMangler, wrapper: RefMangler):
self._base = base
self._wrapper = wrapper

def mangle_refs(
self,
defs: Mapping[str, ResolvedJSONSchema],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
mangled = self._base.mangle_refs(defs, common_ref, sources)

grouper = OrderedUniqueGrouper[str, RefSource]()
for source, ref in mangled.items():
grouper.add(ref, source)

for ref, ref_sources in grouper.finalize().items():
if len(ref_sources) > 1:
mangled = {**mangled, **self._wrapper.mangle_refs(defs, ref, ref_sources)}

return mangled
23 changes: 7 additions & 16 deletions src/adaptix/_internal/morphing/json_schema/providers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ...provider.essential import Mediator
from ...provider.located_request import LocatedRequestMethodsProvider
from ...provider.methods_provider import method_handler
from .definitions import JSONSchemaRef
from .request_cls import InlineJSONSchemaRequest, JSONSchemaRefRequest
from .definitions import RefSource
from .request_cls import InlineJSONSchemaRequest, RefSourceRequest


class InlineJSONSchemaProvider(LocatedRequestMethodsProvider):
Expand All @@ -16,31 +16,22 @@ def provide_inline_json_schema(self, mediator: Mediator, request: InlineJSONSche

class JSONSchemaRefProvider(LocatedRequestMethodsProvider):
@method_handler
def provide_json_schema_ref(self, mediator: Mediator, request: JSONSchemaRefRequest) -> JSONSchemaRef:
return JSONSchemaRef(
value=self._get_reference_value(request),
is_final=False,
def provide_ref_source(self, mediator: Mediator, request: RefSourceRequest) -> RefSource:
return RefSource(
value=None,
json_schema=request.json_schema,
loc_stack=request.loc_stack,
)

def _get_reference_value(self, request: JSONSchemaRefRequest) -> str:
tp = request.loc_stack.last.type
try:
return tp.__name__
except AttributeError:
return str(tp)


class ConstantJSONSchemaRefProvider(LocatedRequestMethodsProvider):
def __init__(self, ref_value: str):
self._ref_value = ref_value

@method_handler
def provide_json_schema_ref(self, mediator: Mediator, request: JSONSchemaRefRequest) -> JSONSchemaRef:
return JSONSchemaRef(
def provide_ref_source(self, mediator: Mediator, request: RefSourceRequest) -> RefSource:
return RefSource(
value=self._ref_value,
is_final=True,
json_schema=request.json_schema,
loc_stack=request.loc_stack,
)
4 changes: 2 additions & 2 deletions src/adaptix/_internal/morphing/json_schema/request_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ...definitions import Direction
from ...provider.located_request import LocatedRequest
from .definitions import JSONSchema, JSONSchemaRef
from .definitions import JSONSchema, RefSource


@dataclass(frozen=True)
Expand All @@ -22,7 +22,7 @@ class JSONSchemaRequest(LocatedRequest[JSONSchema], WithJSONSchemaContext):


@dataclass(frozen=True)
class JSONSchemaRefRequest(LocatedRequest[JSONSchemaRef], WithJSONSchemaContext):
class RefSourceRequest(LocatedRequest[RefSource], WithJSONSchemaContext):
json_schema: JSONSchema


Expand Down
130 changes: 130 additions & 0 deletions src/adaptix/_internal/morphing/json_schema/resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping, MutableMapping, Sequence

from adaptix import Omitted

from ...datastructures import OrderedUniqueGrouper
from ...provider.loc_stack_filtering import LocStack
from ...provider.loc_stack_tools import format_loc_stack
from .definitions import JSONSchema, RefSource, ResolvedJSONSchema
from .schema_tools import replace_json_schema_ref, traverse_json_schema


class JSONSchemaResolver(ABC):
@abstractmethod
def resolve(
self,
defs: MutableMapping[str, ResolvedJSONSchema],
root_schemas: Sequence[JSONSchema],
) -> Sequence[ResolvedJSONSchema]:
...


class RefGenerator(ABC):
@abstractmethod
def generate_ref(self, json_schema: JSONSchema, loc_stack: LocStack) -> str:
...


class RefMangler(ABC):
@abstractmethod
def mangle_refs(
self,
defs: Mapping[str, ResolvedJSONSchema],
common_ref: str,
sources: Sequence[RefSource],
) -> Mapping[RefSource, str]:
...


class BuiltinJSONSchemaResolver(JSONSchemaResolver):
def __init__(self, ref_generator: RefGenerator, ref_mangler: RefMangler):
self._ref_generator = ref_generator
self._ref_mangler = ref_mangler

def resolve(
self,
defs: MutableMapping[str, ResolvedJSONSchema],
root_schemas: Sequence[JSONSchema],
) -> Sequence[ResolvedJSONSchema]:
ref_to_sources = self._collect_ref_to_sources(root_schemas)
source_determinator = self._get_source_determinator(defs, ref_to_sources)
self._write_to_defs(defs, source_determinator)
return [
replace_json_schema_ref(root, source_determinator)
for root in root_schemas
]

def _collect_ref_to_sources(self, root_schemas: Sequence[JSONSchema]) -> Mapping[str, Sequence[RefSource]]:
grouper = OrderedUniqueGrouper[str, RefSource[JSONSchema]]()
for root in root_schemas:
for schema in traverse_json_schema(root):
ref_source = schema.ref
if isinstance(ref_source, Omitted):
continue

ref = (
self._ref_generator.generate_ref(ref_source.json_schema, ref_source.loc_stack)
if ref_source.value is None else
ref_source.value
)
grouper.add(ref, ref_source)
return grouper.finalize()

def _get_source_determinator(
self,
defs: Mapping[str, ResolvedJSONSchema],
ref_to_sources: Mapping[str, Sequence[RefSource]],
) -> Mapping[RefSource, str]:
source_determinator = {}
for common_ref, sources in ref_to_sources.items():
if len(sources) == 1 and common_ref not in defs:
source_determinator[sources[0]] = common_ref
else:
self._validate_sources(common_ref, sources)
mangling_result = self._ref_mangler.mangle_refs(defs, common_ref, sources)
source_determinator.update(mangling_result)
self._validate_mangling(source_determinator)
return source_determinator

def _validate_sources(self, common_ref: str, sources: Sequence[RefSource]) -> None:
pinned_sources = [source for source in sources if source.value is not None]
if len(pinned_sources) > 1:
pinned = ", ".join(f"`{format_loc_stack(pinned.loc_stack)}`" for pinned in pinned_sources)
raise ValueError(
f"Can not create consistent json schema,"
f" there are different sub schemas with pinned ref {common_ref!r}."
f" {pinned}",
)

def _validate_mangling(self, source_determinator: Mapping[RefSource, str]) -> None:
ref_to_sources = defaultdict(list)
for source, ref in source_determinator.items():
ref_to_sources[ref].append(source)

unmangled = [(ref, sources) for ref, sources in ref_to_sources.items() if len(sources) > 1]
if unmangled:
unmangled_desc = "; ".join(
f"For ref {ref!r} at "
+ ", and at ".join(
f"`{format_loc_stack(source.json_schema)}`" for source in sources
)
for ref, sources in unmangled
)
raise ValueError(
f"Can not create consistent json schema,"
f" can not mangle some refs."
f" {unmangled_desc}",
)

def _write_to_defs(
self,
defs: MutableMapping[str, ResolvedJSONSchema],
source_determinator: Mapping[RefSource, str],
) -> None:
for source, ref in source_determinator.items():
resolved_json_schema = replace_json_schema_ref(source.json_schema, source_determinator)
if ref in defs and defs[ref] == resolved_json_schema:
continue
defs[ref] = resolved_json_schema
Loading

0 comments on commit 169acf4

Please sign in to comment.