Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Nested structured config validation #1133

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
87 changes: 70 additions & 17 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def _map_merge(
list_merge_mode: ListMergeMode = ListMergeMode.REPLACE,
) -> None:
"""merge src into dest and return a new copy, does not modified input"""
from omegaconf import AnyNode, DictConfig, ValueNode
from omegaconf import AnyNode, DictConfig, ListConfig, ValueNode

from ._utils import get_dict_key_value_types, get_list_element_type

assert isinstance(dest, DictConfig)
assert isinstance(src, DictConfig)
Expand Down Expand Up @@ -380,16 +382,48 @@ def expand(node: Container) -> None:
if dest_node is not None and dest_node._is_interpolation():
target_node = dest_node._maybe_dereference_node()
if isinstance(target_node, Container):
dest[key] = target_node
dest.__setitem__(key, target_node)
dest_node = dest._get_node(key)

is_optional, et = _resolve_optional(dest._metadata.element_type)
if dest_node is None and is_structured_config(et) and not src_node_missing:
# merging into a new node. Use element_type as a base
dest[key] = DictConfig(
et, parent=dest, ref_type=et, is_optional=is_optional
)
dest_node = dest._get_node(key)
if dest_node is None and not src_node_missing:
# check if merging into a new node
if is_structured_config(et):
# Use element_type as a base
dest.__setitem__(
key,
DictConfig(
et, parent=dest, ref_type=et, is_optional=is_optional
),
)
dest_node = dest._get_node(key)
elif is_dict_annotation(et):
key_type, element_type = get_dict_key_value_types(et)
dest.__setitem__(
key,
DictConfig(
{},
parent=dest,
ref_type=et,
key_type=key_type,
element_type=element_type,
is_optional=is_optional,
),
)
dest_node = dest._get_node(key)
elif is_list_annotation(et):
element_type = get_list_element_type(et)
dest.__setitem__(
key,
ListConfig(
[],
parent=dest,
ref_type=et,
element_type=element_type,
is_optional=is_optional,
),
)
dest_node = dest._get_node(key)

if dest_node is not None:
if isinstance(dest_node, BaseContainer):
Expand Down Expand Up @@ -428,9 +462,9 @@ def expand(node: Container) -> None:
if is_structured_config(src_type):
# verified to be compatible above in _validate_merge
with open_dict(dest):
dest[key] = src._get_node(key)
dest.__setitem__(key, src._get_node(key))
else:
dest[key] = src._get_node(key)
dest.__setitem__(key, src._get_node(key))

_update_types(node=dest, ref_type=src_ref_type, object_type=src_type)

Expand All @@ -449,6 +483,8 @@ def _list_merge(
) -> None:
from omegaconf import DictConfig, ListConfig, OmegaConf

from ._utils import get_dict_key_value_types, get_list_element_type

assert isinstance(dest, ListConfig)
assert isinstance(src, ListConfig)

Expand All @@ -466,15 +502,32 @@ def _list_merge(
dest.__dict__["_metadata"]
)
is_optional, et = _resolve_optional(dest._metadata.element_type)

prototype: Optional[Union[DictConfig, ListConfig]] = None

if is_structured_config(et):
prototype = DictConfig(et, ref_type=et, is_optional=is_optional)
for item in src._iter_ex(resolve=False):
if isinstance(item, DictConfig):
item = OmegaConf.merge(prototype, item)
temp_target.append(item)
else:
for item in src._iter_ex(resolve=False):
temp_target.append(item)
elif is_dict_annotation(et):
key_type, element_type = get_dict_key_value_types(et)
prototype = DictConfig(
{},
ref_type=et,
key_type=key_type,
element_type=element_type,
is_optional=is_optional,
)
elif is_list_annotation(et):
element_type = get_list_element_type(et)
prototype = ListConfig(
[], ref_type=et, element_type=element_type, is_optional=is_optional
)

for item in src._iter_ex(resolve=False):
if prototype is not None:
item = OmegaConf.merge(
prototype, item, list_merge_mode=list_merge_mode
)
temp_target.append(item)

if list_merge_mode == ListMergeMode.EXTEND:
dest.__dict__["_content"].extend(temp_target.__dict__["_content"])
Expand Down
8 changes: 8 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ class NestedWithAny:
var: Any = Nested()


@attr.s(auto_attribs=True)
class DeeplyNestedUser:
dsdsdsu: Dict[str, Dict[str, Dict[str, User]]]
dsdslu: Dict[str, Dict[str, List[User]]]
lldsu: List[List[Dict[str, User]]]
lllu: List[List[List[User]]]


@attr.s(auto_attribs=True)
class NoDefaultValue:
no_default: Any
Expand Down
8 changes: 8 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ class NestedWithAny:
var: Any = field(default_factory=Nested)


@dataclass
class DeeplyNestedUser:
dsdsdsu: Dict[str, Dict[str, Dict[str, User]]]
dsdslu: Dict[str, Dict[str, List[User]]]
lldsu: List[List[Dict[str, User]]]
lllu: List[List[List[User]]]


@dataclass
class NoDefaultValue:
no_default: Any
Expand Down
8 changes: 8 additions & 0 deletions tests/structured_conf/data/dataclasses_pre_311.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ class NestedWithAny:
var: Any = Nested()


@dataclass
class DeeplyNestedUser:
dsdsdsu: Dict[str, Dict[str, Dict[str, User]]]
dsdslu: Dict[str, Dict[str, List[User]]]
lldsu: List[List[Dict[str, User]]]
lllu: List[List[List[User]]]


@dataclass
class NoDefaultValue:
no_default: Any
Expand Down
64 changes: 63 additions & 1 deletion tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
ValidationError,
_utils,
)
from omegaconf.errors import ConfigKeyError, InterpolationToMissingValueError
from omegaconf.errors import (
ConfigKeyError,
ConfigTypeError,
InterpolationToMissingValueError,
)
from tests import Color, Enum1, User, warns_dict_subclass_deprecated


Expand Down Expand Up @@ -1892,6 +1896,64 @@ def test_assign_none(
with raises(ValidationError):
node[last_key] = value

@mark.parametrize(
"user_value, expected_error",
[
param({"name": "Bond", "age": 7}, None, id="user-good"),
param({"cat": "Bond", "turnip": 7}, ConfigKeyError, id="user-bad-key"),
param({"name": "Bond", "age": "abc"}, ValidationError, id="user-bad-type1"),
param([1, 2, 3], ConfigTypeError, id="user-bad-type2"),
],
)
@mark.parametrize(
"key, make_nested, indices",
[
param(
"dsdsdsu",
lambda uv: {"l1": {"l2": {"l3": uv}}},
["l1", "l2", "l3"],
id="dsdsdsu",
),
param(
"dsdslu",
lambda uv: {"l1": {"l2": [uv]}},
["l1", "l2", 0],
id="dsdslu",
),
param(
"lldsu",
lambda uv: [[{"l3": uv}]],
[0, 0, "l3"],
id="lldsu",
),
param(
"lllu",
lambda uv: [[[uv]]],
[0, 0, 0],
id="lllu",
),
],
)
def test_merge_with_deep_nesting(
self, module: Any, key: str, make_nested, indices, user_value, expected_error
) -> None:
conf = OmegaConf.structured(module.DeeplyNestedUser)
src = {key: make_nested(user_value)}

if expected_error is None:
merged = OmegaConf.merge(conf, src)
assert isinstance(merged, DictConfig)
value = merged[key]
for index in indices:
previous = value
value = value[index]
assert previous._metadata.key_type is type(index)
assert previous._metadata.element_type is value._metadata.ref_type
assert value._metadata.ref_type is module.User
else:
with raises(expected_error):
OmegaConf.merge(conf, src)


class TestUnionsOfPrimitiveTypes:
@mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nested_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ def test_merge_nested_list_promotion() -> None:
),
param(
[DictConfig({}, element_type=Dict[str, int]), {"foo": 123}],
"Value 123 (int) is incompatible with type hint 'typing.Dict[str, int]'",
"Cannot assign int to Dict[str, int]",
id="merge-int-into-dict",
),
param(
Expand Down