diff --git a/ibis/common/graph.py b/ibis/common/graph.py index e865abac7ebe..672f3d241c78 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -209,6 +209,8 @@ def fn(node, _, **kwargs): # kwargs to the pattern rather than the original one node object, this way # we can match on already replaced nodes if (result := pat.match(node, ctx)) is NoMatch: + # TODO(kszucs): annotable instances should use node.__recreate__() + # for quick reconstruction return node.__class__(**kwargs) else: return result diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 6d2217302369..5851e14d09f2 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -1591,6 +1591,8 @@ def pattern(obj: AnyType) -> Pattern: return _any elif isinstance(obj, Pattern): return obj + elif isinstance(obj, Variable): + return Capture(obj) elif isinstance(obj, Mapping): return PatternMapping(obj) elif isinstance(obj, type): diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 3ff50e043455..e942b8ef09ad 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -156,6 +156,16 @@ def test_variable(): assert p.make(context) == 10 +def test_pattern_factory_wraps_variable_with_capture(): + v = Variable("other") + p = pattern(v) + assert p == Capture(v, Any()) + + ctx = {} + assert p.match(10, ctx) == 10 + assert ctx == {v: 10} + + def test_capture(): ctx = {} diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 6acf31340265..e1c53d9d21be 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -14,11 +14,14 @@ from ibis import util from ibis.common.annotations import ValidationError from ibis.common.exceptions import IbisTypeError, IntegrityError -from ibis.common.patterns import Call, Object +from ibis.common.patterns import Call, Object, Variable p = Object.namespace(ops) c = Call.namespace(ops) +x = Variable("x") +y = Variable("y") + # --------------------------------------------------------------------- # Some expression metaprogramming / graph transformations to support # compilation later @@ -167,15 +170,9 @@ def fn(node): def substitute_unbound(node): """Rewrite `node` by replacing table expressions with an equivalent unbound table.""" - assert isinstance(node, ops.Node), type(node) - - def fn(node, _, **kwargs): - if isinstance(node, ops.DatabaseTable): - return ops.UnboundTable(name=node.name, schema=node.schema) - else: - return node.__class__(**kwargs) - - return node.map(fn)[node] + return node.replace( + p.DatabaseTable(name=x, schema=y) >> c.UnboundTable(name=x, schema=y) + ) def get_mutation_exprs(exprs: list[ir.Expr], table: ir.Table) -> list[ir.Expr | None]: