From 927865eca8aac60f54ad8363a81cfe095495fbaa Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Tue, 24 Sep 2024 21:44:29 -0500 Subject: [PATCH] fix(snowflake): only compile `sample` to `TABLESAMPLE` on physical tables --- ibis/backends/sql/compilers/snowflake.py | 7 ++++++- .../test_sample/snowflake-table/block.sql | 4 +++- .../test_sample/snowflake-table/row.sql | 4 +++- ibis/backends/tests/test_generic.py | 17 +---------------- 4 files changed, 13 insertions(+), 19 deletions(-) diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 960cb1b79a25..e204f7115f7e 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -60,7 +60,12 @@ class SnowflakeCompiler(SQLGlotCompiler): LOWERED_OPS = { ops.Log2: lower_log2, ops.Log10: lower_log10, - ops.Sample: lower_sample(), + # Snowflake's TABLESAMPLE _can_ work on subqueries, but only by row and without + # a seed. This is effectively the same as `t.filter(random() <= fraction)`, and + # using TABLESAMPLE here would almost certainly have no benefit over the filter + # version in the optimized physical plan. To avoid a special case just for + # snowflake, we only use TABLESAMPLE on physical tables. + ops.Sample: lower_sample(physical_tables_only=True), } UNSUPPORTED_OPS = ( diff --git a/ibis/backends/tests/snapshots/test_sql/test_sample/snowflake-table/block.sql b/ibis/backends/tests/snapshots/test_sql/test_sample/snowflake-table/block.sql index 2c9987d23ddf..15550a434acc 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_sample/snowflake-table/block.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_sample/snowflake-table/block.sql @@ -6,4 +6,6 @@ FROM ( FROM "test" AS "t0" WHERE "t0"."x" > 10 -) AS "t1" TABLESAMPLE system (50.0) \ No newline at end of file +) AS "t1" +WHERE + UNIFORM(TO_DOUBLE(0.0), TO_DOUBLE(1.0), RANDOM()) <= 0.5 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_sample/snowflake-table/row.sql b/ibis/backends/tests/snapshots/test_sql/test_sample/snowflake-table/row.sql index 38eb63631277..15550a434acc 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_sample/snowflake-table/row.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_sample/snowflake-table/row.sql @@ -6,4 +6,6 @@ FROM ( FROM "test" AS "t0" WHERE "t0"."x" > 10 -) AS "t1" TABLESAMPLE bernoulli (50.0) \ No newline at end of file +) AS "t1" +WHERE + UNIFORM(TO_DOUBLE(0.0), TO_DOUBLE(1.0), RANDOM()) <= 0.5 \ No newline at end of file diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 714f987d76cb..0ca4dc1be0a6 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -2125,22 +2125,7 @@ def test_dynamic_table_slice_with_computed_offset(backend): @pytest.mark.notimpl(["druid", "risingwave"], raises=com.OperationNotDefinedError) -@pytest.mark.parametrize( - "method", - [ - "row", - param( - "block", - marks=[ - pytest.mark.notimpl( - ["snowflake"], - raises=SnowflakeProgrammingError, - reason="SAMPLE clause on views only supports row wise sampling without seed.", - ) - ], - ), - ], -) +@pytest.mark.parametrize("method", ["row", "block"]) @pytest.mark.parametrize("subquery", [True, False], ids=["subquery", "table"]) @pytest.mark.xfail_version(pyspark=["sqlglot==25.17.0"]) def test_sample(backend, method, alltypes, subquery):