Skip to content

Commit

Permalink
feat(trino): implement Table.sample as a TABLESAMPLE query
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Oct 17, 2023
1 parent 3a80f3a commit f3d044c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
2 changes: 0 additions & 2 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,6 @@ def test_dynamic_table_slice_with_computed_offset(backend):
"polars",
"pyspark",
"snowflake",
"trino",
]
)
def test_sample(backend):
Expand Down Expand Up @@ -1571,7 +1570,6 @@ def test_sample(backend):
"polars",
"pyspark",
"snowflake",
"trino",
]
)
def test_sample_memtable(con, backend):
Expand Down
13 changes: 12 additions & 1 deletion ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter
from ibis.backends.trino.datatypes import TrinoType
from ibis.backends.trino.registry import operation_registry
from ibis.common.exceptions import UnsupportedOperationError


class TrinoSQLExprTranslator(AlchemyExprTranslator):
Expand Down Expand Up @@ -47,6 +48,16 @@ def _rewrite_string_contains(op):


class TrinoTableSetFormatter(_AlchemyTableSetFormatter):
def _format_sample(self, op, table):
if op.seed is not None:
raise UnsupportedOperationError(
"`Table.sample` with a random seed is unsupported"
)
method = sa.func.bernoulli if op.method == "row" else sa.func.system
return table.tablesample(
sampling=method(sa.literal_column(f"{op.fraction * 100}"))
)

def _format_in_memory_table(self, op, translator):
if not op.data:
return sa.select(
Expand All @@ -65,7 +76,7 @@ def _format_in_memory_table(self, op, translator):
for row in op.data.to_frame().itertuples(index=False)
]
columns = translator._schema_to_sqlalchemy_columns(op.schema)
return sa.values(*columns, name=op.name).data(rows)
return sa.values(*columns, name=op.name).data(rows).select().subquery()


class TrinoSQLCompiler(AlchemyCompiler):
Expand Down

0 comments on commit f3d044c

Please sign in to comment.