diff --git a/ibis/backends/base/sql/alchemy/translator.py b/ibis/backends/base/sql/alchemy/translator.py index 70c177be7fe1..285dc8c3a3c7 100644 --- a/ibis/backends/base/sql/alchemy/translator.py +++ b/ibis/backends/base/sql/alchemy/translator.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import ibis +import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import util @@ -40,12 +43,39 @@ class AlchemyExprTranslator(ExprTranslator): context_class = AlchemyContext + _bool_aggs_need_cast_to_int32 = True + _boolean_cast_ops = ops.Sum, ops.Mean, ops.Min, ops.Max + _has_filter_syntax = False + def name(self, translated, name, force=True): return translated.label(name) def get_sqla_type(self, data_type): return to_sqla_type(data_type, type_map=self._type_map) + def _reduction(self, sa_func, expr): + op = expr.op() + arg = op.arg + if ( + self._bool_aggs_need_cast_to_int32 + and isinstance(op, self._boolean_cast_ops) + and isinstance( + type := arg.type(), + dt.Boolean, + ) + ): + arg = arg.cast(dt.Int32(nullable=type.nullable)) + + if (where := op.where) is not None: + if self._has_filter_syntax: + sa_arg = self.translate(arg).filter(self.translate(where)) + else: + sa_arg = self.translate(where.ifelse(arg, None)) + else: + sa_arg = self.translate(arg) + + return sa_func(sa_arg) + rewrites = AlchemyExprTranslator.rewrites diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index 4be35b461390..fb5282b7c70f 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -13,6 +13,7 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator): # type that duckdb doesn't understand, but we probably still want # the updated `operation_registry` from postgres _type_map = AlchemyExprTranslator._type_map.copy() + _has_filter_syntax = True rewrites = DuckDBSQLExprTranslator.rewrites diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index 108e57791474..d81f3f6cbbf4 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -28,6 +28,7 @@ class MySQLExprTranslator(AlchemyExprTranslator): dt.String: mysql.VARCHAR, } ) + _bool_aggs_need_cast_to_int32 = False rewrites = MySQLExprTranslator.rewrites diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index 54edbc69714d..9baa828913b5 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -27,6 +27,7 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator): dt.Float64: postgresql.DOUBLE_PRECISION, } ) + _has_filter_syntax = True rewrites = PostgreSQLExprTranslator.rewrites