diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index c0915d0598cc..81db090fe9a0 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -11,7 +11,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util -from ibis.backends.sql.compiler import NULL, C, FuncGen, SQLGlotCompiler +from ibis.backends.sql.compiler import NULL, STAR, C, FuncGen, SQLGlotCompiler from ibis.backends.sql.datatypes import SnowflakeType from ibis.backends.sql.dialects import Snowflake from ibis.backends.sql.rewrites import ( @@ -622,3 +622,14 @@ def visit_TimestampRange(self, op, *, start, stop, step): ) .subquery() ) + + def visit_Sample( + self, op, *, parent, fraction: float, method: str, seed: int | None, **_ + ): + sample = sge.TableSample( + this=parent, + method="bernoulli" if method == "row" else "system", + percent=sge.convert(fraction * 100.0), + seed=None if seed is None else sge.convert(seed), + ) + return sg.select(STAR).from_(sample) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 6247fc9eddce..92c6a2a96910 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1950,28 +1950,40 @@ def test_dynamic_table_slice_with_computed_offset(backend): backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["druid", "polars", "snowflake"]) +@pytest.mark.notimpl(["druid", "polars"]) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, reason="function random() does not exist", ) -def test_sample(backend): +@pytest.mark.parametrize( + "method", + [ + "row", + param( + "block", + marks=[ + pytest.mark.notimpl( + ["snowflake"], + raises=SnowflakeProgrammingError, + reason="SAMPLE clause on views only supports row wise sampling without seed.", + ) + ], + ), + ], +) +def test_sample(backend, method): t = backend.functional_alltypes.filter(_.int_col >= 2) total_rows = t.count().execute() empty = t.limit(1).execute().iloc[:0] - df = t.sample(0.1, method="row").execute() - assert len(df) <= total_rows - backend.assert_frame_equal(empty, df.iloc[:0]) - - df = t.sample(0.1, method="block").execute() + df = t.sample(0.1, method=method).execute() assert len(df) <= total_rows backend.assert_frame_equal(empty, df.iloc[:0]) -@pytest.mark.notimpl(["druid", "polars", "snowflake"]) +@pytest.mark.notimpl(["druid", "polars"]) @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, @@ -1998,7 +2010,6 @@ def test_sample_memtable(con, backend): "polars", "postgres", "risingwave", - "snowflake", "sqlite", "trino", "exasol",