Skip to content

Commit

Permalink
fix(flink): cast argument to integer for reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman authored and cpcloud committed Nov 9, 2023
1 parent 0347036 commit 5059eed
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
1 change: 1 addition & 0 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class ExprTranslator:
)
_dialect_name = "hive"
_quote_identifiers = None
_bool_aggs_need_cast_to_int32 = False

def __init__(
self, node, context, named=False, permit_subquery=False, within_where=False
Expand Down
22 changes: 19 additions & 3 deletions ibis/backends/base/sql/registry/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from __future__ import annotations

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops


def _reduction_format(translator, func_name, where, *args):
def _maybe_cast_bool(translator, op, arg):
if (
translator._bool_aggs_need_cast_to_int32
and isinstance(op, (ops.Sum, ops.Mean, ops.Min, ops.Max))
and (dtype := arg.dtype).is_boolean()
):
return ops.Cast(arg, dt.Int32(nullable=dtype.nullable))
return arg


def _reduction_format(translator, op, func_name, where, *args):
args = (
_maybe_cast_bool(translator, op, arg)
for arg in args
if isinstance(arg, ops.Node)
)
if where is not None:
args = (ops.IfElse(where, arg, ibis.NA) for arg in args)

Expand All @@ -17,7 +33,7 @@ def _reduction_format(translator, func_name, where, *args):
def reduction(func_name):
def formatter(translator, op):
*args, where = op.args
return _reduction_format(translator, func_name, where, *args)
return _reduction_format(translator, op, func_name, where, *args)

return formatter

Expand All @@ -29,7 +45,7 @@ def variance_like(func_name):
}

def formatter(translator, op):
return _reduction_format(translator, func_names[op.how], op.where, op.arg)
return _reduction_format(translator, op, func_names[op.how], op.where, op.arg)

return formatter

Expand Down
1 change: 1 addition & 0 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def sort_key(translator, op):
def count_star(translator, op):
return aggregate._reduction_format(
translator,
op,
"count",
op.where,
ops.Literal(value=1, dtype=dt.int64),
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/flink/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class FlinkExprTranslator(ExprTranslator):
"hive" # TODO: neither sqlglot nor sqlalchemy supports flink dialect
)
_registry = operation_registry
_bool_aggs_need_cast_to_int32 = True


@FlinkExprTranslator.rewrites(ops.Clip)
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ def mean_and_std(v):
raises=sa.exc.DatabaseError,
reason="ORA-02000: missing AS keyword",
),
pytest.mark.notimpl(["flink"], "WIP", raises=Py4JError),
],
),
param(
Expand Down

0 comments on commit 5059eed

Please sign in to comment.