Skip to content

Commit

Permalink
refactor: remove duplicated rewrite_sample rule
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and kszucs committed Feb 7, 2024
1 parent 25ef8cc commit ea40f57
Show file tree
Hide file tree
Showing 13 changed files with 27 additions and 40 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit
from ibis.expr.rewrites import rewrite_sample

_NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+')

Expand All @@ -32,7 +32,7 @@ class BigQueryCompiler(SQLGlotCompiler):
type_mapper = BigQueryType
udf_type_mapper = BigQueryUDFType
rewrites = (
rewrite_sample,
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
SQLGlotCompiler,
parenthesize,
)
from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter
from ibis.backends.clickhouse.datatypes import ClickhouseType
from ibis.expr.rewrites import rewrite_sample

ClickHouse.Generator.TRANSFORMS |= {
exp.ArraySize: rename_func("length"),
Expand All @@ -37,7 +37,7 @@ class ClickHouseCompiler(SQLGlotCompiler):

dialect = "clickhouse"
type_mapper = ClickhouseType
rewrites = (rewrite_sample, *SQLGlotCompiler.rewrites)
rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites)

def _aggregate(self, funcname: str, *args, where):
has_filter = where is not None
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
paren,
)
from ibis.backends.base.sqlglot.datatypes import DataFusionType
from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType
from ibis.expr.rewrites import rewrite_sample
from ibis.formats.pyarrow import PyArrowType


Expand All @@ -48,7 +48,7 @@ class DataFusionCompiler(SQLGlotCompiler):
dialect = "datafusion"
type_mapper = DataFusionType
quoted = True
rewrites = (rewrite_sample, *SQLGlotCompiler.rewrites)
rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites)

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import ibis.expr.operations as ops
from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler
from ibis.backends.base.sqlglot.datatypes import DruidType
from ibis.expr.rewrites import rewrite_sample
from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter


# Is postgres the best dialect to inherit from?
Expand All @@ -34,7 +34,7 @@ class DruidCompiler(SQLGlotCompiler):
dialect = "druid"
type_mapper = DruidType
quoted = True
rewrites = (rewrite_sample, *SQLGlotCompiler.rewrites)
rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites)

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
rewrite_sample_as_filter,
)
from ibis.expr.rewrites import rewrite_sample


def _interval(self, e):
Expand Down Expand Up @@ -50,7 +50,7 @@ class ExasolCompiler(SQLGlotCompiler):
type_mapper = ExasolType
quoted = True
rewrites = (
rewrite_sample,
rewrite_sample_as_filter,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.expr.rewrites import rewrite_sample


def _interval(self, e):
Expand Down Expand Up @@ -48,7 +48,7 @@ class ImpalaCompiler(SQLGlotCompiler):
dialect = "impala"
type_mapper = ImpalaType
rewrites = (
rewrite_sample,
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
exclude_unsupported_window_frame_from_row_number,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.deferred import var
from ibis.expr.rewrites import rewrite_sample


class MSSQL(TSQL):
Expand Down Expand Up @@ -73,7 +73,7 @@ class MSSQLCompiler(SQLGlotCompiler):
dialect = "mssql"
type_mapper = MSSQLType
rewrites = (
rewrite_sample,
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.patterns import replace
from ibis.expr.rewrites import p, rewrite_sample
from ibis.expr.rewrites import p

MySQL.Generator.TRANSFORMS |= {
sge.LogicalOr: rename_func("max"),
Expand Down Expand Up @@ -65,7 +66,7 @@ class MySQLCompiler(SQLGlotCompiler):
type_mapper = MySQLType
rewrites = (
rewrite_limit,
rewrite_sample,
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
rewrite_empty_order_by_window,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.expr.rewrites import rewrite_sample


def _create_sql(self, expression: sge.Create) -> str:
Expand Down Expand Up @@ -78,7 +78,7 @@ class OracleCompiler(SQLGlotCompiler):
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_empty_order_by_window,
rewrite_sample,
rewrite_sample_as_filter,
replace_log2,
replace_log10,
*SQLGlotCompiler.rewrites,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ibis.expr.rules as rlz
from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler, paren
from ibis.backends.base.sqlglot.datatypes import PostgresType
from ibis.expr.rewrites import rewrite_sample
from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter

Postgres.Generator.TRANSFORMS |= {
sge.Map: rename_func("hstore"),
Expand All @@ -37,7 +37,7 @@ class PostgresCompiler(SQLGlotCompiler):

dialect = "postgres"
type_mapper = PostgresType
rewrites = rewrite_sample, *SQLGlotCompiler.rewrites
rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites)
quoted = True

NAN = sge.Literal.number("'NaN'::double precision")
Expand Down
7 changes: 4 additions & 3 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from ibis.backends.base.sqlglot.rewrites import (
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.common.temporal import DateUnit, IntervalUnit
from ibis.expr.rewrites import rewrite_sample

SQLite.Generator.TYPE_MAPPING |= {
sge.DataType.Type.BOOLEAN: "BOOLEAN",
Expand All @@ -31,10 +31,11 @@ class SQLiteCompiler(SQLGlotCompiler):
dialect = "sqlite"
quoted = True
type_mapper = SQLiteType
rewrites = SQLGlotCompiler.rewrites + (
rewrite_sample,
rewrites = (
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
*SQLGlotCompiler.rewrites,
)

NAN = NULL
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
exclude_unsupported_window_frame_from_ops,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
rewrite_sample_as_filter,
)
from ibis.expr.rewrites import rewrite_sample


# TODO(cpcloud): remove this hack once
Expand All @@ -43,7 +43,7 @@ class TrinoCompiler(SQLGlotCompiler):
dialect = "trino"
type_mapper = TrinoType
rewrites = (
rewrite_sample,
rewrite_sample_as_filter,
rewrite_first_to_first_value,
rewrite_last_to_last_value,
exclude_unsupported_window_frame_from_ops,
Expand Down
15 changes: 0 additions & 15 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import toolz

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.deferred import Item, _, deferred, var
Expand Down Expand Up @@ -82,20 +81,6 @@ def rewrite_dropna(_):
return ops.Filter(_.parent, tuple(preds))


@replace(p.Sample)
def rewrite_sample(_):
"""Rewrite Sample as `t.filter(random() <= fraction)`.
Errors as unsupported if a `seed` is specified.
"""
if _.seed is not None:
raise com.UnsupportedOperationError(
"`Table.sample` with a random seed is unsupported"
)
pred = ops.LessEqual(ops.RandomScalar(), _.fraction)
return ops.Filter(_.parent, (pred,))


@replace(p.Analytic)
def project_wrap_analytic(_, rel):
# Wrap analytic functions in a window function
Expand Down

0 comments on commit ea40f57

Please sign in to comment.