From 78865c0194a5c1f11e6d564828316f0dc7c7fcc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 1 Aug 2023 12:47:06 +0200 Subject: [PATCH] feat(common): match and replace graph nodes --- ibis/backends/polars/__init__.py | 13 +- ibis/common/graph.py | 387 +++++++++++++++++---- ibis/common/patterns.py | 541 +++++++++++++++++++++++++---- ibis/common/tests/test_patterns.py | 188 ++++++++-- ibis/expr/analysis.py | 10 +- ibis/tests/expr/test_operations.py | 7 +- ibis/util.py | 10 - 7 files changed, 988 insertions(+), 168 deletions(-) diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index f81314743880..b29438d6b44b 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -14,6 +14,7 @@ from ibis.backends.base import BaseBackend, Database from ibis.backends.polars.compiler import translate from ibis.backends.polars.datatypes import dtype_to_polars, schema_from_polars +from ibis.common.patterns import Replace from ibis.util import gen_name, normalize_filename if TYPE_CHECKING: @@ -347,12 +348,14 @@ def compile( ): node = expr.op() ctx = self._context + if params: - replacements = {} - for p, v in params.items(): - op = p.op() if isinstance(p, ir.Expr) else p - replacements[op] = ibis.literal(v, type=op.dtype).op() - node = node.replace(replacements) + params = {param.op(): value for param, value in params.items()} + rule = Replace( + ops.ScalarParameter, + lambda op, ctx: ops.Literal(value=params[op], dtype=op.dtype), + ) + node = node.replace(rule) expr = node.to_expr() node = expr.as_table().op() diff --git a/ibis/common/graph.py b/ibis/common/graph.py index a5da61b7c931..5dcff1680d65 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -4,9 +4,68 @@ from abc import abstractmethod from collections import deque from collections.abc import Hashable, Iterable, Iterator, Mapping -from typing import Any, Callable, Dict, Sequence +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence -from ibis.util import recursive_get +from ibis.common.patterns import NoMatch, pattern +from ibis.util import experimental + +if TYPE_CHECKING: + from typing_extensions import Self + + +def _flatten_collections(node: Node, filter: type) -> Iterator[Node]: + """Flatten collections of nodes into a single iterator. + + We treat common collection types inherently traversable (e.g. list, tuple, dict) + but as undesired in a graph representation, so we traverse them implicitly. + + Parameters + ---------- + node : Any + Flattaneble object unless it's an instance of the types passed as filter. + filter : type, default Node + Type to filter out for the traversal, e.g. Node. + + Returns + ------- + Iterator : Any + """ + if isinstance(node, filter): + yield node + elif isinstance(node, (str, bytes)): + pass + elif isinstance(node, Sequence): + for item in node: + yield from _flatten_collections(item, filter) + elif isinstance(node, Mapping): + for key, value in node.items(): + yield from _flatten_collections(key, filter) + yield from _flatten_collections(value, filter) + + +def _recursive_get(obj: Any, dct: dict[Node, Any]) -> Any: + """Recursively replace objects in a nested structure with values from a dict. + + Since we treat collection types inherently traversable (e.g. list, tuple, dict) we + need to traverse them implicitly and replace the values given a result mapping. + + Parameters + ---------- + obj : Any + Object to replace. + dct : dict[Node, Any] + Mapping of objects to replace with their values. + + Returns + ------- + Object with replaced values. + """ + if isinstance(obj, tuple): + return tuple(_recursive_get(o, dct) for o in obj) + elif isinstance(obj, dict): + return {k: _recursive_get(v, dct) for k, v in obj.items()} + else: + return dct.get(obj, obj) class Node(Hashable): @@ -14,90 +73,233 @@ class Node(Hashable): @property @abstractmethod - def __args__(self) -> Sequence: - ... + def __args__(self) -> tuple[Any, ...]: + """Sequence of arguments to traverse.""" @property @abstractmethod - def __argnames__(self) -> Sequence: - ... - - def __children__(self, filter=None): + def __argnames__(self) -> tuple[str, ...]: + """Sequence of argument names.""" + + def __children__(self, filter: Optional[type] = None) -> tuple[Node, ...]: + """Return the children of this node. + + This method is used to traverse the Node so it returns the children of the node + in the order they should be traversed. We treat common collection types + inherently traversable (e.g. list, tuple, dict), so this method flattens and + optionally filters the arguments of the node. + + Parameters + ---------- + filter : type, default Node + Type to filter out for the traversal, Node is used by default. + + Returns + ------- + Child nodes of this node. + """ return tuple(_flatten_collections(self.__args__, filter or Node)) def __rich_repr__(self): + """Support for rich reprerentation of the node.""" return zip(self.__argnames__, self.__args__) - def map(self, fn, filter=None): + @experimental + def branches(self) -> Iterator[Sequence[Node]]: + """Yield all branches of the graph. + + A branch is a path from the root to a leaf node. This method is primarily + used to implement the `path` method supporting `XPath`-like queries. + + Yields + ------ + A sequence of nodes representing a branch. + """ + stack = [(self, [])] + + while stack: + node, path = stack.pop() + + if children := node.__children__(): + for child in reversed(children): + stack.append((child, path + [node])) + else: + yield path + [node] + + @experimental + def path( + self, *pats: Any, context: Optional[dict] = None + ) -> Iterator[Sequence[Node]]: + """Return the first tree branch matching a given sequence pattern. + + This method provides a way to query the graph using `XPath`-like expressions. + The following XPath expression "//Alias//Value[dtype==int64]" would roughly + translate to the following Python code: + + node.path(..., Alias, ..., Object(Value, dtype=dt.Int64), ...) + + Parameters + ---------- + pats + Sequence which is coerced to a sequence pattern. See `ibis.common.patterns` + for more details. + context + Optional context to use for the pattern matching. + """ + pat = pattern(list(pats)) + for branch in self.branches(): + result = pat.match(branch, context) + if result is not NoMatch: + return result + return NoMatch + + def map(self, fn: Callable, filter: Optional[type] = None) -> dict[Node, Any]: + """Apply a function to all nodes in the graph. + + The traversal is done in a topological order, so the function receives the + results of its immediate children as keyword arguments. + + Parameters + ---------- + fn : Callable + Function to apply to each node. It receives the node as the first argument, + the results as the second and the results of the children as keyword + arguments. + filter : Optional[type], default None + Type to filter out for the traversal, Node is filtered out by default. + + Returns + ------- + A mapping of nodes to their results. + """ results = {} for node in Graph.from_bfs(self, filter=filter).toposort(): kwargs = dict(zip(node.__argnames__, node.__args__)) - kwargs = recursive_get(kwargs, results) + kwargs = _recursive_get(kwargs, results) results[node] = fn(node, results, **kwargs) - return results - def find(self, type, filter=None): - def fn(node, _, **kwargs): - if isinstance(node, type): - return node - return None - - result = self.map(fn, filter=filter) - - return {node for node in result.values() if node is not None} - - def substitute(self, fn, filter=None): - return self.map(fn, filter=filter)[self] + def find( + self, type: type | tuple[type], filter: Optional[type] = None + ) -> set[Node]: + """Find all nodes of a given type in the graph. + + Parameters + ---------- + type : type | tuple[type] + Type or tuple of types to find. + filter : Optional[type], default None + Type to filter out for the traversal, Node is filtered out by default. + + Returns + ------- + The set of nodes matching the given type. + """ + nodes = Graph.from_bfs(self, filter=filter).nodes() + return {node for node in nodes if isinstance(node, type)} + + @experimental + def match( + self, pat: Any, filter: Optional[type] = None, context: Optional[dict] = None + ) -> set[Node]: + """Find all nodes matching a given pattern in the graph. + + A more advanced version of find, this method allows to match nodes based on + the more flexible pattern matching system implemented in the pattern module. + + Parameters + ---------- + pat : Any + Pattern to match. `ibis.common.pattern()` function is used to coerce the + input value into a pattern. See the pattern module for more details. + filter : Optional[type], default None + Type to filter out for the traversal, Node is filtered out by default. + context : Optional[dict], default None + Optional context to use for the pattern matching. + + Returns + ------- + The set of nodes matching the given pattern. + """ + pat = pattern(pat) + ctx = context or {} + nodes = Graph.from_bfs(self, filter=filter).nodes() + return {node for node in nodes if pat.is_match(node, ctx)} + + @experimental + def replace( + self, pat: Any, filter: Optional[type] = None, context: Optional[dict] = None + ) -> Any: + """Match and replace nodes in the graph according to a given pattern. + + The pattern matching system is used to match nodes in the graph and replace them + with the results of the pattern. + + Parameters + ---------- + pat : Any + Pattern to match. `ibis.common.pattern()` function is used to coerce the + input value into a pattern. See the pattern module for more details. + Actual replacement is done by the `ibis.common.pattern.Replace` pattern. + filter : Optional[type], default None + Type to filter out for the traversal, Node is filtered out by default. + context : Optional[dict], default None + Optional context to use for the pattern matching. + + Returns + ------- + The root node of the graph with the replaced nodes. + """ + pat = pattern(pat) + ctx = context or {} - def replace(self, subs, filter=None): def fn(node, _, **kwargs): - try: - return subs[node] - except KeyError: + # TODO(kszucs): pass the reconstructed node from the results provided by the + # kwargs to the pattern rather than the original one node object, this way + # we can match on already replaced nodes + if (result := pat.match(node, ctx)) is NoMatch: return node.__class__(**kwargs) + else: + return result - return self.substitute(fn, filter=filter) + return self.map(fn, filter=filter)[self] -def _flatten_collections(node, filter=Node): - """Flatten collections of nodes into a single iterator. +class Graph(Dict[Node, Sequence[Node]]): + """A mapping-like graph data structure for easier graph traversal and manipulation. - We treat common collection types inherently Node (e.g. list, tuple, dict) - but as undesired in a graph representation, so we traverse them implicitly. + The data structure is a mapping of nodes to their children. The children are + represented as a sequence of nodes. The graph can be constructed from a root node + using the `from_bfs` or `from_dfs` class methods. Parameters ---------- - node : Any - Flattaneble object unless it's an instance of the types passed as filter. - filter : type, default Node - Type to filter out for the traversal, e.g. Node. - - Returns - ------- - Iterator : Any + mapping : Node or Mapping[Node, Sequence[Node]], default () + Either a root node or a mapping of nodes to their children. """ - if isinstance(node, filter): - yield node - elif isinstance(node, (str, bytes)): - pass - elif isinstance(node, Sequence): - for item in node: - yield from _flatten_collections(item, filter) - elif isinstance(node, Mapping): - for key, value in node.items(): - yield from _flatten_collections(key, filter) - yield from _flatten_collections(value, filter) - -class Graph(Dict[Node, Sequence[Node]]): def __init__(self, mapping=(), /, **kwargs): if isinstance(mapping, Node): mapping = self.from_bfs(mapping) super().__init__(mapping, **kwargs) @classmethod - def from_bfs(cls, root: Node, filter=Node) -> Graph: + def from_bfs(cls, root: Node, filter=Node) -> Self: + """Construct a graph from a root node using a breadth-first search. + + The traversal is implemented in an iterative fashion using a queue. + + Parameters + ---------- + root : Node + Root node of the graph. + filter : Optional[type], default None + Type to filter out for the traversal, Node is filtered out by default. + + Returns + ------- + A graph constructed from the root node. + """ if not isinstance(root, Node): raise TypeError("node must be an instance of ibis.common.graph.Node") @@ -112,7 +314,22 @@ def from_bfs(cls, root: Node, filter=Node) -> Graph: return graph @classmethod - def from_dfs(cls, root: Node, filter=Node) -> Graph: + def from_dfs(cls, root: Node, filter=Node) -> Self: + """Construct a graph from a root node using a depth-first search. + + The traversal is implemented in an iterative fashion using a stack. + + Parameters + ---------- + root : Node + Root node of the graph. + filter : Optional[type], default None + Type to filter out for the traversal, Node is filtered out by default. + + Returns + ------- + A graph constructed from the root node. + """ if not isinstance(root, Node): raise TypeError("node must be an instance of ibis.common.graph.Node") @@ -129,17 +346,38 @@ def from_dfs(cls, root: Node, filter=Node) -> Graph: def __repr__(self): return f"{self.__class__.__name__}({super().__repr__()})" - def nodes(self): + def nodes(self) -> set[Node]: + """Return all unique nodes in the graph.""" return self.keys() - def invert(self) -> Graph: + def invert(self) -> Self: + """Invert the data structure. + + The graph originally maps nodes to their children, this method inverts the + mapping to map nodes to their parents. + + Returns + ------- + The inverted graph. + """ result = {node: [] for node in self} for node, dependencies in self.items(): for dependency in dependencies: result[dependency].append(node) return self.__class__({k: tuple(v) for k, v in result.items()}) - def toposort(self) -> Graph: + def toposort(self) -> Self: + """Topologically sort the graph using Kahn's algorithm. + + The graph is sorted in a way that all the dependencies of a node are placed + before the node itself. The graph must not contain any cycles. Especially useful + for mutating the graph in a way that the dependencies of a node are mutated + before the node itself. + + Returns + ------- + The topologically sorted graph. + """ dependents = self.invert() in_degree = {k: len(v) for k, v in self.items()} @@ -162,14 +400,47 @@ def toposort(self) -> Graph: def bfs(node: Node) -> Graph: + """Construct a graph from a root node using a breadth-first search. + + Parameters + ---------- + node : Node + Root node of the graph. + + Returns + ------- + A graph constructed from the root node. + """ return Graph.from_bfs(node) def dfs(node: Node) -> Graph: + """Construct a graph from a root node using a depth-first search. + + Parameters + ---------- + node : Node + Root node of the graph. + + Returns + ------- + A graph constructed from the root node. + """ return Graph.from_dfs(node) def toposort(node: Node) -> Graph: + """Construct a graph from a root node then topologically sort it. + + Parameters + ---------- + node : Node + Root node of the graph. + + Returns + ------- + A topologically sorted graph constructed from the root node. + """ return Graph(node).toposort() diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 2b50b5c128a7..55736d58a1fe 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -2,6 +2,8 @@ import math import numbers +import operator +import sys from abc import ABC, abstractmethod from collections.abc import Callable, Hashable, Mapping, Sequence from enum import Enum @@ -32,7 +34,7 @@ UnionType = object() -T_cov = TypeVar("T_cov", covariant=True) +T_co = TypeVar("T_co", covariant=True) class CoercionError(Exception): @@ -67,6 +69,12 @@ class NoMatch(metaclass=Sentinel): # would be used to annotate an argument as coercible to int or to a certain type # without needing for the type to inherit from Coercible class Pattern(Hashable): + """Base class for all patterns. + + Patterns are used to match values against a given condition. They are extensively + used by other core components of Ibis to validate and/or coerce user inputs. + """ + __slots__ = () @classmethod @@ -94,7 +102,7 @@ def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern: # the typehint is not generic if annot is Ellipsis or annot is AnyType: # treat both `Any` and `...` as wildcard - return Any() + return _any elif isinstance(annot, type): # the typehint is a concrete type (e.g. int, str, etc.) if allow_coercion and issubclass(annot, Coercible): @@ -113,7 +121,7 @@ def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern: if annot.__bound__: return cls.from_typehint(annot.__bound__) else: - return Any() + return _any elif isinstance(annot, Enum): # for enums we check the value against the enum values return EqualTo(annot) @@ -216,27 +224,89 @@ def match(self, value: AnyType, context: dict[str, AnyType]) -> AnyType: """ ... + def is_match(self, value: AnyType, context: dict[str, AnyType]) -> bool: + """Check if the value matches the pattern. + + Parameters + ---------- + value + The value to match the pattern against. + context + A dictionary providing arbitrary context for the pattern matching. + + Returns + ------- + bool + Whether the value matches the pattern. + """ + return self.match(value, context) is not NoMatch + @abstractmethod def __eq__(self, other: Pattern) -> bool: ... - def __invert__(self) -> Pattern: + def __invert__(self) -> Not: + """Syntax sugar for matching the inverse of the pattern.""" return Not(self) - def __or__(self, other: Pattern) -> Pattern: + def __or__(self, other: Pattern) -> AnyOf: + """Syntax sugar for matching either of the patterns. + + Parameters + ---------- + other + The other pattern to match against. + + Returns + ------- + New pattern that matches if either of the patterns match. + """ return AnyOf(self, other) - def __and__(self, other: Pattern) -> Pattern: + def __and__(self, other: Pattern) -> AllOf: + """Syntax sugar for matching both of the patterns. + + Parameters + ---------- + other + The other pattern to match against. + + Returns + ------- + New pattern that matches if both of the patterns match. + """ return AllOf(self, other) - def __rshift__(self, name: str) -> Pattern: - return Capture(self, name) + def __rshift__(self, other: Builder) -> Replace: + """Syntax sugar for replacing a value. - def __rmatmul__(self, name: str) -> Pattern: - return Capture(self, name) + Parameters + ---------- + other + The builder to use for constructing the replacement value. + + Returns + ------- + New replace pattern. + """ + return Replace(self, other) + + def __rmatmul__(self, name: str) -> Capture: + """Syntax sugar for capturing a value. + Parameters + ---------- + name + The name of the capture. -class Matcher(Pattern): + Returns + ------- + New capture pattern. + """ + return Capture(name, self) + + +class _Slotted: """A lightweight alternative to `ibis.common.grounds.Concrete`. This class is used to create immutable dataclasses with slots and a precomputed @@ -275,6 +345,212 @@ def __rich_repr__(self): yield name, getattr(self, name) +class Builder(Hashable): + """A builder is a function that takes a context and returns a new object. + + The context is a dictionary that contains all the captured values and + information relevant for the builder. The builder construct a new object + only given by the context. + + The builder is used in the right hand side of the replace pattern: + `Replace(pattern, builder)`. Semantically when a match occurs for the + replace pattern, the builder is called with the context and the result + of the builder is used as the replacement value. + """ + + __slots__ = () + + @abstractmethod + def __eq__(self, other): + ... + + @abstractmethod + def make(self, context: dict): + """Construct a new object from the context. + + Parameters + ---------- + context + A dictionary containing all the captured values and information + relevant for the builder. + + Returns + ------- + The constructed object. + """ + + +def builder(obj): + """Convert an object to a builder. + + It encapsulates: + - callable objects into a `Factory` builder + - non-callable objects into a `Just` builder + + Parameters + ---------- + obj + The object to convert to a builder. + + Returns + ------- + The builder instance. + """ + # TODO(kszucs): the replacer object must be handled differently from patterns + # basically a replacer is just a lazy way to construct objects from the context + # we should have a separate base class for replacers like Variable, Function, + # Just, Apply and Call. Something like Replacer with a specific method e.g. + # apply() could work + if isinstance(obj, Builder): + return obj + elif callable(obj): + # not function but something else + return Factory(obj) + else: + return Just(obj) + + +class Variable(_Slotted, Builder): + """Retrieve a value from the context. + + Parameters + ---------- + name + The key to retrieve from the state. + """ + + __slots__ = ("name",) + + def make(self, context): + return context[self] + + def __getattr__(self, name): + return Call(operator.attrgetter(name), self) + + def __getitem__(self, name): + return Call(operator.itemgetter(name), self) + + +class Just(_Slotted, Builder): + """Construct exactly the given value. + + Parameters + ---------- + value + The value to return when the builder is called. + """ + + __slots__ = ("value",) + + def make(self, context): + return self.value + + +class Factory(_Slotted, Builder): + """Construct a value by calling a function. + + The function is called with two positional arguments: + 1. the value being matched + 2. the context dictionary + + The function must return the constructed value. + + Parameters + ---------- + func + The function to apply. + """ + + __slots__ = ("func",) + + def make(self, context): + value = context[_] + return self.func(value, context) + + +class Call(_Slotted, Builder): + """Pattern that calls a function with the given arguments. + + Both positional and keyword arguments are coerced into patterns. + + Parameters + ---------- + func + The function to call. + args + The positional argument patterns. + kwargs + The keyword argument patterns. + """ + + __slots__ = ("func", "args", "kwargs") + + def __init__(self, func, *args, **kwargs): + args = tuple(map(builder, args)) + kwargs = frozendict({k: builder(v) for k, v in kwargs.items()}) + super().__init__(func, args, kwargs) + + def make(self, context): + args = tuple(arg.make(context) for arg in self.args) + kwargs = {k: v.make(context) for k, v in self.kwargs.items()} + return self.func(*args, **kwargs) + + def __call__(self, *args, **kwargs): + if self.args or self.kwargs: + raise TypeError("Further specification of Call object is not allowed") + return Call(self.func, *args, **kwargs) + + @classmethod + def namespace(cls, module) -> Namespace: + """Convenience method to create a namespace for easy object construction. + + Parameters + ---------- + module + The module object or name to look up the types. + + Examples + -------- + >>> from ibis.common.patterns import Call + >>> from ibis.expr.operations import Negate + >>> + >>> c = Call.namespace('ibis.expr.operations') + >>> x = Variable('x') + >>> pattern = c.Negate(x) + >>> pattern + Call(func=, args=(Variable(name='x'),), kwargs=FrozenDict({})) + >>> pattern.make({x: 5}) + + """ + return Namespace(cls, module) + + +# reserved variable name for the value being matched +_ = Variable("_") + + +class Matcher(_Slotted, Pattern): + __slots__ = () + + +class Always(Matcher): + """Pattern that matches everything.""" + + __slots__ = () + + def match(self, value, context): + return value + + +class Never(Matcher): + """Pattern that matches nothing.""" + + __slots__ = () + + def match(self, value, context): + return NoMatch + + class Is(Matcher): """Pattern that matches a value against a reference value. @@ -302,6 +578,9 @@ def match(self, value, context): return value +_any = Any() + + class Capture(Matcher): """Pattern that captures a value in the context. @@ -309,33 +588,47 @@ class Capture(Matcher): ---------- pattern The pattern to match against. - name - The name to use in the context if the pattern matches. + key + The key to use in the context if the pattern matches. """ - __slots__ = ("pattern", "name") + __slots__ = ("key", "pattern") + + def __init__(self, key, pat=_any): + super().__init__(key, pattern(pat)) def match(self, value, context): - value = self.pattern.match(value, context=context) + value = self.pattern.match(value, context) if value is NoMatch: return NoMatch - context[self.name] = value + context[self.key] = value return value -class Reference(Matcher): - """Retrieve a value from the context. +class Replace(Matcher): + """Pattern that replaces a value with the output of another pattern. Parameters ---------- - key - The key to retrieve from the state. + matcher + The pattern to match against. + replacer + The pattern to use as a replacement. """ - __slots__ = ("key",) + __slots__ = ("matcher", "builder") - def match(self, context): - return context[self.key] + def __init__(self, matcher, replacer): + super().__init__(pattern(matcher), builder(replacer)) + + def match(self, value, context): + value = self.matcher.match(value, context) + if value is NoMatch: + return NoMatch + # use the `_` reserved variable to record the value being replaced + # in the context, so that it can be used in the replacer pattern + context[_] = value + return self.builder.make(context) class Check(Matcher): @@ -356,7 +649,7 @@ def match(self, value, context): return NoMatch -class Apply(Matcher): +class Function(Matcher): """Pattern that applies a function to the value. Parameters @@ -368,22 +661,73 @@ class Apply(Matcher): __slots__ = ("func",) def match(self, value, context): - return self.func(value) + return self.func(value, context) -class Function(Matcher): - """Pattern that checks a value against a function. +class Namespace: + """Convenience class for creating patterns for various types from a module. + + Useful to reduce boilerplate when creating patterns for various types from + a module. + + Parameters + ---------- + pattern + The pattern to construct with the looked up types. + module + The module object or name to look up the types. + + Examples + -------- + >>> from ibis.common.patterns import Namespace + >>> import ibis.expr.operations as ops + >>> + >>> ns = Namespace(InstanceOf, ops) + >>> ns.Negate + InstanceOf(type=) + >>> + >>> ns.Negate(5) + Object(type=InstanceOf(type=), args=(EqualTo(value=5),), kwargs=FrozenDict({})) + """ + + __slots__ = ("module", "pattern") + + def __init__(self, pattern, module): + if isinstance(module, str): + module = sys.modules[module] + self.module = module + self.pattern = pattern + + def __getattr__(self, name: str) -> Pattern: + return self.pattern(getattr(self.module, name)) + + +class Apply(Matcher): + """Pattern that applies a function to the value. + + The function must accept a single argument. Parameters ---------- func - The function to use. + The function to apply. + + Examples + -------- + >>> from ibis.common.patterns import Apply, match + >>> + >>> match("a" @ Apply(lambda x: x + 1), 5) + 6 """ __slots__ = ("func",) def match(self, value, context): - return self.func(value, context) + return self.func(value) + + def __call__(self, *args, **kwargs): + """Convenience method to create a Call pattern.""" + return Call(self.func, *args, **kwargs) class EqualTo(Matcher): @@ -425,7 +769,7 @@ def match(self, value, context): else: return self.default else: - return self.pattern.match(value, context=context) + return self.pattern.match(value, context) class TypeOf(Matcher): @@ -475,6 +819,9 @@ def match(self, value, context): else: return NoMatch + def __call__(self, *args, **kwargs): + return Object(self.type, *args, **kwargs) + class GenericInstanceOf(Matcher): """Pattern that matches a value that is an instance of a given generic type. @@ -486,10 +833,10 @@ class GenericInstanceOf(Matcher): Examples -------- - >>> class MyNumber(Generic[T_cov]): - ... value: T_cov + >>> class MyNumber(Generic[T_co]): + ... value: T_co ... - ... def __init__(self, value: T_cov): + ... def __init__(self, value: T_co): ... self.value = value ... ... def __eq__(self, other): @@ -551,7 +898,7 @@ def __init__(self, types): check.register(types, lambda x: True) super().__init__(promote_tuple(types), check) - def match(self, value, *, context): + def match(self, value, context): if self.check(value): return value else: @@ -670,8 +1017,11 @@ class Not(Matcher): __slots__ = ("pattern",) + def __init__(self, inner): + super().__init__(pattern(inner)) + def match(self, value, context): - if self.pattern.match(value, context=context) is NoMatch: + if self.pattern.match(value, context) is NoMatch: return value else: return NoMatch @@ -694,7 +1044,7 @@ def __init__(self, *patterns): def match(self, value, context): for pattern in self.patterns: - result = pattern.match(value, context=context) + result = pattern.match(value, context) if result is not NoMatch: return result return NoMatch @@ -718,7 +1068,7 @@ def __init__(self, *patterns): def match(self, value, context): for pattern in self.patterns: - value = pattern.match(value, context=context) + value = pattern.match(value, context) if value is NoMatch: return NoMatch return value @@ -753,7 +1103,7 @@ def __init__( at_most = exactly super().__init__(at_least, at_most) - def match(self, value, *, context): + def match(self, value, context): length = len(value) if self.at_least is not None and length < self.at_least: return NoMatch @@ -842,16 +1192,16 @@ def match(self, values, context): result = [] for value in values: - value = self.item_pattern.match(value, context=context) + value = self.item_pattern.match(value, context) if value is NoMatch: return NoMatch result.append(value) - result = self.type_pattern.match(result, context=context) + result = self.type_pattern.match(result, context) if result is NoMatch: return NoMatch - return self.length_pattern.match(result, context=context) + return self.length_pattern.match(result, context) class TupleOf(Matcher): @@ -880,7 +1230,7 @@ def match(self, values, context): result = [] for pattern, value in zip(self.field_patterns, values): - value = pattern.match(value, context=context) + value = pattern.match(value, context) if value is NoMatch: return NoMatch result.append(value) @@ -912,13 +1262,13 @@ def match(self, value, context): result = {} for k, v in value.items(): - if (k := self.key_pattern.match(k, context=context)) is NoMatch: + if (k := self.key_pattern.match(k, context)) is NoMatch: return NoMatch - if (v := self.value_pattern.match(v, context=context)) is NoMatch: + if (v := self.value_pattern.match(v, context)) is NoMatch: return NoMatch result[k] = v - result = self.type_pattern.match(result, context=context) + result = self.type_pattern.match(result, context) if result is NoMatch: return NoMatch @@ -937,7 +1287,7 @@ def match(self, value, context): return NoMatch v = getattr(value, attr) - if match(pattern, v, context=context) is NoMatch: + if match(pattern, v, context) is NoMatch: return NoMatch return value @@ -959,27 +1309,84 @@ class Object(Matcher): The keyword arguments to match against the attributes of the object. """ - __slots__ = ("type", "attrs_pattern") + __slots__ = ("type", "args", "kwargs") + + def __new__(cls, type, *args, **kwargs): + if not args and not kwargs: + return InstanceOf(type) + else: + return super().__new__(cls) def __init__(self, type, *args, **kwargs): - kwargs.update(dict(zip(type.__match_args__, args))) - super().__init__(type, Attrs(**kwargs)) + type = pattern(type) + args = tuple(map(pattern, args)) + kwargs = frozendict(toolz.valmap(pattern, kwargs)) + super().__init__(type, args, kwargs) def match(self, value, context): - if not isinstance(value, self.type): + if self.type.match(value, context) is NoMatch: return NoMatch - if not self.attrs_pattern.match(value, context=context): + patterns = {**self.kwargs, **dict(zip(value.__match_args__, self.args))} + + fields = {} + changed = False + for name, pattern in patterns.items(): + try: + attr = getattr(value, name) + except AttributeError: + return NoMatch + + result = pattern.match(attr, context) + if result is NoMatch: + return NoMatch + elif result != attr: + changed = True + fields[name] = result + else: + fields[name] = attr + + if changed: + return type(value)(**fields) + else: + return value + + @classmethod + def namespace(cls, module): + return Namespace(InstanceOf, module) + + +class Node(Matcher): + __slots__ = ("type", "each_arg") + + def __init__(self, type, each_arg): + super().__init__(pattern(type), pattern(each_arg)) + + def match(self, value, context): + if self.type.match(value, context) is NoMatch: return NoMatch - return value + newargs = {} + changed = False + for name, arg in zip(value.__argnames__, value.__args__): + result = self.each_arg.match(arg, context) + if result is NoMatch: + newargs[name] = arg + else: + newargs[name] = result + changed = True + + if changed: + return value.__class__(**newargs) + else: + return value class CallableWith(Matcher): __slots__ = ("arg_patterns", "return_pattern") def __init__(self, args, return_=None): - super().__init__(tuple(args), return_ or Any()) + super().__init__(tuple(args), return_ or _any) def match(self, value, context): from ibis.common.annotations import annotated @@ -1014,13 +1421,16 @@ def match(self, value, context): class PatternSequence(Matcher): + # TODO(kszucs): add a length optimization to not even try to match if the + # length of the sequence is lower than the length of the pattern sequence + __slots__ = ("pattern_window",) def __init__(self, patterns): current_patterns = [ - SequenceOf(Any()) if p is Ellipsis else pattern(p) for p in patterns + SequenceOf(_any) if p is Ellipsis else pattern(p) for p in patterns ] - following_patterns = chain(current_patterns[1:], [Not(Any())]) + following_patterns = chain(current_patterns[1:], [Not(_any)]) pattern_window = tuple(zip(current_patterns, following_patterns)) super().__init__(pattern_window) @@ -1069,7 +1479,7 @@ def match(self, value, context): it.rewind() break - res = original.match(matches, context=context) + res = original.match(matches, context) if res is NoMatch: return NoMatch else: @@ -1080,7 +1490,7 @@ def match(self, value, context): except StopIteration: return NoMatch - res = original.match(item, context=context) + res = original.match(item, context) if res is NoMatch: return NoMatch else: @@ -1102,11 +1512,11 @@ def match(self, value, context): return NoMatch keys = value.keys() - if (keys := self.keys_pattern.match(keys, context=context)) is NoMatch: + if (keys := self.keys_pattern.match(keys, context)) is NoMatch: return NoMatch values = value.values() - if (values := self.values_pattern.match(values, context=context)) is NoMatch: + if (values := self.values_pattern.match(values, context)) is NoMatch: return NoMatch return dict(zip(keys, values)) @@ -1135,11 +1545,6 @@ def match(self, value, context): return NoMatch -IsTruish = Check(lambda x: bool(x)) -IsNumber = InstanceOf(numbers.Number) & ~InstanceOf(bool) -IsString = InstanceOf(str) - - def NoneOf(*args) -> Pattern: """Match none of the passed patterns.""" return Not(AnyOf(*args)) @@ -1186,7 +1591,7 @@ def pattern(obj: AnyType) -> Pattern: The constructed pattern. """ if obj is Ellipsis: - return Any() + return _any elif isinstance(obj, Pattern): return obj elif isinstance(obj, Mapping): @@ -1261,6 +1666,7 @@ def match(self, value, context): class Innermost(Matcher): + # matches items in the innermost layer first, but all matches belong to the same layer """Traverse the value tree innermost first and match the first value that matches.""" __slots__ = ("searcher", "filter") @@ -1275,3 +1681,8 @@ def match(self, value, context): return result return self.searcher.match(value, context) + + +IsTruish = Check(lambda x: bool(x)) +IsNumber = InstanceOf(numbers.Number) & ~InstanceOf(bool) +IsString = InstanceOf(str) diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 34866007f1b3..66507ead141d 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -25,12 +25,14 @@ from ibis.common.annotations import ValidationError from ibis.common.collections import FrozenDict -from ibis.common.graph import Node +from ibis.common.graph import Node as GraphNode from ibis.common.patterns import ( AllOf, + Always, Any, AnyOf, Between, + Call, CallableWith, Capture, Check, @@ -45,11 +47,14 @@ Innermost, InstanceOf, IsIn, + Just, LazyInstanceOf, Length, ListOf, MappingOf, MatchError, + Never, + Node, NoMatch, NoneOf, Not, @@ -58,12 +63,14 @@ Pattern, PatternMapping, PatternSequence, - Reference, + Replace, SequenceOf, SubclassOf, Topmost, TupleOf, TypeOf, + Variable, + _, match, pattern, ) @@ -97,6 +104,29 @@ def __eq__(self, other): return self.__class__ == other.__class__ and self.min == other.min +x = Variable("x") +y = Variable("y") +z = Variable("z") + + +def test_always(): + p = Always() + assert p.match(1, context={}) == 1 + assert p.match(2, context={}) == 2 + + +def test_never(): + p = Never() + assert p.match(1, context={}) is NoMatch + assert p.match(2, context={}) is NoMatch + + +def test_just(): + p = Just(1) + assert p.make({}) == 1 + assert p.make({"a": 1}) == 1 + + def test_min(): p = Min(10) assert p.match(10, context={}) == 10 @@ -114,16 +144,16 @@ def test_any(): assert p.match("foo", context={}) == "foo" -def test_reference(): - p = Reference("other") - context = {"other": 10} - assert p.match(context=context) == 10 +def test_variable(): + p = Variable("other") + context = {p: 10} + assert p.make(context) == 10 def test_capture(): ctx = {} - p = Capture(Min(11), "result") + p = Capture("result", Min(11)) assert p.match(10, context=ctx) is NoMatch assert ctx == {} @@ -401,17 +431,29 @@ def test_mapping_of(): assert p.match({"foo": 1}, context={}) is NoMatch -def test_object_pattern(): - class Foo: - __match_args__ = ("a", "b") +class Foo: + __match_args__ = ("a", "b") - def __init__(self, a, b): - self.a = a - self.b = b + def __init__(self, a, b): + self.a = a + self.b = b - def __eq__(self, other): - return type(self) == type(other) and self.a == other.a and self.b == other.b + def __eq__(self, other): + return type(self) == type(other) and self.a == other.a and self.b == other.b + + +class Bar: + __match_args__ = ("c", "d") + + def __init__(self, c, d): + self.c = c + self.d = d + + def __eq__(self, other): + return type(self) == type(other) and self.c == other.c and self.d == other.d + +def test_object_pattern(): p = Object(Foo, 1, b=2) o = Foo(1, 2) r = match(p, o) @@ -419,6 +461,20 @@ def __eq__(self, other): assert r == Foo(1, 2) +def test_object_pattern_complex_type(): + p = Object(Not(Foo), 1, 2) + o = Bar(1, 2) + + # test that the pattern isn't changing the input object if none of + # its arguments are changed by subpatterns + assert match(p, o) is o + assert match(p, Foo(1, 2)) is NoMatch + assert match(p, Bar(1, 3)) is NoMatch + + p = Object(Not(Foo), 1, b=2) + assert match(p, Bar(1, 2)) is NoMatch + + def test_callable_with(): def func(a, b): return str(a) + b @@ -482,10 +538,12 @@ def test_matching(): assert match(InstanceOf(int), 1) == 1 assert match(InstanceOf(int), "foo") is NoMatch - assert Capture(InstanceOf(float), "pi") == "pi" @ InstanceOf(float) - assert Capture(InstanceOf(float), "pi") == InstanceOf(float) >> "pi" + assert Capture("pi", InstanceOf(float)) == "pi" @ InstanceOf(float) + assert Capture("pi", InstanceOf(float)) == "pi" @ InstanceOf(float) - assert match(Capture(InstanceOf(float), "pi"), 3.14, ctx := {}) == 3.14 + assert match(Capture("pi", InstanceOf(float)), 3.14, ctx := {}) == 3.14 + assert ctx == {"pi": 3.14} + assert match("pi" @ InstanceOf(float), 3.14, ctx := {}) == 3.14 assert ctx == {"pi": 3.14} assert match("pi" @ InstanceOf(float), 3.14, ctx := {}) == 3.14 @@ -495,6 +553,58 @@ def test_matching(): assert match(InstanceOf(object) & InstanceOf(float), 3.14) == 3.14 +def test_replace_passes_matched_value_as_underscore(): + class MyInt: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return self.value == other.value + + p = InstanceOf(int) >> Call(MyInt, value=_) + assert p.match(1, context={}) == MyInt(1) + + +def test_replace_in_nested_object_pattern(): + # simple example using reference to replace a value + b = Variable("b") + p = Object(Foo, 1, b=Replace(..., b)) + f = p.match(Foo(1, 2), {b: 3}) + assert f.a == 1 + assert f.b == 3 + + # nested example using reference to replace a value + d = Variable("d") + p = Object(Foo, 1, b=Object(Bar, 2, d=Replace(..., d))) + g = p.match(Foo(1, Bar(2, 3)), {d: 4}) + assert g.b.c == 2 + assert g.b.d == 4 + + # nested example using reference to replace a value with a captured value + p = Object( + Foo, + 1, + b=Replace(Object(Bar, 2, d="d" @ Any()), lambda _, ctx: Foo(-1, b=ctx["d"])), + ) + h = p.match(Foo(1, Bar(2, 3)), {}) + assert isinstance(h, Foo) + assert h.a == 1 + assert isinstance(h.b, Foo) + assert h.b.b == 3 + + # same example with more syntactic sugar + o = Object.namespace(__name__) + c = Call.namespace(__name__) + + d = Variable("d") + p = o.Foo(1, b=o.Bar(2, d=d @ Any()) >> c.Foo(-1, b=d)) + h1 = p.match(Foo(1, Bar(2, 3)), {}) + assert isinstance(h1, Foo) + assert h1.a == 1 + assert isinstance(h1.b, Foo) + assert h1.b.b == 3 + + def test_matching_sequence_pattern(): assert match([], []) == [] assert match([], [1]) is NoMatch @@ -523,8 +633,8 @@ def test_matching_sequence_with_captures(): assert ctx == {"rest": (5, 6, 7, 8)} v = list(range(5)) - assert match([0, 1, "var" @ SequenceOf(...), 4], v, ctx := {}) == v - assert ctx == {"var": (2, 3)} + assert match([0, 1, x @ SequenceOf(...), 4], v, ctx := {}) == v + assert ctx == {x: (2, 3)} assert match([0, 1, "var" @ SequenceOf(...), 4], v, ctx := {}) == v assert ctx == {"var": (2, 3)} @@ -582,8 +692,8 @@ def test_matching_sequence_complicated(): pattern = [ 0, - PatternSequence([1, 2]) >> "first", - PatternSequence([4, 5]) >> "second", + "first" @ PatternSequence([1, 2]), + "second" @ PatternSequence([4, 5]), 3, ] expected = {"first": [1, 2], "second": [4, 5]} @@ -702,7 +812,7 @@ def test_various_not_matching_patterns(pattern, value): @pattern -def endswith_d(s, context): +def endswith_d(s, ctx): if not s.endswith("d"): return NoMatch return s @@ -865,13 +975,16 @@ def f(x): assert pattern(f) == Function(f) -class Term(Node): +class Term(GraphNode): def __eq__(self, other): return type(self) is type(other) and self.__args__ == other.__args__ def __hash__(self): return hash((self.__class__, self.__args__)) + def __repr__(self): + return f"{self.__class__.__name__}({', '.join(map(repr, self.__args__))})" + class Lit(Term): __argnames__ = ("value",) @@ -912,10 +1025,11 @@ class Mul(Binary): three = Add(one, two) six = Mul(two, three) seven = Add(one, six) +fourteen = Add(seven, seven) def test_topmost_innermost(): - inner = Object(Mul, Capture(Any(), "a"), Capture(Any(), "b")) + inner = Object(Mul, Capture("a"), Capture("b")) assert inner.match(six, {}) is six context = {} @@ -928,3 +1042,27 @@ def test_topmost_innermost(): m = p.match(seven, context) assert m is two assert context == {"a": Lit(2), "b": one} + + +def test_graph_path(): + p = Object(Mul, left="lit" @ Object(Lit)) + ctx = {} + r = p.match(two, ctx) + assert r == two + assert ctx["lit"] == Lit(2) + + ctx = {} + r = six.path(..., Object(Mul, left="lit" @ Object(Lit)), ..., context=ctx) + assert ctx["lit"] == Lit(2) + assert r == [six, two, Lit(2)] + + +def test_node(): + pat = Node( + InstanceOf(Add), + each_arg=Replace( + Object(Lit, value=Capture("v")), lambda _, ctx: Lit(ctx["v"] + 100) + ), + ) + result = six.replace(pat) + assert result == Mul(two, Add(Lit(101), two)) diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 694a042919fc..e61acfa5808f 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -14,6 +14,10 @@ from ibis import util from ibis.common.annotations import ValidationError from ibis.common.exceptions import IbisTypeError, IntegrityError +from ibis.common.patterns import Call, Object + +p = Object.namespace(ops) +c = Call.namespace(ops) # --------------------------------------------------------------------- # Some expression metaprogramming / graph transformations to support @@ -177,13 +181,13 @@ def substitute_unbound(node): """Rewrite `node` by replacing table expressions with an equivalent unbound table.""" assert isinstance(node, ops.Node), type(node) - def fn(node, _, *args, **kwargs): + def fn(node, _, **kwargs): if isinstance(node, ops.DatabaseTable): return ops.UnboundTable(name=node.name, schema=node.schema) else: - return node.__class__(*args, **kwargs) + return node.__class__(**kwargs) - return node.substitute(fn) + return node.map(fn)[node] def get_mutation_exprs(exprs: list[ir.Expr], table: ir.Table) -> list[ir.Expr | None]: diff --git a/ibis/tests/expr/test_operations.py b/ibis/tests/expr/test_operations.py index 8f44b5906a97..8124853a963b 100644 --- a/ibis/tests/expr/test_operations.py +++ b/ibis/tests/expr/test_operations.py @@ -12,6 +12,7 @@ import ibis.expr.rules as rlz import ibis.expr.types as ir from ibis.common.annotations import ValidationError +from ibis.common.patterns import EqualTo t = ibis.table([("a", "int64")], name="t") @@ -162,9 +163,11 @@ class Aliased(Base): name: str ketto = Aliased(one, "ketto") - subs = {Name("one"): Name("zero"), two: ketto} - new_values = values.replace(subs) + first_rule = EqualTo(Name("one")) >> Name("zero") + second_rule = EqualTo(two) >> ketto + + new_values = values.replace(first_rule | second_rule) expected = Values((NamedValue(value=1, name=Name("zero")), ketto, three)) assert expected == new_values diff --git a/ibis/util.py b/ibis/util.py index 09e96d7dc598..5f9714456df5 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -351,16 +351,6 @@ def consume(iterator: Iterator[T], n: int | None = None) -> None: next(itertools.islice(iterator, n, n), None) -# TODO(kszucs): make it a more robust to better align with graph._flatten_collections() -def recursive_get(obj, mapping): - if isinstance(obj, tuple): - return tuple(recursive_get(o, mapping) for o in obj) - elif isinstance(obj, dict): - return {k: recursive_get(v, mapping) for k, v in obj.items()} - else: - return mapping.get(obj, obj) - - def flatten_iterable(iterable): """Recursively flatten the iterable `iterable`.""" if not is_iterable(iterable):