Skip to content

Commit

Permalink
refactor: consolidate rewrite rule implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and kszucs committed Feb 12, 2024
1 parent 968cd7a commit c9b8a08
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 136 deletions.
7 changes: 7 additions & 0 deletions ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ def exclude_unsupported_window_frame_from_row_number(_, y):
return ops.Subtract(_.copy(frame=y.copy(start=None, end=0)), 1)


@replace(p.WindowFunction(p.MinRank | p.DenseRank, y @ p.WindowFrame(start=None)))
def exclude_unsupported_window_frame_from_rank(_, y):
return ops.Subtract(
_.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,))), 1
)


@replace(
p.WindowFunction(
p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All,
Expand Down
11 changes: 2 additions & 9 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,17 @@
from ibis.backends.base.sqlglot.datatypes import BigQueryType, BigQueryUDFType
from ibis.backends.base.sqlglot.rewrites import (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)
from ibis.common.patterns import replace
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit
from ibis.expr.rewrites import p, rewrite_sample, y
from ibis.expr.rewrites import rewrite_sample

_NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+')


@replace(p.WindowFunction(p.MinRank | p.DenseRank, y @ p.WindowFrame(start=None)))
def exclude_unsupported_window_frame_from_rank(_, y):
return ops.Subtract(
_.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,))), 1
)


class BigQueryCompiler(SQLGlotCompiler):
dialect = "bigquery"
type_mapper = BigQueryType
Expand Down
11 changes: 2 additions & 9 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from ibis.backends.base.sqlglot.datatypes import ExasolType
from ibis.backends.base.sqlglot.rewrites import (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
)
from ibis.common.patterns import replace
from ibis.expr.rewrites import p, rewrite_sample, y
from ibis.expr.rewrites import rewrite_sample


def _interval(self, e):
Expand All @@ -43,13 +43,6 @@ class Generator(Postgres.Generator):
}


@replace(p.WindowFunction(p.MinRank | p.DenseRank, y @ p.WindowFrame(start=None)))
def exclude_unsupported_window_frame_from_rank(_, y):
return ops.Subtract(
_.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,))), 1
)


class ExasolCompiler(SQLGlotCompiler):
__slots__ = ()

Expand Down
17 changes: 4 additions & 13 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
)
from ibis.backends.base.sqlglot.datatypes import MSSQLType
from ibis.backends.base.sqlglot.rewrites import (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)
from ibis.common.deferred import var
from ibis.common.patterns import replace
from ibis.expr.rewrites import p, rewrite_sample
from ibis.expr.rewrites import rewrite_sample


class MSSQL(TSQL):
Expand Down Expand Up @@ -65,16 +66,6 @@ class Generator(TSQL.Generator):
# * Boolean expressions MUST be used in a WHERE clause, i.e., SELECT * FROM t WHERE 1 is not allowed


@replace(p.WindowFunction(p.RowNumber | p.NTile, y))
def exclude_unsupported_window_frame_from_ops_with_offset(_, y):
return ops.Subtract(_.copy(frame=y.copy(start=None, end=0)), 1)


@replace(p.WindowFunction(p.Lag | p.Lead | p.PercentRank | p.CumeDist, y))
def exclude_unsupported_window_frame_from_ops(_, y):
return _.copy(frame=y.copy(start=None, end=0))


@public
class MSSQLCompiler(SQLGlotCompiler):
__slots__ = ()
Expand All @@ -86,7 +77,7 @@ class MSSQLCompiler(SQLGlotCompiler):
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_ops_with_offset,
exclude_unsupported_window_frame_from_row_number,
*SQLGlotCompiler.rewrites,
)
quoted = True
Expand Down
10 changes: 2 additions & 8 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from ibis.backends.base.sqlglot.datatypes import MySQLType
from ibis.backends.base.sqlglot.rewrites import (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)
from ibis.common.patterns import replace
from ibis.expr.rewrites import p, rewrite_sample, y
from ibis.expr.rewrites import p, rewrite_sample

MySQL.Generator.TRANSFORMS |= {
sge.LogicalOr: rename_func("max"),
Expand Down Expand Up @@ -56,13 +57,6 @@ def rewrite_limit(_, **kwargs):
return _


@replace(p.WindowFunction(p.MinRank | p.DenseRank, y @ p.WindowFrame(start=None)))
def exclude_unsupported_window_frame_from_rank(_, y):
return ops.Subtract(
_.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,))), 1
)


@public
class MySQLCompiler(SQLGlotCompiler):
__slots__ = ()
Expand Down
56 changes: 12 additions & 44 deletions ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
from sqlglot.dialects import Oracle
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func

import ibis
import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.base.sqlglot.datatypes import OracleType
from ibis.backends.base.sqlglot.rewrites import Window, replace_log2, replace_log10
from ibis.common.patterns import replace
from ibis.expr.analysis import p, x, y
from ibis.backends.base.sqlglot.rewrites import (
Window,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
replace_log2,
replace_log10,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)
from ibis.expr.rewrites import rewrite_sample


Expand Down Expand Up @@ -59,44 +65,6 @@ def _create_sql(self, expression: sge.Create) -> str:
Oracle.Generator.TZ_TO_WITH_TIME_ZONE = True


@replace(p.WindowFunction(p.First(x, y)))
def rewrite_first(_, x, y):
if y is not None:
raise com.UnsupportedOperationError(
"`first` aggregate over window does not support `where`"
)
return _.copy(func=ops.FirstValue(x))


@replace(p.WindowFunction(p.Last(x, y)))
def rewrite_last(_, x, y):
if y is not None:
raise com.UnsupportedOperationError(
"`last` aggregate over window does not support `where`"
)
return _.copy(func=ops.LastValue(x))


@replace(p.WindowFunction(frame=x @ p.WindowFrame(order_by=())))
def rewrite_empty_order_by_window(_, x):
return _.copy(frame=x.copy(order_by=(ibis.NA,)))


@replace(p.WindowFunction(p.RowNumber | p.NTile, x))
def exclude_unsupported_window_frame_from_row_number(_, x):
return ops.Subtract(_.copy(frame=x.copy(start=None, end=None)), 1)


@replace(
p.WindowFunction(
p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All,
x @ p.WindowFrame(start=None),
)
)
def exclude_unsupported_window_frame_from_ops(_, x):
return _.copy(frame=x.copy(start=None, end=None))


@public
class OracleCompiler(SQLGlotCompiler):
__slots__ = ()
Expand All @@ -107,8 +75,8 @@ class OracleCompiler(SQLGlotCompiler):
rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_first,
rewrite_last,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
rewrite_sample,
replace_log2,
Expand Down
55 changes: 12 additions & 43 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,61 +10,30 @@
from sqlglot.dialects import Snowflake
from sqlglot.dialects.dialect import rename_func

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.base.sqlglot.compiler import NULL, C, FuncGen, SQLGlotCompiler
from ibis.backends.base.sqlglot.datatypes import SnowflakeType
from ibis.backends.base.sqlglot.rewrites import replace_log2, replace_log10
from ibis.backends.base.sqlglot.rewrites import (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
replace_log2,
replace_log10,
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
)
from ibis.common.patterns import replace
from ibis.expr.analysis import p, x, y
from ibis.expr.analysis import p

Snowflake.Generator.TRANSFORMS |= {
exp.ApproxDistinct: rename_func("approx_count_distinct"),
exp.Levenshtein: rename_func("editdistance"),
}


@replace(p.WindowFunction(p.First(x, y)))
def rewrite_first(_, x, y):
if y is not None:
raise com.UnsupportedOperationError(
"`first` aggregate over window does not support `where`"
)
return _.copy(func=ops.FirstValue(x))


@replace(p.WindowFunction(p.Last(x, y)))
def rewrite_last(_, x, y):
if y is not None:
raise com.UnsupportedOperationError(
"`last` aggregate over window does not support `where`"
)
return _.copy(func=ops.LastValue(x))


@replace(p.WindowFunction(frame=x @ p.WindowFrame(order_by=())))
def rewrite_empty_order_by_window(_, x):
return _.copy(frame=x.copy(order_by=(ibis.NA,)))


@replace(p.WindowFunction(p.RowNumber | p.NTile, x))
def exclude_unsupported_window_frame_from_row_number(_, x):
return ops.Subtract(_.copy(frame=x.copy(start=None, end=None)), 1)


@replace(
p.WindowFunction(
p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All,
x @ p.WindowFrame(start=None),
)
)
def exclude_unsupported_window_frame_from_ops(_, x):
return _.copy(frame=x.copy(start=None, end=None))


@replace(p.ToJSONMap | p.ToJSONArray)
def replace_to_json(_):
return ops.Cast(_.arg, to=_.dtype)
Expand All @@ -86,8 +55,8 @@ class SnowflakeCompiler(SQLGlotCompiler):
replace_to_json,
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_first,
rewrite_last,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
replace_log2,
replace_log10,
Expand Down
10 changes: 0 additions & 10 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,11 +810,6 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes):
raises=AssertionError,
strict=False, # sometimes it passes
),
pytest.mark.notyet(
["oracle"],
raises=com.UnsupportedOperationError,
reason="oracle doesn't allow unordered analytic functions without a windowing clause",
),
pytest.mark.notimpl(
["flink"],
raises=com.UnsupportedOperationError,
Expand Down Expand Up @@ -861,11 +856,6 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes):
raises=AssertionError,
strict=False, # sometimes it passes
),
pytest.mark.notyet(
["oracle"],
raises=com.UnsupportedOperationError,
reason="oracle doesn't allow unordered analytic functions without a windowing clause",
),
pytest.mark.notimpl(
["flink"],
raises=com.UnsupportedOperationError,
Expand Down

0 comments on commit c9b8a08

Please sign in to comment.