Skip to content

Commit

Permalink
refactor(common): restrict implicit traversals to common builtin coll…
Browse files Browse the repository at this point in the history
…ections
  • Loading branch information
kszucs committed Aug 11, 2023
1 parent 775f1cd commit 8531347
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 24 deletions.
84 changes: 61 additions & 23 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,103 @@

from abc import abstractmethod
from collections import deque
from collections.abc import Hashable, Iterable, Iterator, Mapping
from collections.abc import Hashable, Iterable, Iterator
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence

from ibis.common.collections import frozendict
from ibis.common.patterns import NoMatch, pattern
from ibis.util import experimental

if TYPE_CHECKING:
from typing_extensions import Self


def _flatten_collections(node: Node, filter: type) -> Iterator[Node]:
def _flatten_collections(node: Any, filter: type) -> Iterator[Node]:
"""Flatten collections of nodes into a single iterator.
We treat common collection types inherently traversable (e.g. list, tuple, dict)
but as undesired in a graph representation, so we traverse them implicitly.
Parameters
----------
node : Any
node
Flattaneble object unless it's an instance of the types passed as filter.
filter : type, default Node
filter
Type to filter out for the traversal, e.g. Node.
Returns
-------
Iterator : Any
A flat generator of the filtered nodes.
Examples
--------
>>> from ibis.common.grounds import Concrete
>>> from ibis.common.graph import Node
>>>
>>> class MyNode(Concrete, Node):
... number: int
... string: str
... children: tuple[Node, ...]
...
>>> a = MyNode(4, "a", ())
>>>
>>> b = MyNode(3, "b", ())
>>> c = MyNode(2, "c", (a, b))
>>> d = MyNode(1, "d", (c,))
>>>
>>> assert list(_flatten_collections(a, Node)) == [a]
>>> assert list(_flatten_collections((c,), Node)) == [c]
>>> assert list(_flatten_collections([a, b, (c, a)], Node)) == [a, b, c, a]
"""
if isinstance(node, filter):
yield node
elif isinstance(node, (str, bytes)):
pass
elif isinstance(node, Sequence):
elif isinstance(node, (tuple, list)):
for item in node:
yield from _flatten_collections(item, filter)
elif isinstance(node, Mapping):
for key, value in node.items():
yield from _flatten_collections(key, filter)
elif isinstance(node, (dict, frozendict)):
for value in node.values():
yield from _flatten_collections(value, filter)


def _recursive_get(obj: Any, dct: dict[Node, Any]) -> Any:
def _recursive_get(obj: Any, dct: dict[Node, Any], filter: type) -> Any:
"""Recursively replace objects in a nested structure with values from a dict.
Since we treat collection types inherently traversable (e.g. list, tuple, dict) we
need to traverse them implicitly and replace the values given a result mapping.
Since we treat common collection types inherently traversable, so we need to
traverse them implicitly and replace the values given a result mapping.
Parameters
----------
obj : Any
obj
Object to replace.
dct : dict[Node, Any]
dct
Mapping of objects to replace with their values.
filter
Type to filter out for the traversal, e.g. Node.
Returns
-------
Object with replaced values.
Examples
--------
>>> from ibis.common.graph import _recursive_get
>>>
>>> dct = {1: 2, 3: 4}
>>> _recursive_get((1, 3), dct, filter=int)
(2, 4)
>>> _recursive_get(frozendict({1: 3}), dct, filter=int)
{1: 4}
>>> _recursive_get(frozendict({1: (1, 3)}), dct, filter=int)
{1: (2, 4)}
"""
if isinstance(obj, tuple):
return tuple(_recursive_get(o, dct) for o in obj)
elif isinstance(obj, dict):
return {k: _recursive_get(v, dct) for k, v in obj.items()}
if isinstance(obj, filter):
return dct[obj]
elif isinstance(obj, (tuple, list)):
return tuple(_recursive_get(o, dct, filter) for o in obj)
elif isinstance(obj, (dict, frozendict)):
return {k: _recursive_get(v, dct, filter) for k, v in obj.items()}
else:
return dct.get(obj, obj)
return obj


class Node(Hashable):
Expand Down Expand Up @@ -123,10 +157,14 @@ def map(self, fn: Callable, filter: Optional[type] = None) -> dict[Node, Any]:
-------
A mapping of nodes to their results.
"""
filter = filter or Node
results = {}
for node in Graph.from_bfs(self, filter=filter).toposort():
kwargs = dict(zip(node.__argnames__, node.__args__))
kwargs = _recursive_get(kwargs, results)
# minor optimization to directly recurse into the children
kwargs = {
k: _recursive_get(v, results, filter)
for k, v in zip(node.__argnames__, node.__args__)
}
results[node] = fn(node, results, **kwargs)
return results

Expand Down
85 changes: 84 additions & 1 deletion ibis/common/tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from __future__ import annotations

from collections.abc import Mapping, Sequence

import pytest

from ibis.common.graph import Graph, Node, bfs, dfs, toposort
from ibis.common.collections import frozendict
from ibis.common.graph import (
Graph,
Node,
_flatten_collections,
_recursive_get,
bfs,
dfs,
toposort,
)
from ibis.common.grounds import Annotable, Concrete
from ibis.common.patterns import InstanceOf, TupleOf

Expand Down Expand Up @@ -178,3 +189,75 @@ class All(Bool):

copied = node.copy(arguments=(T, F))
assert copied == All((T, F), strict=False)


class MySequence(Sequence):
def __init__(self, *items):
self.items = items

def __getitem__(self, index):
raise AssertionError("must not be called") # pragma: no cover

def __len__(self):
return len(self.items)


class MyMapping(Mapping):
def __init__(self, **items):
self.items = items

def __getitem__(self, key):
raise AssertionError("must not be called") # pragma: no cover

def __iter__(self):
return iter(self.items)

def __len__(self):
return len(self.items)


def test_flatten_collections():
# test that flatten collections doesn't recurse into arbitrary mappings
# and sequences, just the commonly used builtin ones: list, tuple, dict

result = _flatten_collections(
[0.0, 1, 2, [3, 4, (5, 6)], "7", MySequence(8, 9)], filter=int
)
assert list(result) == [1, 2, 3, 4, 5, 6]

result = _flatten_collections(
{
"a": 0.0,
"b": 1,
"c": (MyMapping(d=2, e=3), frozendict(f=4)),
"d": [5, "6", {"e": (7, 8.9)}],
},
filter=int,
)
assert list(result) == [1, 4, 5, 7]


def test_recurse_get():
results = {"a": "A", "b": "B", "c": "C", "d": "D"}

assert _recursive_get((0, 1, "a", {"b": "c"}), results, filter=str) == (
0,
1,
"A",
{"b": "C"},
)
assert _recursive_get({"a": "b", "c": "d"}, results, filter=str) == {
"a": "B",
"c": "D",
}
assert _recursive_get(["a", "b", "c"], results, filter=str) == ("A", "B", "C")
assert _recursive_get("a", results, filter=str) == "A"

my_seq = MySequence("a", "b", "c")
my_map = MyMapping(a="a", b="b", c="c")
assert _recursive_get(("a", my_seq, ["b", "a"], my_map), results, filter=str) == (
"A",
my_seq,
("B", "A"),
my_map,
)

0 comments on commit 8531347

Please sign in to comment.