diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index 9afbb0838d8f..c70660a7f739 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -20,9 +20,9 @@ exclude_unsupported_window_frame_from_row_number, rewrite_first_to_first_value, rewrite_last_to_last_value, + rewrite_sample_as_filter, ) from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit -from ibis.expr.rewrites import rewrite_sample _NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') @@ -32,7 +32,7 @@ class BigQueryCompiler(SQLGlotCompiler): type_mapper = BigQueryType udf_type_mapper = BigQueryUDFType rewrites = ( - rewrite_sample, + rewrite_sample_as_filter, rewrite_first_to_first_value, rewrite_last_to_last_value, exclude_unsupported_window_frame_from_ops, diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index 6439fb99a5e2..b561d841f8bd 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -21,8 +21,8 @@ SQLGlotCompiler, parenthesize, ) +from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter from ibis.backends.clickhouse.datatypes import ClickhouseType -from ibis.expr.rewrites import rewrite_sample ClickHouse.Generator.TRANSFORMS |= { exp.ArraySize: rename_func("length"), @@ -37,7 +37,7 @@ class ClickHouseCompiler(SQLGlotCompiler): dialect = "clickhouse" type_mapper = ClickhouseType - rewrites = (rewrite_sample, *SQLGlotCompiler.rewrites) + rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) def _aggregate(self, funcname: str, *args, where): has_filter = where is not None diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 755f0c366747..84b4ae5dc3e0 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -22,9 +22,9 @@ paren, ) from ibis.backends.base.sqlglot.datatypes import DataFusionType +from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter from ibis.common.temporal import IntervalUnit, TimestampUnit from ibis.expr.operations.udf import InputType -from ibis.expr.rewrites import rewrite_sample from ibis.formats.pyarrow import PyArrowType @@ -48,7 +48,7 @@ class DataFusionCompiler(SQLGlotCompiler): dialect = "datafusion" type_mapper = DataFusionType quoted = True - rewrites = (rewrite_sample, *SQLGlotCompiler.rewrites) + rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) def _aggregate(self, funcname: str, *args, where): expr = self.f[funcname](*args) diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index 43fd1c1559fd..31a60f7e85be 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -14,7 +14,7 @@ import ibis.expr.operations as ops from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import DruidType -from ibis.expr.rewrites import rewrite_sample +from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter # Is postgres the best dialect to inherit from? @@ -34,7 +34,7 @@ class DruidCompiler(SQLGlotCompiler): dialect = "druid" type_mapper = DruidType quoted = True - rewrites = (rewrite_sample, *SQLGlotCompiler.rewrites) + rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) def _aggregate(self, funcname: str, *args, where): expr = self.f[funcname](*args) diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 3f52b732aa10..071e597c6f05 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -16,8 +16,8 @@ exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, rewrite_empty_order_by_window, + rewrite_sample_as_filter, ) -from ibis.expr.rewrites import rewrite_sample def _interval(self, e): @@ -50,7 +50,7 @@ class ExasolCompiler(SQLGlotCompiler): type_mapper = ExasolType quoted = True rewrites = ( - rewrite_sample, + rewrite_sample_as_filter, exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, diff --git a/ibis/backends/impala/compiler.py b/ibis/backends/impala/compiler.py index 1afba29a20ea..73840a47cc44 100644 --- a/ibis/backends/impala/compiler.py +++ b/ibis/backends/impala/compiler.py @@ -18,8 +18,8 @@ rewrite_empty_order_by_window, rewrite_first_to_first_value, rewrite_last_to_last_value, + rewrite_sample_as_filter, ) -from ibis.expr.rewrites import rewrite_sample def _interval(self, e): @@ -48,7 +48,7 @@ class ImpalaCompiler(SQLGlotCompiler): dialect = "impala" type_mapper = ImpalaType rewrites = ( - rewrite_sample, + rewrite_sample_as_filter, rewrite_first_to_first_value, rewrite_last_to_last_value, rewrite_empty_order_by_window, diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index f34b443a7662..604250ba2b3c 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -26,9 +26,9 @@ exclude_unsupported_window_frame_from_row_number, rewrite_first_to_first_value, rewrite_last_to_last_value, + rewrite_sample_as_filter, ) from ibis.common.deferred import var -from ibis.expr.rewrites import rewrite_sample class MSSQL(TSQL): @@ -73,7 +73,7 @@ class MSSQLCompiler(SQLGlotCompiler): dialect = "mssql" type_mapper = MSSQLType rewrites = ( - rewrite_sample, + rewrite_sample_as_filter, rewrite_first_to_first_value, rewrite_last_to_last_value, exclude_unsupported_window_frame_from_ops, diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index 1a89c77c536d..2277de215fb6 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -21,9 +21,10 @@ rewrite_empty_order_by_window, rewrite_first_to_first_value, rewrite_last_to_last_value, + rewrite_sample_as_filter, ) from ibis.common.patterns import replace -from ibis.expr.rewrites import p, rewrite_sample +from ibis.expr.rewrites import p MySQL.Generator.TRANSFORMS |= { sge.LogicalOr: rename_func("max"), @@ -65,7 +66,7 @@ class MySQLCompiler(SQLGlotCompiler): type_mapper = MySQLType rewrites = ( rewrite_limit, - rewrite_sample, + rewrite_sample_as_filter, rewrite_first_to_first_value, rewrite_last_to_last_value, exclude_unsupported_window_frame_from_ops, diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py index 93c5f962455a..ad398e7e7d2a 100644 --- a/ibis/backends/oracle/compiler.py +++ b/ibis/backends/oracle/compiler.py @@ -22,8 +22,8 @@ rewrite_empty_order_by_window, rewrite_first_to_first_value, rewrite_last_to_last_value, + rewrite_sample_as_filter, ) -from ibis.expr.rewrites import rewrite_sample def _create_sql(self, expression: sge.Create) -> str: @@ -78,7 +78,7 @@ class OracleCompiler(SQLGlotCompiler): rewrite_first_to_first_value, rewrite_last_to_last_value, rewrite_empty_order_by_window, - rewrite_sample, + rewrite_sample_as_filter, replace_log2, replace_log10, *SQLGlotCompiler.rewrites, diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index e6b44829a49e..af640cdd0f98 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -15,7 +15,7 @@ import ibis.expr.rules as rlz from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler, paren from ibis.backends.base.sqlglot.datatypes import PostgresType -from ibis.expr.rewrites import rewrite_sample +from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter Postgres.Generator.TRANSFORMS |= { sge.Map: rename_func("hstore"), @@ -37,7 +37,7 @@ class PostgresCompiler(SQLGlotCompiler): dialect = "postgres" type_mapper = PostgresType - rewrites = rewrite_sample, *SQLGlotCompiler.rewrites + rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) quoted = True NAN = sge.Literal.number("'NaN'::double precision") diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index 5e1b5f91abfc..f02a77b9b361 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -15,9 +15,9 @@ from ibis.backends.base.sqlglot.rewrites import ( rewrite_first_to_first_value, rewrite_last_to_last_value, + rewrite_sample_as_filter, ) from ibis.common.temporal import DateUnit, IntervalUnit -from ibis.expr.rewrites import rewrite_sample SQLite.Generator.TYPE_MAPPING |= { sge.DataType.Type.BOOLEAN: "BOOLEAN", @@ -31,10 +31,11 @@ class SQLiteCompiler(SQLGlotCompiler): dialect = "sqlite" quoted = True type_mapper = SQLiteType - rewrites = SQLGlotCompiler.rewrites + ( - rewrite_sample, + rewrites = ( + rewrite_sample_as_filter, rewrite_first_to_first_value, rewrite_last_to_last_value, + *SQLGlotCompiler.rewrites, ) NAN = NULL diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 2d214c685110..97dd85221744 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -18,8 +18,8 @@ exclude_unsupported_window_frame_from_ops, rewrite_first_to_first_value, rewrite_last_to_last_value, + rewrite_sample_as_filter, ) -from ibis.expr.rewrites import rewrite_sample # TODO(cpcloud): remove this hack once @@ -43,7 +43,7 @@ class TrinoCompiler(SQLGlotCompiler): dialect = "trino" type_mapper = TrinoType rewrites = ( - rewrite_sample, + rewrite_sample_as_filter, rewrite_first_to_first_value, rewrite_last_to_last_value, exclude_unsupported_window_frame_from_ops, diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 8b1ea87de95d..fd2799f8ab50 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -6,7 +6,6 @@ import toolz -import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.deferred import Item, _, deferred, var @@ -82,20 +81,6 @@ def rewrite_dropna(_): return ops.Filter(_.parent, tuple(preds)) -@replace(p.Sample) -def rewrite_sample(_): - """Rewrite Sample as `t.filter(random() <= fraction)`. - - Errors as unsupported if a `seed` is specified. - """ - if _.seed is not None: - raise com.UnsupportedOperationError( - "`Table.sample` with a random seed is unsupported" - ) - pred = ops.LessEqual(ops.RandomScalar(), _.fraction) - return ops.Filter(_.parent, (pred,)) - - @replace(p.Analytic) def project_wrap_analytic(_, rel): # Wrap analytic functions in a window function