Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(common): do not convert callables to resolveable objects #7956

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions ibis/backends/clickhouse/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from ibis.backends.clickhouse.compiler.relations import translate_rel
from ibis.backends.clickhouse.compiler.values import translate_val
from ibis.common.deferred import _
from ibis.expr.analysis import c, find_first_base_table, p, x, y
from ibis.common.patterns import replace
from ibis.expr.analysis import c, find_first_base_table, p, x
from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna, rewrite_sample

if TYPE_CHECKING:
Expand All @@ -42,6 +43,14 @@ def _translate_node(node, **kwargs):
return translate_rel(node, **kwargs)


@replace(ops.InColumn)
def replace_in_column_with_table_array_view(_):
# replace the right side of InColumn into a scalar subquery for sql backends
base = find_first_base_table(_.options)
options = ops.TableArrayView(ops.Selection(table=base, selections=(_.options,)))
return _.copy(options=options)


def translate(op: ops.TableNode, params: Mapping[ir.Value, Any]) -> sg.exp.Expression:
"""Translate an ibis operation to a sqlglot expression.

Expand Down Expand Up @@ -88,14 +97,6 @@ def fn(node, _, **kwargs):
lambda _, x: ops.Literal(value=params[_], dtype=x)
)

# replace the right side of InColumn into a scalar subquery for sql
# backends
replace_in_column_with_table_array_view = p.InColumn(options=y) >> _.copy(
options=c.TableArrayView(
c.Selection(table=lambda _, y: find_first_base_table(y), selections=(y,))
),
)

# replace any checks against an empty right side of the IN operation with
# `False`
replace_empty_in_values_with_false = p.InValues(options=()) >> c.Literal(
Expand Down
4 changes: 0 additions & 4 deletions ibis/common/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,6 @@ def resolver(obj):
return Just(obj)
else:
return Sequence(obj)
elif isinstance(obj, type):
return Just(obj)
elif callable(obj):
return Factory(obj)
else:
# the object is used as a constant value
return Just(obj)
Expand Down
28 changes: 18 additions & 10 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ibis.common.collections import FrozenDict, RewindableIterator, frozendict
from ibis.common.deferred import (
Deferred,
Factory,
Resolver,
Variable,
_, # noqa: F401
Expand All @@ -46,6 +47,13 @@
T_co = TypeVar("T_co", covariant=True)


def as_resolver(obj):
if callable(obj) and not isinstance(obj, Deferred):
return Factory(obj)
else:
return resolver(obj)


class NoMatch(metaclass=Sentinel):
"""Marker to indicate that a pattern didn't match."""

Expand Down Expand Up @@ -331,7 +339,7 @@ class Capture(Slotted, Pattern):

def __init__(self, key, pat=_any):
if isinstance(key, (Deferred, Resolver)):
key = resolver(key)
key = as_resolver(key)
if isinstance(key, Variable):
key = key.name
else:
Expand All @@ -353,25 +361,25 @@ class Replace(Slotted, Pattern):
----------
matcher
The pattern to match against.
resolver
replacer
The deferred to use as a replacement.
"""

__slots__ = ("pattern", "resolver")
pattern: Pattern
resolver: Resolver
__slots__ = ("matcher", "replacer")
matcher: Pattern
replacer: Resolver

def __init__(self, matcher, replacer):
super().__init__(pattern=pattern(matcher), resolver=resolver(replacer))
super().__init__(matcher=pattern(matcher), replacer=as_resolver(replacer))

def match(self, value, context):
value = self.pattern.match(value, context)
value = self.matcher.match(value, context)
if value is NoMatch:
return NoMatch
# use the `_` reserved variable to record the value being replaced
# in the context, so that it can be used in the replacer pattern
context["_"] = value
return self.resolver.resolve(context)
return self.replacer.resolve(context)


def replace(matcher):
Expand Down Expand Up @@ -424,7 +432,7 @@ class DeferredCheck(Slotted, Pattern):
resolver: Resolver

def __init__(self, obj):
super().__init__(resolver=resolver(obj))
super().__init__(resolver=as_resolver(obj))

def describe(self, plural=False):
if plural:
Expand Down Expand Up @@ -505,7 +513,7 @@ class DeferredEqualTo(Slotted, Pattern):
resolver: Resolver

def __init__(self, obj):
super().__init__(resolver=resolver(obj))
super().__init__(resolver=as_resolver(obj))

def match(self, value, context):
context["_"] = value
Expand Down
4 changes: 2 additions & 2 deletions ibis/common/tests/test_deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ def test_builder():
class MyClass:
pass

def fn(x, ctx):
def fn(x):
return x + 1

assert resolver(1) == Just(1)
assert resolver(Just(1)) == Just(1)
assert resolver(Just(Just(1))) == Just(1)
assert resolver(MyClass) == Just(MyClass)
assert resolver(fn) == Factory(fn)
assert resolver(fn) == Just(fn)
assert resolver(()) == Sequence(())
assert resolver((1, 2, _)) == Sequence((Just(1), Just(2), _))
assert resolver({}) == Mapping({})
Expand Down
9 changes: 9 additions & 0 deletions ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,3 +1691,12 @@ def test_sample():
assert op.fraction == 0.5
assert op.method == "block"
assert op.seed == 1234


def test_deferred_doesnt_convert_callables():
t = ibis.table([("a", "int64"), ("b", "string")])
expr = t.mutate(b=_.b.split(",").filter(lambda pp: ~pp.isin(("word1", "word2"))))
expected = t.mutate(
b=t.b.split(",").filter(lambda pp: ~pp.isin(("word1", "word2")))
)
assert expr.equals(expected)