Skip to content

Commit

Permalink
refactor(ir): accept any relation in ops.ExistsSubquery (#8264)
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Feb 12, 2024
1 parent 6361bed commit 68287db
Show file tree
Hide file tree
Showing 20 changed files with 98 additions and 74 deletions.
3 changes: 2 additions & 1 deletion ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,8 @@ def visit_SimpleCase(self, op, *, base=None, cases, results, default):

@visit_node.register(ops.ExistsSubquery)
def visit_ExistsSubquery(self, op, *, rel):
return self.f.exists(rel.this)
select = rel.this.select(1, append=False)
return self.f.exists(select)

@visit_node.register(ops.InSubquery)
def visit_InSubquery(self, op, *, rel, needle):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM "t1" AS "t0"
WHERE
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "t2" AS "t1"
WHERE
"t0"."key1" = "t1"."key1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ FROM "events" AS "t0"
WHERE
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "purchases" AS "t1"
WHERE
"t1"."ts" > '2015-08-15' AND "t0"."user_id" = "t1"."user_id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM "foo_t" AS "t0"
WHERE
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "bar_t" AS "t1"
WHERE
"t0"."key1" = "t1"."key1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ FROM "foo_t" AS "t0"
WHERE
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "bar_t" AS "t1"
WHERE
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WHERE
NOT (
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "bar_t" AS "t1"
WHERE
"t0"."key1" = "t1"."key1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ WHERE
NOT (
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "functional_alltypes" AS "t1"
WHERE
"t0"."string_col" = "t1"."string_col"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ FROM "functional_alltypes" AS "t0"
WHERE
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "functional_alltypes" AS "t1"
WHERE
"t0"."string_col" = "t1"."string_col"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ FROM (
WHERE
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "lineitem" AS "t1"
WHERE
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ FROM (
WHERE
EXISTS(
SELECT
1 AS "1"
1
FROM "hive"."ibis_sf1"."lineitem" AS "t1"
WHERE
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ FROM (
AND "t10"."n_name" = 'SAUDI ARABIA'
AND EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "lineitem" AS "t6"
WHERE
(
Expand All @@ -50,7 +50,7 @@ FROM (
AND NOT (
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "lineitem" AS "t7"
WHERE
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ FROM (
AND "t15"."n_name" = 'SAUDI ARABIA'
AND EXISTS(
SELECT
1 AS "1"
1
FROM "t8" AS "t13"
WHERE
(
Expand All @@ -99,7 +99,7 @@ FROM (
AND NOT (
EXISTS(
SELECT
1 AS "1"
1
FROM "t8" AS "t14"
WHERE
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ FROM (
AND NOT (
EXISTS(
SELECT
CAST(1 AS TINYINT) AS "1"
1
FROM "orders" AS "t1"
WHERE
"t1"."o_custkey" = "t0"."c_custkey"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ FROM (
AND NOT (
EXISTS(
SELECT
1 AS "1"
1
FROM "hive"."ibis_sf1"."orders" AS "t1"
WHERE
"t1"."o_custkey" = "t2"."c_custkey"
Expand Down
1 change: 1 addition & 0 deletions ibis/expr/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ibis.expr.operations.sortkeys import * # noqa: F403
from ibis.expr.operations.strings import * # noqa: F403
from ibis.expr.operations.structs import * # noqa: F403
from ibis.expr.operations.subqueries import * # noqa: F403
from ibis.expr.operations.temporal import * # noqa: F403
from ibis.expr.operations.temporal_windows import * # noqa: F403
from ibis.expr.operations.udf import * # noqa: F403
Expand Down
54 changes: 0 additions & 54 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.collections import FrozenDict
from ibis.common.exceptions import IbisTypeError, IntegrityError, RelationError
Expand Down Expand Up @@ -99,59 +98,6 @@ def relations(self):
return frozenset({self.rel})


@public
class Subquery(Value):
rel: Relation
shape = ds.columnar

def __init__(self, rel, **kwargs):
if len(rel.schema) != 1:
raise IntegrityError(
f"Subquery must have exactly one column, got {len(rel.schema)}"
)
super().__init__(rel=rel, **kwargs)

@attribute
def value(self):
(value,) = self.rel.values.values()
return value

@attribute
def relations(self):
return frozenset()

@property
def dtype(self):
return self.value.dtype


@public
class ScalarSubquery(Subquery):
shape = ds.scalar


@public
class ExistsSubquery(Subquery):
dtype = dt.boolean


@public
class InSubquery(Subquery):
needle: Value
dtype = dt.boolean

def __init__(self, rel, needle):
super().__init__(rel=rel, needle=needle)
if not rlz.comparable(self.value, self.needle):
raise IntegrityError(
f"Subquery {self.needle!r} is not comparable to {self.value!r}"
)

@attribute
def relations(self):
return self.needle.relations


def _check_integrity(values, allowed_parents):
for value in values:
for rel in value.relations:
Expand Down
76 changes: 76 additions & 0 deletions ibis/expr/operations/subqueries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

from public import public

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.exceptions import IntegrityError
from ibis.expr.operations.core import Value
from ibis.expr.operations.relations import Relation # noqa: TCH001


@public
class Subquery(Value):
rel: Relation

@attribute
def relations(self):
return frozenset()


@public
class ExistsSubquery(Subquery):
dtype = dt.boolean
shape = ds.columnar


@public
class ScalarSubquery(Subquery):
shape = ds.scalar

def __init__(self, rel):
if len(rel.schema) != 1:
raise IntegrityError(
"Relation passed to ScalarSubquery() must have exactly one "
f"column, got {len(rel.schema)}"
)
super().__init__(rel=rel)

@attribute
def value(self):
(value,) = self.rel.values.values()
return value

@attribute
def dtype(self):
return self.value.dtype


@public
class InSubquery(Subquery):
needle: Value

dtype = dt.boolean
shape = ds.columnar

def __init__(self, rel, needle):
if len(rel.schema) != 1:
raise IntegrityError(
"Relation passed to InSubquery() must have exactly one "
f"column, got {len(rel.schema)}"
)
(value,) = rel.values.values()
if not rlz.comparable(value, needle):
raise IntegrityError(f"{needle!r} is not comparable to {value!r}")
super().__init__(rel=rel, needle=needle)

@attribute
def value(self):
(value,) = self.rel.values.values()
return value

@attribute
def relations(self):
return self.needle.relations
2 changes: 1 addition & 1 deletion ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_subquery_integrity_check():
t = ibis.table(name="t", schema={"a": "int64", "b": "string"})
agg = t.agg([t.a.sum(), t.a.mean()])

msg = "Subquery must have exactly one column, got 2"
msg = "must have exactly one column, got 2"
with pytest.raises(IntegrityError, match=msg):
ops.ScalarSubquery(agg)
with pytest.raises(IntegrityError, match=msg):
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def any(self, where: BooleanValue | None = None) -> BooleanValue:
def resolve_exists_subquery(outer):
"""An exists subquery whose outer leaf table is unknown."""
(inner,) = (t for t in parents if t != outer.op())
relation = ops.Project(ops.Filter(inner, [self]), {"1": 1})
relation = ops.Filter(inner, [self])
return ops.ExistsSubquery(relation).to_expr()

if len(parents) == 2:
Expand Down
4 changes: 2 additions & 2 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,12 +1468,12 @@ def test_unresolved_existence_predicate(t1, t2):
expr = (t1.key1 == t2.key1).any()
assert isinstance(expr, Deferred)

filtered = t2.filter(t1.key1 == t2.key1).select(ibis.literal(1))
filtered = t2.filter(t1.key1 == t2.key1)
subquery = ops.ExistsSubquery(filtered)
expected = ops.Filter(parent=t1, predicates=[subquery])
assert t1[expr].op() == expected

filtered = t1.filter(t1.key1 == t2.key1).select(ibis.literal(1))
filtered = t1.filter(t1.key1 == t2.key1)
subquery = ops.ExistsSubquery(filtered)
expected = ops.Filter(parent=t2, predicates=[subquery])
assert t2[expr].op() == expected
Expand Down

0 comments on commit 68287db

Please sign in to comment.