From e1870eae2404b43136c81e7de898e92e92c41926 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 16 Oct 2023 15:48:51 -0500 Subject: [PATCH] feat(sql): implement `Table.sample` as a `random()` filter across several SQL backends --- ibis/backends/clickhouse/compiler/core.py | 3 +- ibis/backends/druid/compiler.py | 2 + ibis/backends/mssql/compiler.py | 2 + ibis/backends/mysql/compiler.py | 2 + ibis/backends/oracle/__init__.py | 2 + ibis/backends/postgres/compiler.py | 2 + ibis/backends/sqlite/compiler.py | 2 + ibis/backends/tests/test_generic.py | 84 +++++++++++++++++++++++ ibis/expr/rewrites.py | 21 ++++++ 9 files changed, 119 insertions(+), 1 deletion(-) diff --git a/ibis/backends/clickhouse/compiler/core.py b/ibis/backends/clickhouse/compiler/core.py index 2b6590fcff66..7b00566458c3 100644 --- a/ibis/backends/clickhouse/compiler/core.py +++ b/ibis/backends/clickhouse/compiler/core.py @@ -29,7 +29,7 @@ from ibis.backends.clickhouse.compiler.values import translate_val from ibis.common.deferred import _ from ibis.expr.analysis import c, find_first_base_table, p, x, y -from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna +from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna, rewrite_sample if TYPE_CHECKING: from collections.abc import Mapping @@ -125,6 +125,7 @@ def fn(node, _, **kwargs): | nullify_empty_string_results | rewrite_fillna | rewrite_dropna + | rewrite_sample ) # apply translate rules in topological order node = op.map(fn)[op] diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index 60a48bbf5b45..6c766af97111 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -7,6 +7,7 @@ import ibis.backends.druid.datatypes as ddt from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.druid.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample class DruidExprTranslator(AlchemyExprTranslator): @@ -29,3 +30,4 @@ def translate(self, op): class DruidCompiler(AlchemyCompiler): translator_class = DruidExprTranslator null_limit = sa.literal_column("ALL") + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index 7722702eb95a..6a1d9c31b099 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -6,6 +6,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.mssql.datatypes import MSSQLType from ibis.backends.mssql.registry import _timestamp_from_unix, operation_registry +from ibis.expr.rewrites import rewrite_sample class MsSqlExprTranslator(AlchemyExprTranslator): @@ -35,3 +36,4 @@ class MsSqlCompiler(AlchemyCompiler): supports_indexed_grouping_keys = False null_limit = None + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index c33f49318bad..529dfe84b211 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -5,6 +5,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.mysql.datatypes import MySQLType from ibis.backends.mysql.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample class MySQLExprTranslator(AlchemyExprTranslator): @@ -24,3 +25,4 @@ class MySQLCompiler(AlchemyCompiler): translator_class = MySQLExprTranslator support_values_syntax_in_select = False null_limit = None + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index 03c79e93ac91..2095cb0b54b5 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -39,6 +39,7 @@ ) from ibis.backends.oracle.datatypes import OracleType # noqa: E402 from ibis.backends.oracle.registry import operation_registry # noqa: E402 +from ibis.expr.rewrites import rewrite_sample # noqa: E402 if TYPE_CHECKING: from collections.abc import Iterable @@ -73,6 +74,7 @@ class OracleCompiler(AlchemyCompiler): support_values_syntax_in_select = False supports_indexed_grouping_keys = False null_limit = None + rewrites = AlchemyCompiler.rewrites | rewrite_sample class Backend(BaseAlchemyBackend): diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index e5efc76fa7d2..48a5f24b0111 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -5,6 +5,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.postgres.datatypes import PostgresType from ibis.backends.postgres.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample class PostgresUDFNode(ops.Value): @@ -35,3 +36,4 @@ def _any_all_no_op(expr): class PostgreSQLCompiler(AlchemyCompiler): translator_class = PostgreSQLExprTranslator + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index 97a4604524af..09db8897ebfc 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -16,6 +16,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.sqlite.datatypes import SqliteType from ibis.backends.sqlite.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample class SQLiteExprTranslator(AlchemyExprTranslator): @@ -32,3 +33,4 @@ class SQLiteCompiler(AlchemyCompiler): translator_class = SQLiteExprTranslator support_values_syntax_in_select = False null_limit = None + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 79c102a47c9a..ee592f7aac88 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1527,3 +1527,87 @@ def test_dynamic_table_slice_with_computed_offset(backend): result = expr.to_pandas() backend.assert_frame_equal(result, expected) + + +@pytest.mark.notimpl( + [ + "bigquery", + "dask", + "datafusion", + "druid", + "duckdb", + "flink", + "impala", + "pandas", + "polars", + "pyspark", + "snowflake", + "trino", + ] +) +def test_sample(backend): + t = backend.functional_alltypes.filter(_.int_col >= 2) + + total_rows = t.count().execute() + empty = t.limit(1).execute().iloc[:0] + + df = t.sample(0.1, method="row").execute() + assert len(df) <= total_rows + backend.assert_frame_equal(empty, df.iloc[:0]) + + df = t.sample(0.1, method="block").execute() + assert len(df) <= total_rows + backend.assert_frame_equal(empty, df.iloc[:0]) + + +@pytest.mark.notimpl( + [ + "bigquery", + "dask", + "datafusion", + "druid", + "duckdb", + "flink", + "impala", + "pandas", + "polars", + "pyspark", + "snowflake", + "trino", + ] +) +def test_sample_memtable(con, backend): + df = pd.DataFrame({"x": [1, 2, 3, 4]}) + res = con.execute(ibis.memtable(df).sample(0.5)) + assert len(res) <= 4 + backend.assert_frame_equal(res.iloc[:0], df.iloc[:0]) + + +@pytest.mark.notimpl( + [ + "bigquery", + "clickhouse", + "dask", + "datafusion", + "druid", + "duckdb", + "flink", + "impala", + "mssql", + "mysql", + "oracle", + "pandas", + "polars", + "postgres", + "pyspark", + "snowflake", + "sqlite", + "trino", + ] +) +def test_sample_with_seed(backend): + t = backend.functional_alltypes + expr = t.sample(0.1, seed=1234) + df1 = expr.to_pandas() + df2 = expr.to_pandas() + backend.assert_frame_equal(df1, df2) diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 6cdbbbd21523..4ae694f0120c 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -6,6 +6,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.common.exceptions import UnsupportedOperationError from ibis.common.patterns import pattern, replace from ibis.util import Namespace @@ -58,3 +59,23 @@ def rewrite_dropna(_): return _.table return ops.Selection(_.table, (), preds, ()) + + +@replace(p.Sample) +def rewrite_sample(_): + """Rewrite Sample as `t.filter(random() <= fraction)`. + + Errors as unsupported if a `seed` is specified. + """ + + if _.seed is not None: + raise UnsupportedOperationError( + "`Table.sample` with a random seed is unsupported" + ) + + return ops.Selection( + _.table, + (), + (ops.LessEqual(ops.RandomScalar(), _.fraction),), + (), + )