Skip to content

Commit

Permalink
add JSON schema providing to ModelLoaderProvider and ModelDumperProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Jul 28, 2024
1 parent 4f6eb94 commit b4f4e72
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 39 deletions.
2 changes: 2 additions & 0 deletions src/adaptix/_internal/morphing/json_schema/request_cls.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from dataclasses import dataclass

from ...definitions import Direction
from ...provider.located_request import LocatedRequest
from .definitions import JSONSchema, JSONSchemaRef


@dataclass(frozen=True)
class JSONSchemaContext:
dialect: str
direction: Direction


@dataclass(frozen=True)
Expand Down
12 changes: 7 additions & 5 deletions src/adaptix/_internal/morphing/model/dumper_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from string import Template
from typing import Any, Callable, Dict, Mapping, NamedTuple, Tuple

from ...utils import Omittable, Omitted

from ...code_tools.cascade_namespace import BuiltinCascadeNamespace, CascadeNamespace
from ...code_tools.code_builder import CodeBuilder
from ...code_tools.utils import get_literal_expr, get_literal_from_factory, is_singleton
Expand All @@ -15,7 +17,6 @@
DefaultValue,
DescriptorAccessor,
ItemAccessor,
NoDefault,
OutputField,
OutputShape,
)
Expand Down Expand Up @@ -661,7 +662,7 @@ def __init__(
shape: OutputShape,
extra_move: OutExtraMove,
field_json_schema_getter: Callable[[OutputField], JSONSchema],
field_default_dumper: Callable[[OutputField], JSONValue],
field_default_dumper: Callable[[OutputField], Omittable[JSONValue]],
placeholder_dumper: Callable[[Any], JSONValue],
):
self._shape = shape
Expand Down Expand Up @@ -700,9 +701,10 @@ def _convert_list_crown(self, crown: OutListCrown) -> JSONSchema:
def _convert_field_crown(self, crown: OutFieldCrown) -> JSONSchema:
field = self._shape.fields_dict[crown.id]
json_schema = self._field_json_schema_getter(field)
if field.default == NoDefault():
return json_schema
return replace(json_schema, default=self._field_default_dumper(field))
default = self._field_default_dumper(field)
if default != Omitted():
return replace(json_schema, default=default)
return json_schema

def _convert_none_crown(self, crown: OutNoneCrown) -> JSONSchema:
value = (
Expand Down
96 changes: 81 additions & 15 deletions src/adaptix/_internal/morphing/model/dumper_provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from typing import Mapping

from adaptix._internal.provider.fields import output_field_to_loc
from functools import partial
from typing import Any, Mapping, Sequence

from ...code_tools.compiler import BasicClosureCompiler, ClosureCompiler
from ...code_tools.name_sanitizer import BuiltinNameSanitizer, NameSanitizer
from ...common import Dumper
from ...definitions import DebugTrail
from ...model_tools.definitions import OutputShape
from ...provider.essential import Mediator
from ...definitions import DebugTrail, Direction
from ...model_tools.definitions import DefaultFactory, DefaultValue, OutputField, OutputShape
from ...provider.essential import CannotProvide, Mediator
from ...provider.fields import output_field_to_loc
from ...provider.located_request import LocatedRequest
from ...provider.shape_provider import OutputShapeRequest, provide_generic_resolved_shape
from ..provider_template import DumperProvider
from ...utils import Omittable, Omitted
from ..json_schema.definitions import JSONSchema
from ..json_schema.request_cls import JSONSchemaRequest
from ..json_schema.schema_model import JSONValue
from ..provider_template import DumperProvider, JSONSchemaProvider
from ..request_cls import DebugTrailRequest, DumperRequest
from .basic_gen import (
ModelDumperGen,
Expand All @@ -19,11 +24,11 @@
get_optional_fields_at_list_crown,
get_wild_extra_targets,
)
from .crown_definitions import OutputNameLayout, OutputNameLayoutRequest
from .dumper_gen import BuiltinModelDumperGen
from .crown_definitions import OutExtraMove, OutputNameLayout, OutputNameLayoutRequest
from .dumper_gen import BuiltinModelDumperGen, ModelOutputJSONSchemaGen


class ModelDumperProvider(DumperProvider):
class ModelDumperProvider(DumperProvider, JSONSchemaProvider):
def __init__(self, *, name_sanitizer: NameSanitizer = BuiltinNameSanitizer()):
self._name_sanitizer = name_sanitizer

Expand All @@ -40,7 +45,68 @@ def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
file_name=self._get_file_name(request),
)

def _fetch_model_dumper_gen(self, mediator: Mediator, request: DumperRequest) -> ModelDumperGen:
def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
if request.ctx.direction != Direction.OUTPUT:
raise CannotProvide

shape = self._fetch_shape(mediator, request)
name_layout = self._fetch_name_layout(mediator, request, shape)
self._validate_params(shape, name_layout)

schema_gen = self._get_schema_gen(mediator, request, shape, name_layout.extra_move)
return schema_gen.convert_crown(name_layout.crown)

def _get_schema_gen(
self,
mediator: Mediator,
request: JSONSchemaRequest,
shape: OutputShape,
extra_move: OutExtraMove,
) -> ModelOutputJSONSchemaGen:
return ModelOutputJSONSchemaGen(
shape=shape,
field_default_dumper=partial(self._dump_field_default, mediator, request),
field_json_schema_getter=partial(self._get_field_json_schema, mediator, request),
extra_move=extra_move,
placeholder_dumper=self._dump_placeholder,
)

def _dump_field_default(
self,
mediator: Mediator,
request: JSONSchemaRequest,
field: OutputField,
) -> Omittable[JSONValue]:
if isinstance(field.default, DefaultValue):
default_value = field.default.value
elif isinstance(field.default, DefaultFactory):
default_value = field.default.factory()
else:
return Omitted()

dumper = mediator.mandatory_provide(
DumperRequest(loc_stack=request.loc_stack.append_with(output_field_to_loc(field))),
)
return dumper(default_value)

def _dump_placeholder(self, data: Any) -> JSONValue:
if isinstance(data, Mapping):
return {str(self._dump_placeholder(key)): self._dump_placeholder(value) for key, value in data.items()}
if isinstance(data, Sequence):
return [self._dump_placeholder(element) for element in data]
if isinstance(data, (str, int, float, bool)) or data is None:
return data
raise TypeError(f"Can not dump placeholder {data}")

def _get_field_json_schema(
self,
mediator: Mediator,
request: JSONSchemaRequest,
field: OutputField,
) -> JSONSchema:
return mediator.mandatory_provide(request.append_loc(output_field_to_loc(field)))

def _fetch_model_dumper_gen(self, mediator: Mediator, request: LocatedRequest) -> ModelDumperGen:
shape = self._fetch_shape(mediator, request)
name_layout = self._fetch_name_layout(mediator, request, shape)
self._validate_params(shape, name_layout)
Expand All @@ -58,7 +124,7 @@ def _fetch_model_dumper_gen(self, mediator: Mediator, request: DumperRequest) ->
def _fetch_model_identity(
self,
mediator: Mediator,
request: DumperRequest,
request: LocatedRequest,
shape: OutputShape,
name_layout: OutputNameLayout,
) -> str:
Expand Down Expand Up @@ -102,10 +168,10 @@ def _get_closure_name(self, request: DumperRequest) -> str:
def _get_compiler(self) -> ClosureCompiler:
return BasicClosureCompiler()

def _fetch_shape(self, mediator: Mediator, request: DumperRequest) -> OutputShape:
def _fetch_shape(self, mediator: Mediator, request: LocatedRequest) -> OutputShape:
return provide_generic_resolved_shape(mediator, OutputShapeRequest(loc_stack=request.loc_stack))

def _fetch_name_layout(self, mediator: Mediator, request: DumperRequest, shape: OutputShape) -> OutputNameLayout:
def _fetch_name_layout(self, mediator: Mediator, request: LocatedRequest, shape: OutputShape) -> OutputNameLayout:
return mediator.delegating_provide(
OutputNameLayoutRequest(
loc_stack=request.loc_stack,
Expand All @@ -116,7 +182,7 @@ def _fetch_name_layout(self, mediator: Mediator, request: DumperRequest, shape:
def _fetch_field_dumpers(
self,
mediator: Mediator,
request: DumperRequest,
request: LocatedRequest,
shape: OutputShape,
) -> Mapping[str, Dumper]:
dumpers = mediator.mandatory_provide_by_iterable(
Expand Down
13 changes: 7 additions & 6 deletions src/adaptix/_internal/morphing/model/loader_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from ...common import Loader
from ...compat import CompatExceptionGroup
from ...definitions import DebugTrail
from ...model_tools.definitions import DefaultFactory, DefaultValue, InputField, InputShape, NoDefault, Param, ParamKind
from ...model_tools.definitions import DefaultFactory, DefaultValue, InputField, InputShape, Param, ParamKind
from ...special_cases_optimization import as_is_stub
from ...struct_trail import append_trail, extend_trail, render_trail_as_note
from ...utils import Omitted
from ...utils import Omittable, Omitted
from ..json_schema.definitions import JSONSchema
from ..json_schema.schema_model import JSONSchemaType, JSONValue
from ..load_error import (
Expand Down Expand Up @@ -789,7 +789,7 @@ def __init__(
self,
shape: InputShape,
field_json_schema_getter: Callable[[InputField], JSONSchema],
field_default_dumper: Callable[[InputField], JSONValue],
field_default_dumper: Callable[[InputField], Omittable[JSONValue]],
):
self._shape = shape
self._field_json_schema_getter = field_json_schema_getter
Expand Down Expand Up @@ -825,9 +825,10 @@ def _convert_list_crown(self, crown: InpListCrown) -> JSONSchema:
def _convert_field_crown(self, crown: InpFieldCrown) -> JSONSchema:
field = self._shape.fields_dict[crown.id]
json_schema = self._field_json_schema_getter(field)
if field.default == NoDefault():
return json_schema
return replace(json_schema, default=self._field_default_dumper(field))
default = self._field_default_dumper(field)
if default != Omitted():
return replace(json_schema, default=default)
return json_schema

def _convert_none_crown(self, crown: InpNoneCrown) -> JSONSchema:
return JSONSchema()
Expand Down
78 changes: 66 additions & 12 deletions src/adaptix/_internal/morphing/model/loader_provider.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from functools import partial
from typing import AbstractSet, Mapping

from adaptix._internal.provider.fields import input_field_to_loc

from ...code_tools.compiler import BasicClosureCompiler, ClosureCompiler
from ...code_tools.name_sanitizer import BuiltinNameSanitizer, NameSanitizer
from ...common import Loader
from ...definitions import DebugTrail
from ...model_tools.definitions import InputShape
from ...provider.essential import Mediator
from ...definitions import DebugTrail, Direction
from ...model_tools.definitions import DefaultFactory, DefaultValue, InputField, InputShape
from ...provider.essential import CannotProvide, Mediator
from ...provider.fields import input_field_to_loc
from ...provider.located_request import LocatedRequest
from ...provider.shape_provider import InputShapeRequest, provide_generic_resolved_shape
from ..model.loader_gen import BuiltinModelLoaderGen, ModelLoaderProps
from ..provider_template import LoaderProvider
from ..request_cls import DebugTrailRequest, LoaderRequest, StrictCoercionRequest
from ...utils import Omittable, Omitted
from ..json_schema.definitions import JSONSchema
from ..json_schema.request_cls import JSONSchemaRequest
from ..json_schema.schema_model import JSONValue
from ..model.loader_gen import BuiltinModelLoaderGen, ModelInputJSONSchemaGen, ModelLoaderProps
from ..provider_template import JSONSchemaProvider, LoaderProvider
from ..request_cls import DebugTrailRequest, DumperRequest, LoaderRequest, StrictCoercionRequest
from .basic_gen import (
ModelLoaderGen,
compile_closure_with_globals_capturing,
Expand All @@ -25,7 +30,7 @@
from .crown_definitions import InputNameLayout, InputNameLayoutRequest


class ModelLoaderProvider(LoaderProvider):
class ModelLoaderProvider(LoaderProvider, JSONSchemaProvider):
def __init__(
self,
*,
Expand All @@ -48,6 +53,55 @@ def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
file_name=self._get_file_name(request),
)

def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
if request.ctx.direction != Direction.INPUT:
raise CannotProvide

shape = self._fetch_shape(mediator, request)
name_layout = self._fetch_name_layout(mediator, request, shape)
skipped_fields = get_skipped_fields(shape, name_layout)
self._validate_params(shape, name_layout, skipped_fields)
schema_gen = self._get_schema_gen(mediator, request, shape)
return schema_gen.convert_crown(name_layout.crown)

def _get_schema_gen(
self,
mediator: Mediator,
request: JSONSchemaRequest,
shape: InputShape,
) -> ModelInputJSONSchemaGen:
return ModelInputJSONSchemaGen(
shape=shape,
field_default_dumper=partial(self._dump_field_default, mediator, request),
field_json_schema_getter=partial(self._get_field_json_schema, mediator, request),
)

def _dump_field_default(
self,
mediator: Mediator,
request: JSONSchemaRequest,
field: InputField,
) -> Omittable[JSONValue]:
if isinstance(field.default, DefaultValue):
default_value = field.default.value
elif isinstance(field.default, DefaultFactory):
default_value = field.default.factory()
else:
return Omitted()

dumper = mediator.mandatory_provide(
DumperRequest(loc_stack=request.loc_stack.append_with(input_field_to_loc(field))),
)
return dumper(default_value)

def _get_field_json_schema(
self,
mediator: Mediator,
request: JSONSchemaRequest,
field: InputField,
) -> JSONSchema:
return mediator.mandatory_provide(request.append_loc(input_field_to_loc(field)))

def _fetch_model_loader_gen(self, mediator: Mediator, request: LoaderRequest) -> ModelLoaderGen:
shape = self._fetch_shape(mediator, request)
name_layout = self._fetch_name_layout(mediator, request, shape)
Expand Down Expand Up @@ -120,10 +174,10 @@ def _get_closure_name(self, request: LoaderRequest) -> str:
def _get_compiler(self) -> ClosureCompiler:
return BasicClosureCompiler()

def _fetch_shape(self, mediator: Mediator, request: LoaderRequest) -> InputShape:
def _fetch_shape(self, mediator: Mediator, request: LocatedRequest) -> InputShape:
return provide_generic_resolved_shape(mediator, InputShapeRequest(loc_stack=request.loc_stack))

def _fetch_name_layout(self, mediator: Mediator, request: LoaderRequest, shape: InputShape) -> InputNameLayout:
def _fetch_name_layout(self, mediator: Mediator, request: LocatedRequest, shape: InputShape) -> InputNameLayout:
return mediator.mandatory_provide(
InputNameLayoutRequest(
loc_stack=request.loc_stack,
Expand Down Expand Up @@ -202,5 +256,5 @@ def __init__(
super().__init__(name_sanitizer=name_sanitizer, props=props)
self._shape = shape

def _fetch_shape(self, mediator: Mediator, request: LoaderRequest) -> InputShape:
def _fetch_shape(self, mediator: Mediator, request: LocatedRequest) -> InputShape:
return self._shape
2 changes: 1 addition & 1 deletion src/adaptix/_internal/morphing/provider_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class JSONSchemaProvider(LocatedRequestMethodsProvider, ABC):

@final
@method_handler
def generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
def provide_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
if request.ctx.dialect not in self.SUPPORTED_JSON_SCHEMA_DIALECTS:
raise CannotProvide(f"Dialect {request.ctx.dialect} is not supported for this type")
return self._generate_json_schema(mediator, request)
Expand Down

0 comments on commit b4f4e72

Please sign in to comment.