diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index a83907f8e..2665e940e 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -304,7 +304,9 @@ def _map_merge( list_merge_mode: ListMergeMode = ListMergeMode.REPLACE, ) -> None: """merge src into dest and return a new copy, does not modified input""" - from omegaconf import AnyNode, DictConfig, ValueNode + from omegaconf import AnyNode, DictConfig, ListConfig, ValueNode + + from ._utils import get_dict_key_value_types, get_list_element_type assert isinstance(dest, DictConfig) assert isinstance(src, DictConfig) @@ -380,16 +382,48 @@ def expand(node: Container) -> None: if dest_node is not None and dest_node._is_interpolation(): target_node = dest_node._maybe_dereference_node() if isinstance(target_node, Container): - dest[key] = target_node + dest.__setitem__(key, target_node) dest_node = dest._get_node(key) is_optional, et = _resolve_optional(dest._metadata.element_type) - if dest_node is None and is_structured_config(et) and not src_node_missing: - # merging into a new node. Use element_type as a base - dest[key] = DictConfig( - et, parent=dest, ref_type=et, is_optional=is_optional - ) - dest_node = dest._get_node(key) + if dest_node is None and not src_node_missing: + # check if merging into a new node + if is_structured_config(et): + # Use element_type as a base + dest.__setitem__( + key, + DictConfig( + et, parent=dest, ref_type=et, is_optional=is_optional + ), + ) + dest_node = dest._get_node(key) + elif is_dict_annotation(et): + key_type, element_type = get_dict_key_value_types(et) + dest.__setitem__( + key, + DictConfig( + {}, + parent=dest, + ref_type=et, + key_type=key_type, + element_type=element_type, + is_optional=is_optional, + ), + ) + dest_node = dest._get_node(key) + elif is_list_annotation(et): + element_type = get_list_element_type(et) + dest.__setitem__( + key, + ListConfig( + [], + parent=dest, + ref_type=et, + element_type=element_type, + is_optional=is_optional, + ), + ) + dest_node = dest._get_node(key) if dest_node is not None: if isinstance(dest_node, BaseContainer): @@ -428,9 +462,9 @@ def expand(node: Container) -> None: if is_structured_config(src_type): # verified to be compatible above in _validate_merge with open_dict(dest): - dest[key] = src._get_node(key) + dest.__setitem__(key, src._get_node(key)) else: - dest[key] = src._get_node(key) + dest.__setitem__(key, src._get_node(key)) _update_types(node=dest, ref_type=src_ref_type, object_type=src_type) @@ -449,6 +483,8 @@ def _list_merge( ) -> None: from omegaconf import DictConfig, ListConfig, OmegaConf + from ._utils import get_dict_key_value_types, get_list_element_type + assert isinstance(dest, ListConfig) assert isinstance(src, ListConfig) @@ -466,15 +502,32 @@ def _list_merge( dest.__dict__["_metadata"] ) is_optional, et = _resolve_optional(dest._metadata.element_type) + + prototype: Optional[Union[DictConfig, ListConfig]] = None + if is_structured_config(et): prototype = DictConfig(et, ref_type=et, is_optional=is_optional) - for item in src._iter_ex(resolve=False): - if isinstance(item, DictConfig): - item = OmegaConf.merge(prototype, item) - temp_target.append(item) - else: - for item in src._iter_ex(resolve=False): - temp_target.append(item) + elif is_dict_annotation(et): + key_type, element_type = get_dict_key_value_types(et) + prototype = DictConfig( + {}, + ref_type=et, + key_type=key_type, + element_type=element_type, + is_optional=is_optional, + ) + elif is_list_annotation(et): + element_type = get_list_element_type(et) + prototype = ListConfig( + [], ref_type=et, element_type=element_type, is_optional=is_optional + ) + + for item in src._iter_ex(resolve=False): + if prototype is not None: + item = OmegaConf.merge( + prototype, item, list_merge_mode=list_merge_mode + ) + temp_target.append(item) if list_merge_mode == ListMergeMode.EXTEND: dest.__dict__["_content"].extend(temp_target.__dict__["_content"]) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index f0b32d02c..c663cc521 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -260,6 +260,14 @@ class NestedWithAny: var: Any = Nested() +@attr.s(auto_attribs=True) +class DeeplyNestedUser: + dsdsdsu: Dict[str, Dict[str, Dict[str, User]]] + dsdslu: Dict[str, Dict[str, List[User]]] + lldsu: List[List[Dict[str, User]]] + lllu: List[List[List[User]]] + + @attr.s(auto_attribs=True) class NoDefaultValue: no_default: Any diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index b79640794..6065e2b29 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -263,6 +263,14 @@ class NestedWithAny: var: Any = field(default_factory=Nested) +@dataclass +class DeeplyNestedUser: + dsdsdsu: Dict[str, Dict[str, Dict[str, User]]] + dsdslu: Dict[str, Dict[str, List[User]]] + lldsu: List[List[Dict[str, User]]] + lllu: List[List[List[User]]] + + @dataclass class NoDefaultValue: no_default: Any diff --git a/tests/structured_conf/data/dataclasses_pre_311.py b/tests/structured_conf/data/dataclasses_pre_311.py index 1401d4324..4b74140c4 100644 --- a/tests/structured_conf/data/dataclasses_pre_311.py +++ b/tests/structured_conf/data/dataclasses_pre_311.py @@ -261,6 +261,14 @@ class NestedWithAny: var: Any = Nested() +@dataclass +class DeeplyNestedUser: + dsdsdsu: Dict[str, Dict[str, Dict[str, User]]] + dsdslu: Dict[str, Dict[str, List[User]]] + lldsu: List[List[Dict[str, User]]] + lllu: List[List[List[User]]] + + @dataclass class NoDefaultValue: no_default: Any diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 67ac2189f..a0099a855 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -24,7 +24,11 @@ ValidationError, _utils, ) -from omegaconf.errors import ConfigKeyError, InterpolationToMissingValueError +from omegaconf.errors import ( + ConfigKeyError, + ConfigTypeError, + InterpolationToMissingValueError, +) from tests import Color, Enum1, User, warns_dict_subclass_deprecated @@ -1892,6 +1896,64 @@ def test_assign_none( with raises(ValidationError): node[last_key] = value + @mark.parametrize( + "user_value, expected_error", + [ + param({"name": "Bond", "age": 7}, None, id="user-good"), + param({"cat": "Bond", "turnip": 7}, ConfigKeyError, id="user-bad-key"), + param({"name": "Bond", "age": "abc"}, ValidationError, id="user-bad-type1"), + param([1, 2, 3], ConfigTypeError, id="user-bad-type2"), + ], + ) + @mark.parametrize( + "key, make_nested, indices", + [ + param( + "dsdsdsu", + lambda uv: {"l1": {"l2": {"l3": uv}}}, + ["l1", "l2", "l3"], + id="dsdsdsu", + ), + param( + "dsdslu", + lambda uv: {"l1": {"l2": [uv]}}, + ["l1", "l2", 0], + id="dsdslu", + ), + param( + "lldsu", + lambda uv: [[{"l3": uv}]], + [0, 0, "l3"], + id="lldsu", + ), + param( + "lllu", + lambda uv: [[[uv]]], + [0, 0, 0], + id="lllu", + ), + ], + ) + def test_merge_with_deep_nesting( + self, module: Any, key: str, make_nested, indices, user_value, expected_error + ) -> None: + conf = OmegaConf.structured(module.DeeplyNestedUser) + src = {key: make_nested(user_value)} + + if expected_error is None: + merged = OmegaConf.merge(conf, src) + assert isinstance(merged, DictConfig) + value = merged[key] + for index in indices: + previous = value + value = value[index] + assert previous._metadata.key_type is type(index) + assert previous._metadata.element_type is value._metadata.ref_type + assert value._metadata.ref_type is module.User + else: + with raises(expected_error): + OmegaConf.merge(conf, src) + class TestUnionsOfPrimitiveTypes: @mark.parametrize( diff --git a/tests/test_nested_containers.py b/tests/test_nested_containers.py index 0be4a8984..90018e9f8 100644 --- a/tests/test_nested_containers.py +++ b/tests/test_nested_containers.py @@ -1408,7 +1408,7 @@ def test_merge_nested_list_promotion() -> None: ), param( [DictConfig({}, element_type=Dict[str, int]), {"foo": 123}], - "Value 123 (int) is incompatible with type hint 'typing.Dict[str, int]'", + "Cannot assign int to Dict[str, int]", id="merge-int-into-dict", ), param(