Skip to content

Commit

Permalink
Merge pull request #346 from reagento/fix/structure-flatenning-missin…
Browse files Browse the repository at this point in the history
…g-field

fix error on loading missing field with name flattening
  • Loading branch information
zhPavel authored Nov 3, 2024
2 parents e2bb2d0 + dc96b57 commit 422a501
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Unexpected error is replaced with ``NoRequiredFieldsLoadError`` for fields generated by name flattening
150 changes: 79 additions & 71 deletions src/adaptix/_internal/morphing/model/loader_gen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import collections.abc
import contextlib
from collections.abc import Mapping, Set
from contextlib import AbstractContextManager, contextmanager, nullcontext
from dataclasses import dataclass, replace
from typing import Callable, Optional
from typing import Any, Callable, Optional

from ...code_tools.cascade_namespace import BuiltinCascadeNamespace, CascadeNamespace
from ...code_tools.code_builder import CodeBuilder
Expand Down Expand Up @@ -147,7 +147,7 @@ def parent_path(self) -> CrownPath:
def parent_crown(self) -> BranchInpCrown:
return self._crown_stack[-2] # type: ignore[return-value]

@contextlib.contextmanager
@contextmanager
def add_key(self, crown: InpCrown, key: CrownPathElem):
past = self._path
past_parent = self._parent_path
Expand Down Expand Up @@ -475,7 +475,7 @@ def _gen_add_self_extra_to_parent_extra(self, state: GenState):
state.builder(f"{state.parent.v_extra}[{state.path[-1]!r}] = {state.v_extra}")
state.builder.empty_line()

@contextlib.contextmanager
@contextmanager
def _maybe_wrap_with_type_load_error_catching(self, state: GenState):
if self._debug_trail != DebugTrail.ALL or not state.path:
yield
Expand Down Expand Up @@ -504,38 +504,42 @@ def _gen_dict_crown(self, state: GenState, crown: InpDictCrown):
if state.path:
self._gen_assignment_from_parent_data(state, assign_to=state.v_data)
state.builder.empty_line()

if self._can_collect_extra:
state.builder += f"{state.v_extra} = {{}}"
if self._debug_trail == DebugTrail.ALL:
state.builder += f"{state.v_has_not_found_error} = False"

with self._maybe_wrap_with_type_load_error_catching(state):
for key, value in crown.map.items():
self._gen_crown_dispatch(state, value, key)

if state.path not in state.type_checked_type_paths:
with state.builder(f"if not isinstance({state.v_data}, CollectionsMapping):"):
self._gen_raise_bad_type_error(state, f"TypeLoadError(CollectionsMapping, {state.v_data})")
state.builder.empty_line()
state.type_checked_type_paths.add(state.path)

if crown.extra_policy == ExtraForbid():
state.builder += f"""
{state.v_extra}_set = set({state.v_data}) - {state.v_known_keys}
if {state.v_extra}_set:
{state.emit_error(f"ExtraFieldsLoadError({state.v_extra}_set, {state.v_data})")}
"""
state.builder.empty_line()
elif crown.extra_policy == ExtraCollect():
state.builder += f"""
for key in set({state.v_data}) - {state.v_known_keys}:
{state.v_extra}[key] = {state.v_data}[key]
"""
state.builder.empty_line()

if self._can_collect_extra:
self._gen_add_self_extra_to_parent_extra(state)
ctx: AbstractContextManager[Any] = state.builder("else:")
else:
ctx = nullcontext()

with ctx:
if self._can_collect_extra:
state.builder += f"{state.v_extra} = {{}}"
if self._debug_trail == DebugTrail.ALL:
state.builder += f"{state.v_has_not_found_error} = False"

with self._maybe_wrap_with_type_load_error_catching(state):
for key, value in crown.map.items():
self._gen_crown_dispatch(state, value, key)

if state.path not in state.type_checked_type_paths:
with state.builder(f"if not isinstance({state.v_data}, CollectionsMapping):"):
self._gen_raise_bad_type_error(state, f"TypeLoadError(CollectionsMapping, {state.v_data})")
state.builder.empty_line()
state.type_checked_type_paths.add(state.path)

if crown.extra_policy == ExtraForbid():
state.builder += f"""
{state.v_extra}_set = set({state.v_data}) - {state.v_known_keys}
if {state.v_extra}_set:
{state.emit_error(f"ExtraFieldsLoadError({state.v_extra}_set, {state.v_data})")}
"""
state.builder.empty_line()
elif crown.extra_policy == ExtraCollect():
state.builder += f"""
for key in set({state.v_data}) - {state.v_known_keys}:
{state.v_extra}[key] = {state.v_data}[key]
"""
state.builder.empty_line()

if self._can_collect_extra:
self._gen_add_self_extra_to_parent_extra(state)

def _gen_forbidden_sequence_check(self, state: GenState) -> None:
with state.builder(f"if type({state.v_data}) is str:"):
Expand All @@ -545,44 +549,48 @@ def _gen_list_crown(self, state: GenState, crown: InpListCrown):
if state.path:
self._gen_assignment_from_parent_data(state, assign_to=state.v_data)
state.builder.empty_line()

if self._can_collect_extra:
list_literal: list = [
{} if isinstance(sub_crown, (InpFieldCrown, InpNoneCrown)) else None
for sub_crown in crown.map
]
state.builder(f"{state.v_extra} = {list_literal!r}")

with self._maybe_wrap_with_type_load_error_catching(state):
if self._strict_coercion:
self._gen_forbidden_sequence_check(state)

for key, value in enumerate(crown.map):
self._gen_crown_dispatch(state, value, key)

if state.path not in state.type_checked_type_paths:
with state.builder(f"if not isinstance({state.v_data}, CollectionsSequence):"):
self._gen_raise_bad_type_error(state, f"TypeLoadError(CollectionsSequence, {state.v_data})")
state.builder.empty_line()
state.type_checked_type_paths.add(state.path)

expected_len = len(crown.map)
if crown.extra_policy == ExtraForbid():
state.builder += f"""
if len({state.v_data}) != {expected_len}:
ctx: AbstractContextManager[Any] = state.builder("else:")
else:
ctx = nullcontext()

with ctx:
if self._can_collect_extra:
list_literal: list = [
{} if isinstance(sub_crown, (InpFieldCrown, InpNoneCrown)) else None
for sub_crown in crown.map
]
state.builder(f"{state.v_extra} = {list_literal!r}")

with self._maybe_wrap_with_type_load_error_catching(state):
if self._strict_coercion:
self._gen_forbidden_sequence_check(state)

for key, value in enumerate(crown.map):
self._gen_crown_dispatch(state, value, key)

if state.path not in state.type_checked_type_paths:
with state.builder(f"if not isinstance({state.v_data}, CollectionsSequence):"):
self._gen_raise_bad_type_error(state, f"TypeLoadError(CollectionsSequence, {state.v_data})")
state.builder.empty_line()
state.type_checked_type_paths.add(state.path)

expected_len = len(crown.map)
if crown.extra_policy == ExtraForbid():
state.builder += f"""
if len({state.v_data}) != {expected_len}:
if len({state.v_data}) < {expected_len}:
{state.emit_error(f"NoRequiredItemsLoadError({expected_len}, {state.v_data})")}
else:
{state.emit_error(f"ExtraItemsLoadError({expected_len}, {state.v_data})")}
"""
else:
state.builder += f"""
if len({state.v_data}) < {expected_len}:
{state.emit_error(f"NoRequiredItemsLoadError({expected_len}, {state.v_data})")}
else:
{state.emit_error(f"ExtraItemsLoadError({expected_len}, {state.v_data})")}
"""
else:
state.builder += f"""
if len({state.v_data}) < {expected_len}:
{state.emit_error(f"NoRequiredItemsLoadError({expected_len}, {state.v_data})")}
"""
"""

if self._can_collect_extra:
self._gen_add_self_extra_to_parent_extra(state)
if self._can_collect_extra:
self._gen_add_self_extra_to_parent_extra(state)

def _get_default_clause_expr(self, state: GenState, field: InputField) -> str:
if isinstance(field.default, DefaultValue):
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/morphing/model/test_loader_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def getter():
ValueProvider(InputShapeRequest, shape),
ValueProvider(InputNameLayoutRequest, name_layout),
bound(int, ValueProvider(LoaderRequest, int_loader)),
debug_ctx.accum,
],
)
return retort.replace(
Expand Down Expand Up @@ -928,6 +929,18 @@ def test_structure_flattening(debug_ctx, debug_trail, trail_select):
),
)

raises_exc(
trail_select(
disable=NoRequiredFieldsLoadError(fields={"w", "v", "z"}, input_value={}),
first=with_trail(NoRequiredFieldsLoadError(fields={"w", "v", "z"}, input_value={}), []),
all=AggregateLoadError(
f"while loading model {Gauge}",
[with_trail(NoRequiredFieldsLoadError(fields={"w", "v", "z"}, input_value={}), [])],
),
),
lambda: loader({}),
)


def _replace_value_by_path(data, path, new_value):
sub_data = data
Expand Down

0 comments on commit 422a501

Please sign in to comment.