From 02c66078e6551584d2e10a32552b4decf7ec7762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 25 Apr 2024 15:12:29 +0200 Subject: [PATCH] feat(common): also traverse nodes used as dictionary keys (#9041) This change allows repr-ing node mappings, like the dereference mappings for easier inspection. Inevitably this is slowing down the traversals, but changing the instance checks to use a plain dict instead of `(dict, frozendict)` maintains the previous traversal speed. --- ibis/common/graph.py | 14 +++++++++----- ibis/common/tests/test_graph.py | 12 ++++++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/ibis/common/graph.py b/ibis/common/graph.py index 7654e4b54f39..c4975e37d360 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -2,13 +2,13 @@ from __future__ import annotations +import itertools from abc import abstractmethod from collections import deque from collections.abc import Iterable, Iterator, KeysView, Mapping, Sequence from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union from ibis.common.bases import Hashable -from ibis.common.collections import frozendict from ibis.common.patterns import NoMatch, Pattern from ibis.common.typing import _ClassInfo from ibis.util import experimental, promote_list @@ -66,8 +66,9 @@ def _flatten_collections(node: Any) -> Iterator[N]: yield item elif isinstance(item, (tuple, list)): yield from _flatten_collections(item) - elif isinstance(item, (dict, frozendict)): - yield from _flatten_collections(item.values()) + elif isinstance(item, dict): + items = itertools.chain.from_iterable(item.items()) + yield from _flatten_collections(items) def _recursive_lookup(obj: Any, dct: dict) -> Any: @@ -89,6 +90,7 @@ def _recursive_lookup(obj: Any, dct: dict) -> Any: Examples -------- + >>> from ibis.common.collections import frozendict >>> from ibis.common.grounds import Concrete >>> from ibis.common.graph import Node >>> @@ -117,8 +119,10 @@ def _recursive_lookup(obj: Any, dct: dict) -> Any: return dct.get(obj, obj) elif isinstance(obj, (tuple, list)): return tuple(_recursive_lookup(o, dct) for o in obj) - elif isinstance(obj, (dict, frozendict)): - return {k: _recursive_lookup(v, dct) for k, v in obj.items()} + elif isinstance(obj, dict): + return { + _recursive_lookup(k, dct): _recursive_lookup(v, dct) for k, v in obj.items() + } else: return obj diff --git a/ibis/common/tests/test_graph.py b/ibis/common/tests/test_graph.py index b0da3bdf2a0f..7b05287b4ce6 100644 --- a/ibis/common/tests/test_graph.py +++ b/ibis/common/tests/test_graph.py @@ -55,6 +55,7 @@ def copy(self, name=None, children=None): E = MyNode(name="E", children=[]) B = MyNode(name="B", children=[D, E]) A = MyNode(name="A", children=[B, C]) +F = MyNode(name="F", children=[{C: D, E: None}]) def test_bfs(): @@ -68,8 +69,8 @@ def test_construction(): def test_graph_nodes(): - g = Graph(A) - assert g.nodes() == {A, B, C, D, E} + assert Graph(A).nodes() == {A, B, C, D, E} + assert Graph(F).nodes() == {F, C, D, E} def test_graph_repr(): @@ -286,6 +287,10 @@ def test_flatten_collections(): ) assert list(result) == [A, C, D] + # test that dictionary keys are also flattened + result = _flatten_collections([0.0, {A: B, C: [D]}, frozendict({E: 6})]) + assert list(result) == [A, B, C, D, E] + def test_recursive_lookup(): results = {A: "A", B: "B", C: "C", D: "D"} @@ -312,6 +317,9 @@ def test_recursive_lookup(): my_map, ) + # test that dictionary nodes as dictionary keys are also looked up + assert _recursive_lookup({A: B, C: D}, results) == {"A": "B", "C": "D"} + def test_coerce_finder(): f = _coerce_finder(int)