Skip to content

Commit

Permalink
__getitem__ and __getattr__ consistency
Browse files Browse the repository at this point in the history
DictConfig __getitem__ is now constent with plain dict (raises a KeyError if a key is not found instead of returning None)
DictConfig __getattr__ is now consistent with plain objects (raises an AttributeError if an attribute is not found instead of returning None).
  • Loading branch information
omry committed Feb 5, 2021
1 parent 0644fab commit d018678
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 48 deletions.
12 changes: 1 addition & 11 deletions docs/notebook/Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,6 @@
}
],
"source": [
"conf.missing_key or 'a default value'\n",
"conf.get('missing_key', 'a default value')"
]
},
Expand Down Expand Up @@ -890,17 +889,8 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.1"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
}
4 changes: 0 additions & 4 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,6 @@ You can provide default values directly in the accessing code:

.. doctest:: loaded

>>> # providing default values
>>> conf.missing_key or 'a default value'
'a default value'

>>> conf.get('missing_key', 'a default value')
'a default value'

Expand Down
1 change: 1 addition & 0 deletions news/515.api_change.1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DictConfig __getitem__ access, e.g. `cfg["foo"]`, is now raising a KeyError if the key "foo" does not exist
1 change: 1 addition & 0 deletions news/515.api_change.2
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DictConfig __getattr__ access, e.g. `cfg.foo`, is now raising a AttributeError if the key "foo" does not exist
3 changes: 2 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def _get_node(
self,
key: Any,
validate_access: bool = True,
throw_on_missing: bool = False,
throw_on_missing_value: bool = False,
throw_on_missing_key: bool = False,
) -> Union[Optional[Node], List[Optional[Node]]]:
...

Expand Down
25 changes: 14 additions & 11 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,15 +347,15 @@ def __getattr__(self, key: str) -> Any:
:param key:
:return:
"""
# PyCharm is sometimes inspecting __members__, be sure to tell it we don't have that.
if key == "__members__":
raise AttributeError()

if key == "__name__":
raise AttributeError()

try:
return self._get_impl(key=key, default_value=DEFAULT_VALUE_MARKER)
except ConfigKeyError as e:
self._format_and_raise(
key=key, value=None, cause=e, type_override=ConfigAttributeError
)
except Exception as e:
self._format_and_raise(key=key, value=None, cause=e)

Expand Down Expand Up @@ -413,9 +413,9 @@ def get(self, key: DictKeyType, default_value: Any = DEFAULT_VALUE_MARKER) -> An

def _get_impl(self, key: DictKeyType, default_value: Any) -> Any:
try:
node = self._get_node(key=key)
except ConfigAttributeError:
if default_value != DEFAULT_VALUE_MARKER:
node = self._get_node(key=key, throw_on_missing_key=True)
except (ConfigAttributeError, ConfigKeyError):
if default_value is not DEFAULT_VALUE_MARKER:
node = default_value
else:
raise
Expand All @@ -427,7 +427,8 @@ def _get_node(
self,
key: DictKeyType,
validate_access: bool = True,
throw_on_missing: bool = False,
throw_on_missing_value: bool = False,
throw_on_missing_key: bool = False,
) -> Union[Optional[Node], List[Optional[Node]]]:
try:
key = self._validate_and_normalize_key(key)
Expand All @@ -440,10 +441,12 @@ def _get_node(
if validate_access:
self._validate_get(key)

value: Node = self.__dict__["_content"].get(key)
if throw_on_missing and value._is_missing():
value: Optional[Node] = self.__dict__["_content"].get(key)
if value is None:
if throw_on_missing_key:
raise ConfigKeyError(f"Missing key {key}")
elif throw_on_missing_value and value._is_missing():
raise MissingMandatoryValue("Missing mandatory value")

return value

def pop(self, key: DictKeyType, default: Any = DEFAULT_VALUE_MARKER) -> Any:
Expand Down
9 changes: 5 additions & 4 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def _get_node(
self,
key: Union[int, slice],
validate_access: bool = True,
throw_on_missing: bool = False,
throw_on_missing_value: bool = False,
throw_on_missing_key: bool = False,
) -> Union[Optional[Node], List[Optional[Node]]]:
try:
if self._is_none():
Expand All @@ -390,15 +391,15 @@ def _get_node(
if isinstance(key, slice):
assert isinstance(value, list)
for v in value:
if throw_on_missing and v._is_missing():
if throw_on_missing_value and v._is_missing():
raise MissingMandatoryValue("Missing mandatory value")
else:
assert isinstance(value, Node)
if throw_on_missing and value._is_missing():
if throw_on_missing_value and value._is_missing():
raise MissingMandatoryValue("Missing mandatory value")
return value
except (IndexError, TypeError, MissingMandatoryValue, KeyValidationError) as e:
if isinstance(e, MissingMandatoryValue) and throw_on_missing:
if isinstance(e, MissingMandatoryValue) and throw_on_missing_value:
raise
if validate_access:
self._format_and_raise(key=key, value=None, cause=e)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def test_get_node(cfg: Any, key: Any, expected: Any) -> None:
pytest.param([10, "???", 30], slice(1, 2), id="list_slice"),
],
)
def test_get_node_throw_on_missing(cfg: Any, key: Any) -> None:
def test_get_node_throw_on_missing_value(cfg: Any, key: Any) -> None:
cfg = OmegaConf.create(cfg)
with pytest.raises(MissingMandatoryValue, match="Missing mandatory value"):
cfg._get_node(key, throw_on_missing=True)
cfg._get_node(key, throw_on_missing_value=True)
35 changes: 26 additions & 9 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
open_dict,
)
from omegaconf.basecontainer import BaseContainer
from omegaconf.errors import ConfigKeyError, ConfigTypeError, KeyValidationError
from omegaconf.errors import (
ConfigAttributeError,
ConfigKeyError,
ConfigTypeError,
KeyValidationError,
)
from tests import (
ConcretePlugin,
Enum1,
Expand Down Expand Up @@ -134,13 +139,14 @@ def test_dict_struct_delitem(
assert key not in c


def test_default_value() -> None:
def test_attribute_error() -> None:
c = OmegaConf.create()
assert c.missing_key or "a default value" == "a default value"
with pytest.raises(ConfigAttributeError):
assert c.missing_key


def test_get_default_value() -> None:
c = OmegaConf.create()
@pytest.mark.parametrize("c", [{}, OmegaConf.create()])
def test_get_default_value(c: Any) -> None:
assert c.get("missing_key", "a default value") == "a default value"


Expand Down Expand Up @@ -611,16 +617,16 @@ def test_creation_with_invalid_key() -> None:
OmegaConf.create({object(): "a"})


def test_set_with_invalid_key() -> None:
def test_setitem_with_invalid_key() -> None:
cfg = OmegaConf.create()
with pytest.raises(KeyValidationError):
cfg[object()] # type: ignore
cfg.__setitem__(object(), "a") # type: ignore


def test_get_with_invalid_key() -> None:
def test_getitem_with_invalid_key() -> None:
cfg = OmegaConf.create()
with pytest.raises(KeyValidationError):
cfg[object()] # type: ignore
cfg.__getitem__(object()) # type: ignore


def test_hasattr() -> None:
Expand Down Expand Up @@ -873,3 +879,14 @@ def test_assign_to_sc_field_without_ref_type() -> None:

cfg.plugin = 10
assert cfg.plugin == 10


def test_dict_getitem_not_found() -> None:
cfg = OmegaConf.create()
with pytest.raises(ConfigKeyError):
cfg["aaa"]


def test_dict_getitem_none_output() -> None:
cfg = OmegaConf.create({"a": None})
assert cfg["a"] is None
4 changes: 2 additions & 2 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def finalize(self, cfg: Any) -> None:
Expected(
create=lambda: OmegaConf.create({"foo": "${missing}"}),
op=lambda cfg: getattr(cfg, "foo"),
exception_type=ConfigKeyError,
exception_type=ConfigAttributeError,
msg="str interpolation key 'missing' not found",
key="foo",
child_node=lambda cfg: cfg._get_node("foo"),
Expand All @@ -207,7 +207,7 @@ def finalize(self, cfg: Any) -> None:
Expected(
create=lambda: OmegaConf.create({"foo": "foo_${missing}"}),
op=lambda cfg: getattr(cfg, "foo"),
exception_type=ConfigKeyError,
exception_type=ConfigAttributeError,
msg="str interpolation key 'missing' not found",
key="foo",
child_node=lambda cfg: cfg._get_node("foo"),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ValidationError,
)
from omegaconf._utils import _ensure_container
from omegaconf.errors import ConfigKeyError, OmegaConfBaseException
from omegaconf.errors import ConfigAttributeError, OmegaConfBaseException


@pytest.mark.parametrize(
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_merge_with_interpolation() -> None:

def test_non_container_interpolation() -> None:
cfg = OmegaConf.create(dict(foo=0, bar="${foo.baz}"))
with pytest.raises(ConfigKeyError):
with pytest.raises(ConfigAttributeError):
cfg.bar


Expand Down
2 changes: 1 addition & 1 deletion tests/test_omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
@pytest.mark.parametrize(
"cfg, key, expected_is_missing, expectation",
[
({}, "foo", False, does_not_raise()),
({}, "foo", False, raises(ConfigKeyError)),
({"foo": True}, "foo", False, does_not_raise()),
({"foo": "${no_such_key}"}, "foo", False, raises(ConfigKeyError)),
({"foo": MISSING}, "foo", True, raises(MissingMandatoryValue)),
Expand Down
1 change: 0 additions & 1 deletion tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

def test_struct_default() -> None:
c = OmegaConf.create()
assert c.not_found is None
assert OmegaConf.is_struct(c) is None


Expand Down

0 comments on commit d018678

Please sign in to comment.