Skip to content

Commit

Permalink
fix(mssql): support translation of ops.Neg() when projecting a field
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Oct 13, 2023
1 parent d5f7cc0 commit ca49d2a
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 18 deletions.
5 changes: 4 additions & 1 deletion ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,10 @@ def _add_where(self, fragment):
if not self.where:
return fragment

args = [self._translate(pred, permit_subquery=True) for pred in self.where]
args = [
self._translate(pred, permit_subquery=True, within_where=True)
for pred in self.where
]
clause = functools.reduce(sql.and_, args)
return fragment.where(clause)

Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,13 @@ def __init__(

self.indent = indent

def _translate(self, expr, named=False, permit_subquery=False):
def _translate(self, expr, named=False, permit_subquery=False, within_where=False):
translator = self.translator_class(
expr,
context=self.context,
named=named,
permit_subquery=permit_subquery,
within_where=within_where,
)
return translator.get_result()

Expand Down Expand Up @@ -395,7 +396,7 @@ def format_where(self):
fmt_preds = []
npreds = len(self.where)
for pred in self.where:
new_pred = self._translate(pred, permit_subquery=True)
new_pred = self._translate(pred, permit_subquery=True, within_where=True)
if npreds > 1:
new_pred = f"({new_pred})"
fmt_preds.append(new_pred)
Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ class ExprTranslator:
_dialect_name = "hive"
_quote_identifiers = None

def __init__(self, node, context, named=False, permit_subquery=False):
def __init__(
self, node, context, named=False, permit_subquery=False, within_where=False
):
self.node = node
self.permit_subquery = permit_subquery

Expand All @@ -198,6 +200,11 @@ def __init__(self, node, context, named=False, permit_subquery=False):
# For now, governing whether the result will have a name
self.named = named

# used to indicate whether the expression being rendered is within a
# WHERE clause. This is used for MSSQL to determine whether to use
# boolean expressions or not.
self.within_where = within_where

def _needs_name(self, op):
if not self.named:
return False
Expand Down
11 changes: 11 additions & 0 deletions ibis/backends/mssql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,22 @@ def _temporal_delta(t, op):
return sa.func.datediff(sa.literal_column(op.part.value.upper()), right, left)


def _not(t, op):
arg = t.translate(op.arg)
if t.within_where:
return sa.not_(arg)
else:
# mssql doesn't support boolean types or comparisons at selection positions
# so we need to compare the value wrapped in a case statement
return sa.case((arg == 0, True), else_=False)


operation_registry = sqlalchemy_operation_registry.copy()
operation_registry.update(sqlalchemy_window_functions_registry)

operation_registry.update(
{
ops.Not: _not,
# aggregate methods
ops.Count: _reduction(sa.func.count),
ops.Max: _reduction(sa.func.max),
Expand Down
18 changes: 6 additions & 12 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,16 +644,10 @@ def test_isin_notin_column_expr(backend, alltypes, df, ibis_op, pandas_op):
[
param(True, True, toolz.identity, id="true_noop"),
param(False, False, toolz.identity, id="false_noop"),
param(
True, False, invert, id="true_invert", marks=pytest.mark.notimpl(["mssql"])
),
param(
False, True, invert, id="false_invert", marks=pytest.mark.notimpl(["mssql"])
),
param(True, False, neg, id="true_negate", marks=pytest.mark.notimpl(["mssql"])),
param(
False, True, neg, id="false_negate", marks=pytest.mark.notimpl(["mssql"])
),
param(True, False, invert, id="true_invert"),
param(False, True, invert, id="false_invert"),
param(True, False, neg, id="true_negate"),
param(False, True, neg, id="false_negate"),
],
)
def test_logical_negation_literal(con, expr, expected, op):
Expand All @@ -664,8 +658,8 @@ def test_logical_negation_literal(con, expr, expected, op):
"op",
[
toolz.identity,
param(invert, marks=pytest.mark.notimpl(["mssql"])),
param(neg, marks=pytest.mark.notimpl(["mssql"])),
invert,
neg,
],
)
def test_logical_negation_column(backend, alltypes, df, op):
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def calc_zscore(s):
),
id="cumnotany",
marks=[
pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError),
pytest.mark.notimpl(["dask"], raises=NotImplementedError),
pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError),
],
Expand Down Expand Up @@ -238,7 +237,6 @@ def calc_zscore(s):
),
id="cumnotall",
marks=[
pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError),
pytest.mark.notimpl(["dask"], raises=NotImplementedError),
pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError),
],
Expand Down

0 comments on commit ca49d2a

Please sign in to comment.