diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 4782aa8efc79..d76cefaafa1a 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1541,7 +1541,6 @@ def test_dynamic_table_slice_with_computed_offset(backend): "polars", "pyspark", "snowflake", - "trino", ] ) def test_sample(backend): @@ -1571,7 +1570,6 @@ def test_sample(backend): "polars", "pyspark", "snowflake", - "trino", ] ) def test_sample_memtable(con, backend): diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 1f9911b61a27..e8d199daead5 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -7,6 +7,7 @@ from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter from ibis.backends.trino.datatypes import TrinoType from ibis.backends.trino.registry import operation_registry +from ibis.common.exceptions import UnsupportedOperationError class TrinoSQLExprTranslator(AlchemyExprTranslator): @@ -47,6 +48,16 @@ def _rewrite_string_contains(op): class TrinoTableSetFormatter(_AlchemyTableSetFormatter): + def _format_sample(self, op, table): + if op.seed is not None: + raise UnsupportedOperationError( + "`Table.sample` with a random seed is unsupported" + ) + method = sa.func.bernoulli if op.method == "row" else sa.func.system + return table.tablesample( + sampling=method(sa.literal_column(f"{op.fraction * 100}")) + ) + def _format_in_memory_table(self, op, translator): if not op.data: return sa.select( @@ -65,7 +76,7 @@ def _format_in_memory_table(self, op, translator): for row in op.data.to_frame().itertuples(index=False) ] columns = translator._schema_to_sqlalchemy_columns(op.schema) - return sa.values(*columns, name=op.name).data(rows) + return sa.values(*columns, name=op.name).data(rows).select().subquery() class TrinoSQLCompiler(AlchemyCompiler):