Skip to content

Commit

Permalink
refactor(sql): extract aggregate handling out into common utility class
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed May 21, 2024
1 parent 5151906 commit 9d12ebc
Show file tree
Hide file tree
Showing 17 changed files with 139 additions and 156 deletions.
8 changes: 0 additions & 8 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,6 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.TimestampNow: "current_timestamp",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]

if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)

return func(*args, dialect=self.dialect)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
21 changes: 13 additions & 8 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,29 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import ClickHouseType
from ibis.backends.sql.dialects import ClickHouse


class ClickhouseAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
# Clickhouse aggregate functions all have filtering variants with a
# `If` suffix (e.g. `SumIf` instead of `Sum`).
if where is not None:
name += "If"
args += (where,)
return compiler.f[name](*args, dialect=compiler.dialect)


class ClickHouseCompiler(SQLGlotCompiler):
__slots__ = ()

dialect = ClickHouse
type_mapper = ClickHouseType

agg = ClickhouseAggGen()

UNSUPPORTED_OPS = (
ops.RowID,
ops.CumeDist,
Expand Down Expand Up @@ -104,13 +116,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.Unnest: "arrayJoin",
}

def _aggregate(self, funcname: str, *args, where):
has_filter = where is not None
func = self.f[funcname + "If" * has_filter]
args += (where,) * has_filter

return func(*args, dialect=self.dialect)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.common.temporal import IntervalUnit, TimestampUnit
Expand All @@ -25,6 +25,8 @@ class DataFusionCompiler(SQLGlotCompiler):
dialect = DataFusion
type_mapper = DataFusionType

agg = AggGen(supports_filter=True)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
Expand Down Expand Up @@ -73,12 +75,6 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.ArrayUnion: "array_union",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where))
return expr

def _to_timestamp(self, value, target_dtype, literal=False):
tz = (
f'Some("{timezone}")'
Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DruidType
from ibis.backends.sql.dialects import Druid

Expand All @@ -17,6 +17,8 @@ class DruidCompiler(SQLGlotCompiler):
dialect = Druid
type_mapper = DruidType

agg = AggGen(supports_filter=True)

LOWERED_OPS = {ops.Capitalize: None}

UNSUPPORTED_OPS = (
Expand Down Expand Up @@ -80,12 +82,6 @@ class DruidCompiler(SQLGlotCompiler):
ops.StringContains: "contains_string",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where))
return expr

def visit_Modulus(self, op, *, left, right):
return self.f.anon.mod(left, right)

Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType

_INTERVAL_SUFFIXES = {
Expand All @@ -33,6 +33,8 @@ class DuckDBCompiler(SQLGlotCompiler):
dialect = DuckDB
type_mapper = DuckDBType

agg = AggGen(supports_filter=True)

LOWERED_OPS = {
ops.Sample: None,
ops.StringSlice: None,
Expand Down Expand Up @@ -85,12 +87,6 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.GeoY: "st_y",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
[
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,6 @@ def _minimize_spec(start, end, spec):
return None
return spec

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

@staticmethod
def _gen_valid_name(name: str) -> str:
"""Exasol does not allow dots in quoted column names."""
Expand Down
61 changes: 32 additions & 29 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import FlinkType
from ibis.backends.sql.dialects import Flink
from ibis.backends.sql.rewrites import (
Expand All @@ -18,10 +18,41 @@
)


class FlinkAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
func = compiler.f[name]
if where is not None:
# Flink does support FILTER, but it's broken for:
#
# 1. certain aggregates: std/var doesn't return the right result
# 2. certain kinds of predicates: x IN y doesn't filter the right
# values out
# 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y)
# returns an incorrect result
#
# One solution is to try `IF(predicate, arg, NULL)`.
#
# Unfortunately that won't work without casting the NULL to a
# specific type.
#
# At this point in the Ibis compiler we don't have any of the Ibis
# operation's type information because we thrown it away. In every
# other engine Ibis supports the type of a NULL literal is inferred
# by the engine.
#
# Using a CASE statement and leaving out the explicit NULL does the
# trick for Flink.
args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args)
return func(*args)


class FlinkCompiler(SQLGlotCompiler):
quoted = True
dialect = Flink
type_mapper = FlinkType

agg = FlinkAggGen()

rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
Expand Down Expand Up @@ -96,34 +127,6 @@ def POS_INF(self):
def _generate_groups(groups):
return groups

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
# FILTER (WHERE ) is broken for one or both of:
#
# 1. certain aggregates: std/var doesn't return the right result
# 2. certain kinds of predicates: x IN y doesn't filter the right
# values out
# 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y)
# returns an incorrect result
#
# One solution is to try `IF(predicate, arg, NULL)`.
#
# Unfortunately that won't work without casting the NULL to a
# specific type.
#
# At this point in the Ibis compiler we don't have any of the Ibis
# operation's type information because we thrown it away. In every
# other engine Ibis supports the type of a NULL literal is inferred
# by the engine.
#
# Using a CASE statement and leaving out the explicit NULL does the
# trick for Flink.
#
# Le sigh.
args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args)
return func(*args)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ class ImpalaCompiler(SQLGlotCompiler):
ops.TypeOf: "typeof",
}

def _aggregate(self, funcname: str, *args, where):
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)

return self.f[funcname](*args, dialect=self.dialect)

@staticmethod
def _minimize_spec(start, end, spec):
# start is None means unbounded preceding
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,6 @@ def POS_INF(self):
def NEG_INF(self):
return self.f.double("-Infinity")

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

@staticmethod
def _generate_groups(groups):
return groups
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ def POS_INF(self):
ops.Log2: "log2",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,6 @@ class OracleCompiler(SQLGlotCompiler):
ops.Hash: "ora_hash",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg) for arg in args)
return func(*args)

@staticmethod
def _generate_groups(groups):
return groups
Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres

Expand All @@ -27,6 +27,8 @@ class PostgresCompiler(SQLGlotCompiler):
dialect = Postgres
type_mapper = PostgresType

agg = AggGen(supports_filter=True)

NAN = sge.Literal.number("'NaN'::double precision")
POS_INF = sge.Literal.number("'Inf'::double precision")
NEG_INF = sge.Literal.number("'-Inf'::double precision")
Expand Down Expand Up @@ -96,12 +98,6 @@ class PostgresCompiler(SQLGlotCompiler):
ops.TimeFromHMS: "make_time",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_RandomUUID(self, op, **kwargs):
return self.f.gen_random_uuid()

Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,6 @@ class PySparkCompiler(SQLGlotCompiler):
ops.UnwrapJSONBoolean: "unwrap_json_bool",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

def visit_InSubquery(self, op, *, rel, needle):
if op.needle.dtype.is_struct():
# construct the outer struct for pyspark
Expand Down
7 changes: 0 additions & 7 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ def __init__(self):
super().__init__()
self.f = SnowflakeFuncGen()

def _aggregate(self, funcname: str, *args, where):
if where is not None:
args = [self.if_(where, arg, NULL) for arg in args]

func = self.f[funcname]
return func(*args)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
Loading

0 comments on commit 9d12ebc

Please sign in to comment.