Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sql): extract aggregate handling out into common utility class #9222

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,6 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.TimestampNow: "current_timestamp",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]

if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)

return func(*args, dialect=self.dialect)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
21 changes: 13 additions & 8 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,29 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import ClickHouseType
from ibis.backends.sql.dialects import ClickHouse


class ClickhouseAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
# Clickhouse aggregate functions all have filtering variants with a
# `If` suffix (e.g. `SumIf` instead of `Sum`).
if where is not None:
name += "If"
args += (where,)
return compiler.f[name](*args, dialect=compiler.dialect)


class ClickHouseCompiler(SQLGlotCompiler):
__slots__ = ()

dialect = ClickHouse
type_mapper = ClickHouseType

agg = ClickhouseAggGen()

UNSUPPORTED_OPS = (
ops.RowID,
ops.CumeDist,
Expand Down Expand Up @@ -104,13 +116,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.Unnest: "arrayJoin",
}

def _aggregate(self, funcname: str, *args, where):
has_filter = where is not None
func = self.f[funcname + "If" * has_filter]
args += (where,) * has_filter

return func(*args, dialect=self.dialect)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.common.temporal import IntervalUnit, TimestampUnit
Expand All @@ -25,6 +25,8 @@ class DataFusionCompiler(SQLGlotCompiler):
dialect = DataFusion
type_mapper = DataFusionType

agg = AggGen(supports_filter=True)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
Expand Down Expand Up @@ -73,12 +75,6 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.ArrayUnion: "array_union",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where))
return expr

def _to_timestamp(self, value, target_dtype, literal=False):
tz = (
f'Some("{timezone}")'
Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DruidType
from ibis.backends.sql.dialects import Druid

Expand All @@ -17,6 +17,8 @@ class DruidCompiler(SQLGlotCompiler):
dialect = Druid
type_mapper = DruidType

agg = AggGen(supports_filter=True)

LOWERED_OPS = {ops.Capitalize: None}

UNSUPPORTED_OPS = (
Expand Down Expand Up @@ -80,12 +82,6 @@ class DruidCompiler(SQLGlotCompiler):
ops.StringContains: "contains_string",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where))
return expr

def visit_Modulus(self, op, *, left, right):
return self.f.anon.mod(left, right)

Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType

_INTERVAL_SUFFIXES = {
Expand All @@ -33,6 +33,8 @@ class DuckDBCompiler(SQLGlotCompiler):
dialect = DuckDB
type_mapper = DuckDBType

agg = AggGen(supports_filter=True)

LOWERED_OPS = {
ops.Sample: None,
ops.StringSlice: None,
Expand Down Expand Up @@ -85,12 +87,6 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.GeoY: "st_y",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
[
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,6 @@ def _minimize_spec(start, end, spec):
return None
return spec

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

@staticmethod
def _gen_valid_name(name: str) -> str:
"""Exasol does not allow dots in quoted column names."""
Expand Down
61 changes: 32 additions & 29 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import FlinkType
from ibis.backends.sql.dialects import Flink
from ibis.backends.sql.rewrites import (
Expand All @@ -18,10 +18,41 @@
)


class FlinkAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
func = compiler.f[name]
if where is not None:
# Flink does support FILTER, but it's broken for:
#
# 1. certain aggregates: std/var doesn't return the right result
# 2. certain kinds of predicates: x IN y doesn't filter the right
# values out
# 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y)
# returns an incorrect result
#
# One solution is to try `IF(predicate, arg, NULL)`.
#
# Unfortunately that won't work without casting the NULL to a
# specific type.
#
# At this point in the Ibis compiler we don't have any of the Ibis
# operation's type information because we thrown it away. In every
# other engine Ibis supports the type of a NULL literal is inferred
# by the engine.
#
# Using a CASE statement and leaving out the explicit NULL does the
# trick for Flink.
args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args)
return func(*args)


class FlinkCompiler(SQLGlotCompiler):
quoted = True
dialect = Flink
type_mapper = FlinkType

agg = FlinkAggGen()

rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
Expand Down Expand Up @@ -96,34 +127,6 @@ def POS_INF(self):
def _generate_groups(groups):
return groups

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
# FILTER (WHERE ) is broken for one or both of:
#
# 1. certain aggregates: std/var doesn't return the right result
# 2. certain kinds of predicates: x IN y doesn't filter the right
# values out
# 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y)
# returns an incorrect result
#
# One solution is to try `IF(predicate, arg, NULL)`.
#
# Unfortunately that won't work without casting the NULL to a
# specific type.
#
# At this point in the Ibis compiler we don't have any of the Ibis
# operation's type information because we thrown it away. In every
# other engine Ibis supports the type of a NULL literal is inferred
# by the engine.
#
# Using a CASE statement and leaving out the explicit NULL does the
# trick for Flink.
#
# Le sigh.
args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args)
return func(*args)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ class ImpalaCompiler(SQLGlotCompiler):
ops.TypeOf: "typeof",
}

def _aggregate(self, funcname: str, *args, where):
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)

return self.f[funcname](*args, dialect=self.dialect)

@staticmethod
def _minimize_spec(start, end, spec):
# start is None means unbounded preceding
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,6 @@ def POS_INF(self):
def NEG_INF(self):
return self.f.double("-Infinity")

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

@staticmethod
def _generate_groups(groups):
return groups
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ def POS_INF(self):
ops.Log2: "log2",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,6 @@ class OracleCompiler(SQLGlotCompiler):
ops.Hash: "ora_hash",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg) for arg in args)
return func(*args)

@staticmethod
def _generate_groups(groups):
return groups
Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres

Expand All @@ -27,6 +27,8 @@ class PostgresCompiler(SQLGlotCompiler):
dialect = Postgres
type_mapper = PostgresType

agg = AggGen(supports_filter=True)

NAN = sge.Literal.number("'NaN'::double precision")
POS_INF = sge.Literal.number("'Inf'::double precision")
NEG_INF = sge.Literal.number("'-Inf'::double precision")
Expand Down Expand Up @@ -96,12 +98,6 @@ class PostgresCompiler(SQLGlotCompiler):
ops.TimeFromHMS: "make_time",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_RandomUUID(self, op, **kwargs):
return self.f.gen_random_uuid()

Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,6 @@ class PySparkCompiler(SQLGlotCompiler):
ops.UnwrapJSONBoolean: "unwrap_json_bool",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

def visit_InSubquery(self, op, *, rel, needle):
if op.needle.dtype.is_struct():
# construct the outer struct for pyspark
Expand Down
7 changes: 0 additions & 7 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ def __init__(self):
super().__init__()
self.f = SnowflakeFuncGen()

def _aggregate(self, funcname: str, *args, where):
if where is not None:
args = [self.if_(where, arg, NULL) for arg in args]

func = self.f[funcname]
return func(*args)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
Loading
Loading