Skip to content

Commit

Permalink
refactor(ir): stricter scalar subquery integrity checks
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Feb 12, 2024
1 parent ba31f82 commit d269776
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 20 deletions.
34 changes: 16 additions & 18 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def __init__(self, rel, **kwargs):

@attribute
def value(self):
name = self.rel.schema.names[0]
return self.rel.values[name]
(value,) = self.rel.values.values()
return value

@attribute
def relations(self):
Expand All @@ -127,12 +127,13 @@ def dtype(self):
@public
class ScalarSubquery(Subquery):
def __init__(self, rel):
from ibis.expr.rewrites import ReductionValue
from ibis.expr.operations import Reduction

super().__init__(rel=rel)
if not self.value.find(ReductionValue, filter=Value):
if not isinstance(self.value, Reduction):
raise IntegrityError(
f"Subquery {self.value!r} is not scalar, it must be turned into a scalar subquery first"
f"Subquery {self.value!r} is not a reduction, only "
"reductions can be used as scalar subqueries"
)


Expand All @@ -146,8 +147,8 @@ class InSubquery(Subquery):
needle: Value
dtype = dt.boolean

def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self, rel, needle):
super().__init__(rel=rel, needle=needle)
if not rlz.comparable(self.value, self.needle):
raise IntegrityError(
f"Subquery {self.needle!r} is not comparable to {self.value!r}"
Expand Down Expand Up @@ -275,12 +276,13 @@ class Filter(Simple):
predicates: VarTuple[Value[dt.Boolean]]

def __init__(self, parent, predicates):
from ibis.expr.rewrites import ReductionValue
from ibis.expr.rewrites import ReductionLike

for pred in predicates:
if pred.find(ReductionValue, filter=Value):
if pred.find(ReductionLike, filter=Value):
raise IntegrityError(
f"Cannot add {pred!r} to filter, it is a reduction"
f"Cannot add {pred!r} to filter, it is a reduction which "
"must be converted to a scalar subquery first"
)
if pred.relations and parent not in pred.relations:
raise IntegrityError(
Expand All @@ -291,6 +293,8 @@ def __init__(self, parent, predicates):

@public
class Limit(Simple):
# TODO(kszucs): dynamic limit should contain ScalarSubqueries rather than
# plain scalar values
n: typing.Union[int, Scalar[dt.Integer], None] = None
offset: typing.Union[int, Scalar[dt.Integer]] = 0

Expand Down Expand Up @@ -324,6 +328,7 @@ class Set(Relation):
left: Relation
right: Relation
distinct: bool = False
values = FrozenDict()

def __init__(self, left, right, **kwargs):
# convert to dictionary first, to get key-unordered comparison semantics
Expand All @@ -336,10 +341,6 @@ def __init__(self, left, right, **kwargs):
right = Project(right, cols)
super().__init__(left=left, right=right, **kwargs)

@attribute
def values(self):
return FrozenDict()

@attribute
def schema(self):
return self.left.schema
Expand All @@ -363,10 +364,7 @@ class Difference(Set):
@public
class PhysicalTable(Relation):
name: str

@attribute
def values(self):
return FrozenDict()
values = FrozenDict()


@public
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def rewrite_project_input(value, relation):
)


ReductionValue = p.Reduction | p.Field(p.Aggregate(groups={}))
ReductionLike = p.Reduction | p.Field(p.Aggregate(groups={}))


@replace(ReductionValue)
@replace(ReductionLike)
def filter_wrap_reduction(_):
# Wrap reductions or fields referencing an aggregation without a group by -
# which are scalar fields - in a scalar subquery. In the latter case we
Expand Down
8 changes: 8 additions & 0 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,19 @@ def test_select_windowizing_analytic_function():

def test_subquery_integrity_check():
t = ibis.table(name="t", schema={"a": "int64", "b": "string"})
agg = t.agg([t.a.sum(), t.a.mean()])

msg = "Subquery must have exactly one column, got 2"
with pytest.raises(IntegrityError, match=msg):
ops.ScalarSubquery(agg)
with pytest.raises(IntegrityError, match=msg):
ops.ScalarSubquery(t)

agg = t.agg(t.a.sum() + 1)
msg = "is not a reduction"
with pytest.raises(IntegrityError, match=msg):
ops.ScalarSubquery(agg)


def test_select_turns_scalar_reduction_into_subquery():
arr = ibis.literal([1, 2, 3])
Expand Down

0 comments on commit d269776

Please sign in to comment.