Skip to content

Commit

Permalink
fix(mssql): restore any, all and cumulative versions (#8409)
Browse files Browse the repository at this point in the history
## Description of changes

I took a look through what sqlalchemy was generating for these failing
test cases and then reimplemented it using sqlglot.

Any and All:
* Instead of `logical or` we do a `Max(IIF(condition, 1, 0))`
* Instead of `logical and` we do a `Min(IIF(condition, 1, 0))`

Cumulative versions:

The cumulative versions work, but required an additional check in the
`visit_Not` node. MSSQL doesn't support doing `IFF(MAX(IFF(...` -- the
argument to IFF has to be a boolean and the Window function isn't
considered a boolean for some reason, so the outer conditional needs to
be a case statement.- fix(mssql): restore any, all and cumulative
versions
- fix(mssql): support aggregations with any/all


## Issues closed

* Resolves #8073
  • Loading branch information
gforsyth authored Feb 21, 2024
1 parent a3b1cc6 commit 99a4022
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
26 changes: 24 additions & 2 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ class MSSQLCompiler(SQLGlotCompiler):

UNSUPPORTED_OPERATIONS = frozenset(
(
ops.Any,
ops.All,
ops.ApproxMedian,
ops.Arbitrary,
ops.ArgMax,
Expand Down Expand Up @@ -402,6 +400,18 @@ def visit_Mean(self, op, *, arg, where):
def visit_Not(self, op, *, arg):
if isinstance(arg, sge.Boolean):
return FALSE if arg == TRUE else TRUE
elif isinstance(arg, (sge.Window, sge.Max, sge.Min)):
# special case Window, Max, and Min.
# These are used for NOT ANY or NOT ALL and friends.
# We are working around MSSQL's rather unfriendly boolean handling rules
# and because Max or Min don't return booleans, we have to handle the equality check
# in a case statement instead.
# e.g.
# IFF(MAX(IFF(condition, 1, 0)) = 0, true_case, false_case)
# is invalid
# Needs to be
# CASE WHEN MAX(IFF(condition, 1, 0)) = 0 THEN true_case ELSE false_case END
return sge.Case(ifs=[self.if_(arg.eq(0), 1)], default=0)
return self.if_(arg, 1, 0).eq(0)

def visit_HashBytes(self, op, *, arg, how):
Expand Down Expand Up @@ -435,3 +445,15 @@ def visit_HexDigest(self, op, *, arg, how):
def visit_StringConcat(self, op, *, arg):
any_args_null = (a.is_(NULL) for a in arg)
return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg))

def visit_Any(self, op, *, arg, where):
arg = self.if_(arg, 1, 0)
if where is not None:
arg = self.if_(where, arg, NULL)
return sge.Max(this=arg)

def visit_All(self, op, *, arg, where):
arg = self.if_(arg, 1, 0)
if where is not None:
arg = self.if_(where, arg, NULL)
return sge.Min(this=arg)
6 changes: 0 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def mean_and_std(v):
raises=AttributeError,
reason="'IntegerColumn' object has no attribute 'any'",
),
pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError),
],
),
param(
Expand All @@ -307,7 +306,6 @@ def mean_and_std(v):
reason="ORA-02000: missing AS keyword",
),
pytest.mark.notimpl(["exasol"], raises=ExaQueryError),
pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError),
],
),
param(
Expand All @@ -326,7 +324,6 @@ def mean_and_std(v):
reason="ORA-02000: missing AS keyword",
),
pytest.mark.notimpl(["exasol"], raises=ExaQueryError),
pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError),
],
),
param(
Expand All @@ -339,7 +336,6 @@ def mean_and_std(v):
raises=AttributeError,
reason="'IntegerColumn' object has no attribute 'all'",
),
pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError),
],
),
param(
Expand All @@ -358,7 +354,6 @@ def mean_and_std(v):
reason="ORA-02000: missing AS keyword",
),
pytest.mark.notimpl(["exasol"], raises=ExaQueryError),
pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError),
],
),
param(
Expand All @@ -377,7 +372,6 @@ def mean_and_std(v):
reason="ORA-02000: missing AS keyword",
),
pytest.mark.notimpl(["exasol"], raises=ExaQueryError),
pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError),
],
),
param(
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def calc_zscore(s):
.astype(bool)
),
id="cumany",
marks=[pytest.mark.broken(["mssql"], raises=com.OperationNotDefinedError)],
),
param(
lambda t, win: (t.double_col == 0).notany().over(win),
Expand All @@ -262,7 +261,6 @@ def calc_zscore(s):
id="cumnotany",
marks=[
pytest.mark.broken(["oracle"], raises=OracleDatabaseError),
pytest.mark.broken(["mssql"], raises=com.OperationNotDefinedError),
],
),
param(
Expand All @@ -274,7 +272,6 @@ def calc_zscore(s):
.astype(bool)
),
id="cumall",
marks=[pytest.mark.broken(["mssql"], raises=com.OperationNotDefinedError)],
),
param(
lambda t, win: (t.double_col == 0).notall().over(win),
Expand All @@ -287,7 +284,6 @@ def calc_zscore(s):
id="cumnotall",
marks=[
pytest.mark.broken(["oracle"], raises=OracleDatabaseError),
pytest.mark.broken(["mssql"], raises=com.OperationNotDefinedError),
],
),
param(
Expand Down

0 comments on commit 99a4022

Please sign in to comment.