Skip to content

Commit

Permalink
Merge pull request mmistakes#120 from pesser/fixretrieve
Browse files Browse the repository at this point in the history
Fix retrieve
  • Loading branch information
jhaux authored Aug 7, 2019
2 parents c64a1ea + e836a3a commit 249cda8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
26 changes: 18 additions & 8 deletions edflow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def call(key, value):
return results


class KeyNotFoundError(Exception):
def __init__(self, cause):
self.cause = cause


def retrieve(
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
Expand Down Expand Up @@ -178,24 +183,29 @@ def retrieve(
for key in keys:
if callable(list_or_dict):
if not expand:
raise ValueError(
"Trying to get past callable node with expand=False."
raise KeyNotFoundError(
ValueError(
"Trying to get past callable node with expand=False."
)
)
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
last_key = key
parent = list_or_dict

if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
try:
if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
except (KeyError, IndexError) as e:
raise KeyNotFoundError(e)

visited += [key]
except Exception as e:
except KeyNotFoundError as e:
if default is None:
print("Key not found: {}, seen: {}".format(keys, visited))
raise e
raise e.cause
else:
list_or_dict = default
success = False
Expand Down
24 changes: 24 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,30 @@ def test_retrieve_pass_success_fail_ef_callable():
val = retrieve(dol, "b/c/d", pass_success=True, expand=False)


def failing_leave():
raise Exception()
return {"c": nested_leave}


class CustomException(Exception):
pass


def custom_leave():
raise CustomException()
return {"c": nested_leave}


def test_retrieve_propagates_exception():
dol = {"a": [1, 2], "b": failing_leave, "e": 2}
with pytest.raises(Exception):
val = retrieve(dol, "b/c/d", default=0)

dol = {"a": [1, 2], "b": custom_leave, "e": 2}
with pytest.raises(CustomException):
val = retrieve(dol, "b/c/d", default=0)


# ====================== walk ====================


Expand Down

0 comments on commit 249cda8

Please sign in to comment.