diff --git a/ibis/common/graph.py b/ibis/common/graph.py index a899718e372b..c2aaf5d7e046 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -56,19 +56,18 @@ def _flatten_collections(node: Any) -> Iterator[N]: >>> c = MyNode(2, "c", (a, b)) >>> d = MyNode(1, "d", (c,)) >>> - >>> assert list(_flatten_collections(a)) == [a] >>> assert list(_flatten_collections((c,))) == [c] >>> assert list(_flatten_collections([a, b, (c, a)])) == [a, b, c, a] + >>> assert list(_flatten_collections([{"b": b, "a": a}])) == [b, a] """ - if isinstance(node, Node): - yield node - elif isinstance(node, (tuple, list)): - for item in node: + for item in node: + if isinstance(item, Node): + yield item + elif isinstance(item, (tuple, list)): yield from _flatten_collections(item) - elif isinstance(node, (dict, frozendict)): - for value in node.values(): - yield from _flatten_collections(value) + elif isinstance(item, (dict, frozendict)): + yield from _flatten_collections(item.values()) def _recursive_lookup(obj: Any, dct: dict) -> Any: diff --git a/ibis/common/tests/test_graph.py b/ibis/common/tests/test_graph.py index 59d39a4f7733..677f9f0f204d 100644 --- a/ibis/common/tests/test_graph.py +++ b/ibis/common/tests/test_graph.py @@ -288,12 +288,7 @@ def test_flatten_collections(): assert list(result) == [A, B, C, D, E] result = _flatten_collections( - { - "a": 0.0, - "b": A, - "c": (MyMapping(d=B, e=3), frozendict(f=C)), - "d": [5, "6", {"e": (D, 8.9)}], - } + [0.0, A, (MyMapping(d=B, e=3), frozendict(f=C)), [5, "6", {"e": (D, 8.9)}]] ) assert list(result) == [A, C, D]