Skip to content

Commit

Permalink
Improve handling of interpolations pointing to missing nodes
Browse files Browse the repository at this point in the history
* Interpolations are never considered to be missing anymore, even if
  they point to a missing node

* When resolving an expression containing an interpolation pointing to a
  missing node, if `throw_on_missing` is False, then the result is
  always `None` (while it used to either be a missing Node, or an
  expression computed from the "???" string)

* Similarly, this commit also ensures that if
  `throw_on_resolution_failure` is False, then resolving an
  interpolation resulting in a resolution failure always leads to the
  result being `None` (instead of potentially being an expression computed
  from `None`)

Fixes omry#543
  • Loading branch information
odelalleau committed Feb 26, 2021
1 parent 90d64e3 commit af47aa5
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 59 deletions.
File renamed without changes.
1 change: 0 additions & 1 deletion news/462.api_change.2

This file was deleted.

1 change: 1 addition & 0 deletions news/543.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Interpolations interpolating a missing key no longer count as missing.
80 changes: 31 additions & 49 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,13 +410,20 @@ def _resolve_interpolation_from_parse_tree(
) -> Optional["Node"]:
from .nodes import StringNode

resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
key=key,
parent=parent,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
try:
resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
key=key,
parent=parent,
)
except MissingMandatoryValue:
if throw_on_missing:
raise
return None
except Exception:
if throw_on_resolution_failure:
raise
return None

if resolved is None:
return None
Expand All @@ -435,36 +442,28 @@ def _resolve_interpolation_from_parse_tree(
def _resolve_node_interpolation(
self,
inter_key: str,
throw_on_missing: bool,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
) -> "Node":
"""A node interpolation is of the form `${foo.bar}`"""
root_node, inter_key = self._resolve_key_and_root(inter_key)
parent, last_key, value = root_node._select_impl(
inter_key,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
throw_on_missing=True,
throw_on_resolution_failure=True,
)

if parent is None or value is None:
if throw_on_resolution_failure:
raise InterpolationResolutionError(
f"Interpolation key '{inter_key}' not found"
)
else:
return None
assert isinstance(value, Node)
return value
raise InterpolationResolutionError(
f"Interpolation key '{inter_key}' not found"
)
else:
return value

def _evaluate_custom_resolver(
self,
key: Any,
inter_type: str,
inter_args: Tuple[Any, ...],
throw_on_missing: bool,
throw_on_resolution_failure: bool,
inter_args_str: Tuple[str, ...],
) -> Optional["Node"]:
) -> Any:
from omegaconf import OmegaConf

from .nodes import ValueNode
Expand All @@ -482,18 +481,11 @@ def _evaluate_custom_resolver(
),
)
except Exception as e:
if throw_on_resolution_failure:
self._format_and_raise(key=None, value=None, cause=e)
assert False
else:
return None
self._format_and_raise(key=None, value=None, cause=e)
else:
if throw_on_resolution_failure:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
)
else:
return None
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
)

def _maybe_resolve_interpolation(
self,
Expand Down Expand Up @@ -522,8 +514,6 @@ def resolve_parse_tree(
parse_tree: ParserRuleContext,
key: Optional[Any] = None,
parent: Optional["Container"] = None,
throw_on_missing: bool = True,
throw_on_resolution_failure: bool = True,
) -> Any:
"""
Resolve a given parse tree into its value.
Expand All @@ -533,26 +523,17 @@ def resolve_parse_tree(
"""
from .nodes import StringNode

# Common arguments to all callbacks.
callback_args: Dict[str, Any] = dict(
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)

def node_interpolation_callback(inter_key: str) -> Optional["Node"]:
return self._resolve_node_interpolation(
inter_key=inter_key, **callback_args
)
return self._resolve_node_interpolation(inter_key=inter_key)

def resolver_interpolation_callback(
name: str, args: Tuple[Any, ...], args_str: Tuple[str, ...]
) -> Optional["Node"]:
) -> Any:
return self._evaluate_custom_resolver(
key=key,
inter_type=name,
inter_args=args,
inter_args_str=args_str,
**callback_args,
)

def quoted_string_callback(quoted_str: str) -> str:
Expand All @@ -565,7 +546,8 @@ def quoted_string_callback(quoted_str: str) -> str:
parent=parent,
is_optional=False,
),
**callback_args,
throw_on_missing=True,
throw_on_resolution_failure=True,
)
return str(quoted_val)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,10 +802,10 @@ def test_is_missing() -> None:
}
)
assert cfg._get_node("foo")._is_missing() # type: ignore
assert cfg._get_node("inter")._is_missing() # type: ignore
assert not cfg._get_node("inter")._is_missing() # type: ignore
assert not cfg._get_node("str_inter")._is_missing() # type: ignore
assert cfg._get_node("missing_node")._is_missing() # type: ignore
assert cfg._get_node("missing_node_inter")._is_missing() # type: ignore
assert not cfg._get_node("missing_node_inter")._is_missing() # type: ignore


@pytest.mark.parametrize("ref_type", [None, Any])
Expand Down
22 changes: 21 additions & 1 deletion tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def test_interpolation_with_missing() -> None:
"x": {"missing": "???"},
}
)
assert OmegaConf.is_missing(cfg.x, "missing")
assert not OmegaConf.is_missing(cfg, "a")
assert OmegaConf.is_missing(cfg, "b")
assert not OmegaConf.is_missing(cfg, "b")


def test_assign_to_interpolation() -> None:
Expand Down Expand Up @@ -709,3 +710,22 @@ def test_empty_stack() -> None:
"""
with pytest.raises(GrammarParseError):
grammar_parser.parse("ab}", lexer_mode="VALUE_MODE")


@pytest.mark.parametrize("ref", ["missing", "invalid"])
def test_invalid_intermediate_result_when_not_throwing(
ref: str, restore_resolvers: Any
) -> None:
"""
Check that the resolution of nested interpolations stops on missing / resolution failure.
"""
OmegaConf.register_new_resolver("check_none", lambda x: x is None)
cfg = OmegaConf.create({"x": f"${{check_none:${{{ref}}}}}", "missing": "???"})
x_node = cfg._get_node("x")
assert isinstance(x_node, Node)
assert (
x_node._dereference_node(
throw_on_missing=False, throw_on_resolution_failure=False
)
is None
)
4 changes: 2 additions & 2 deletions tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def test_interpolation(
inter=True,
exp=None,
)
verify(cfg, "int_missing", none=False, opt=True, missing=True, inter=True)
verify(cfg, "int_missing", none=False, opt=True, missing=False, inter=True)
verify(
cfg, "int_opt_missing", none=False, opt=True, missing=True, inter=True
cfg, "int_opt_missing", none=False, opt=True, missing=False, inter=True
)

verify(
Expand Down
21 changes: 17 additions & 4 deletions tests/test_omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
IntegerNode,
ListConfig,
MissingMandatoryValue,
Node,
OmegaConf,
StringNode,
)
Expand Down Expand Up @@ -39,7 +40,7 @@
pytest.param(
{"foo": "${bar}", "bar": DictConfig(content=MISSING)},
"foo",
True,
False,
raises(MissingMandatoryValue),
id="missing_interpolated_dict",
),
Expand All @@ -57,7 +58,12 @@
raises(MissingMandatoryValue),
id="missing_dict",
),
({"foo": "${bar}", "bar": MISSING}, "foo", True, raises(MissingMandatoryValue)),
(
{"foo": "${bar}", "bar": MISSING},
"foo",
False,
raises(MissingMandatoryValue),
),
(
{"foo": "foo_${bar}", "bar": MISSING},
"foo",
Expand All @@ -74,7 +80,7 @@
(
{"foo": StringNode(value="???"), "inter": "${foo}"},
"inter",
True,
False,
raises(MissingMandatoryValue),
),
(StructuredWithMissing, "num", True, raises(MissingMandatoryValue)),
Expand All @@ -87,7 +93,7 @@
(StructuredWithMissing, "opt_user", True, raises(MissingMandatoryValue)),
(StructuredWithMissing, "inter_user", True, raises(MissingMandatoryValue)),
(StructuredWithMissing, "inter_opt_user", True, raises(MissingMandatoryValue)),
(StructuredWithMissing, "inter_num", True, raises(MissingMandatoryValue)),
(StructuredWithMissing, "inter_num", False, raises(MissingMandatoryValue)),
],
)
def test_is_missing(
Expand Down Expand Up @@ -117,6 +123,13 @@ def test_is_missing_resets() -> None:
assert OmegaConf.is_missing(cfg, "list")


def test_dereference_interpolation_to_missing() -> None:
cfg = OmegaConf.create({"x": "${y}", "y": "???"})
x_node = cfg._get_node("x")
assert isinstance(x_node, Node)
assert x_node._dereference_node() is None


@pytest.mark.parametrize(
"cfg, expected",
[
Expand Down

0 comments on commit af47aa5

Please sign in to comment.