Skip to content

Commit

Permalink
refactor(sqlalchemy): centralize reduction compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed May 23, 2022
1 parent 8b7415f commit 505352b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 0 deletions.
30 changes: 30 additions & 0 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util
Expand Down Expand Up @@ -40,12 +43,39 @@ class AlchemyExprTranslator(ExprTranslator):

context_class = AlchemyContext

_bool_aggs_need_cast_to_int32 = True
_boolean_cast_ops = ops.Sum, ops.Mean, ops.Min, ops.Max
_has_filter_syntax = False

def name(self, translated, name, force=True):
return translated.label(name)

def get_sqla_type(self, data_type):
return to_sqla_type(data_type, type_map=self._type_map)

def _reduction(self, sa_func, expr):
op = expr.op()
arg = op.arg
if (
self._bool_aggs_need_cast_to_int32
and isinstance(op, self._boolean_cast_ops)
and isinstance(
type := arg.type(),
dt.Boolean,
)
):
arg = arg.cast(dt.Int32(nullable=type.nullable))

if (where := op.where) is not None:
if self._has_filter_syntax:
sa_arg = self.translate(arg).filter(self.translate(where))
else:
sa_arg = self.translate(where.ifelse(arg, None))
else:
sa_arg = self.translate(arg)

return sa_func(sa_arg)


rewrites = AlchemyExprTranslator.rewrites

Expand Down
1 change: 1 addition & 0 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator):
# type that duckdb doesn't understand, but we probably still want
# the updated `operation_registry` from postgres
_type_map = AlchemyExprTranslator._type_map.copy()
_has_filter_syntax = True


rewrites = DuckDBSQLExprTranslator.rewrites
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class MySQLExprTranslator(AlchemyExprTranslator):
dt.String: mysql.VARCHAR,
}
)
_bool_aggs_need_cast_to_int32 = False


rewrites = MySQLExprTranslator.rewrites
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator):
dt.Float64: postgresql.DOUBLE_PRECISION,
}
)
_has_filter_syntax = True


rewrites = PostgreSQLExprTranslator.rewrites
Expand Down

0 comments on commit 505352b

Please sign in to comment.