Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): Add Table.sample #7377

Merged
merged 9 commits into from
Oct 17, 2023
44 changes: 21 additions & 23 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import toolz
from sqlalchemy import sql

import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.translator import (
Expand Down Expand Up @@ -84,7 +85,7 @@
ctx = self.context

orig_op = op
if isinstance(op, ops.SelfReference):
if isinstance(op, (ops.SelfReference, ops.Sample)):
jcrist marked this conversation as resolved.
Show resolved Hide resolved
op = op.table

alias = ctx.get_ref(orig_op)
Expand Down Expand Up @@ -128,28 +129,27 @@
for name, value in zip(op.schema.names, op.values)
)
)
elif ctx.is_extracted(op):
if isinstance(orig_op, ops.SelfReference):
result = ctx.get_ref(op)
else:
result = alias
else:
# A subquery
if ctx.is_extracted(op):
# Was put elsewhere, e.g. WITH block, we just need to grab
# its alias
alias = ctx.get_ref(orig_op)

# hack
if isinstance(orig_op, ops.SelfReference):
table = ctx.get_ref(op)
self_ref = alias if hasattr(alias, "name") else table.alias(alias)
ctx.set_ref(orig_op, self_ref)
return self_ref
return alias

alias = ctx.get_ref(orig_op)
result = ctx.get_compiled_expr(orig_op)
result = ctx.get_compiled_expr(op)

result = alias if hasattr(alias, "name") else result.alias(alias)

if isinstance(orig_op, ops.Sample):
result = self._format_sample(orig_op, result)

ctx.set_ref(orig_op, result)
return result

def _format_sample(self, op, table):
# Should never be hit in practice, as Sample operations should be rewritten
# before this point for all backends without TABLESAMPLE support
raise com.UnsupportedOperationError("`Table.sample` is not supported")

Check warning on line 151 in ibis/backends/base/sql/alchemy/query_builder.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/base/sql/alchemy/query_builder.py#L151

Added line #L151 was not covered by tests

def _format_in_memory_table(self, op, translator):
columns = translator._schema_to_sqlalchemy_columns(op.schema)
if self.context.compiler.cheap_in_memory_tables:
Expand All @@ -168,7 +168,7 @@
).limit(0)
elif self.context.compiler.support_values_syntax_in_select:
rows = list(op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=op.name).data(rows)
result = sa.values(*columns, name=op.name).data(rows).select().subquery()
else:
raw_rows = (
sa.select(
Expand Down Expand Up @@ -219,13 +219,11 @@
self.context.set_ref(expr, result)

def _compile_table_set(self):
if self.table_set is not None:
helper = self.table_set_formatter_class(self, self.table_set)
result = helper.get_result()
return result
else:
if self.table_set is None:
return None

return self.table_set_formatter_class(self, self.table_set).get_result()

def _add_select(self, table_set):
if not self.select_set:
return table_set.element
Expand Down
34 changes: 19 additions & 15 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@
ctx = self.context

orig_op = op
if isinstance(op, ops.SelfReference):
if isinstance(op, (ops.SelfReference, ops.Sample)):
op = op.table

alias = ctx.get_ref(orig_op)

if isinstance(op, ops.InMemoryTable):
result = self._format_in_memory_table(op)
elif isinstance(op, ops.PhysicalTable):
Expand All @@ -117,26 +119,28 @@
db=getattr(op, "namespace", None),
quoted=self.parent.translator_class._quote_identifiers,
).sql(dialect=self.parent.translator_class._dialect_name)
elif ctx.is_extracted(op):
if isinstance(orig_op, ops.SelfReference):
result = ctx.get_ref(op)
else:
result = alias
else:
# A subquery
if ctx.is_extracted(op):
# Was put elsewhere, e.g. WITH block, we just need to grab its
# alias
alias = ctx.get_ref(orig_op)

# HACK: self-references have to be treated more carefully here
if isinstance(orig_op, ops.SelfReference):
return f"{ctx.get_ref(op)} {alias}"
else:
return alias

subquery = ctx.get_compiled_expr(orig_op)
subquery = ctx.get_compiled_expr(op)
result = f"(\n{util.indent(subquery, self.indent)}\n)"

result += f" {ctx.get_ref(orig_op)}"
if result != alias:
result = f"{result} {alias}"

if isinstance(orig_op, ops.Sample):
result = self._format_sample(orig_op, result)

Check warning on line 135 in ibis/backends/base/sql/compiler/query_builder.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/base/sql/compiler/query_builder.py#L135

Added line #L135 was not covered by tests

return result

def _format_sample(self, op, table):
# Should never be hit in practice, as Sample operations should be rewritten
# before this point for all backends without TABLESAMPLE support
raise com.UnsupportedOperationError("`Table.sample` is not supported")

Check warning on line 142 in ibis/backends/base/sql/compiler/query_builder.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/base/sql/compiler/query_builder.py#L142

Added line #L142 was not covered by tests

def get_result(self):
# Got to unravel the join stack; the nesting order could be
# arbitrary, so we do a depth first search and push the join tokens
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/base/sql/compiler/select_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def _collect_Limit(self, op, toplevel=False):
assert self.limit is None
self.limit = _LimitSpec(op.n, op.offset)

def _collect_Sample(self, op, toplevel=False):
if toplevel:
self.table_set = op
self.select_set = [op]

def _collect_Union(self, op, toplevel=False):
if toplevel:
self.table_set = op
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/clickhouse/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ibis.backends.clickhouse.compiler.values import translate_val
from ibis.common.deferred import _
from ibis.expr.analysis import c, find_first_base_table, p, x, y
from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna
from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna, rewrite_sample

Check warning on line 32 in ibis/backends/clickhouse/compiler/core.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/clickhouse/compiler/core.py#L32

Added line #L32 was not covered by tests

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -125,6 +125,7 @@
| nullify_empty_string_results
| rewrite_fillna
| rewrite_dropna
| rewrite_sample
)
# apply translate rules in topological order
node = op.map(fn)[op]
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,8 @@ def execute_table_array_view(op, _, **kwargs):
# Need to compute dataframe in order to squeeze into a scalar
ddf = execute(op.table)
return ddf.compute().squeeze()


@execute_node.register(ops.Sample, dd.DataFrame, object, object)
def execute_sample(op, data, fraction, seed, **kwargs):
return data.sample(frac=fraction, random_state=seed)
2 changes: 2 additions & 0 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis.backends.druid.datatypes as ddt
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.druid.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class DruidExprTranslator(AlchemyExprTranslator):
Expand All @@ -29,3 +30,4 @@ def translate(self, op):
class DruidCompiler(AlchemyCompiler):
translator_class = DruidExprTranslator
null_limit = sa.literal_column("ALL")
rewrites = AlchemyCompiler.rewrites | rewrite_sample
14 changes: 14 additions & 0 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter
from ibis.backends.duckdb.datatypes import DuckDBType
from ibis.backends.duckdb.registry import operation_registry

Expand Down Expand Up @@ -60,6 +61,19 @@ def _no_op(expr):
return expr


class DuckDBTableSetFormatter(_AlchemyTableSetFormatter):
def _format_sample(self, op, table):
if op.method == "row":
method = sa.func.bernoulli
else:
method = sa.func.system
return table.tablesample(
sampling=method(sa.literal_column(f"{op.fraction * 100} PERCENT")),
seed=(None if op.seed is None else sa.literal_column(str(op.seed))),
)


class DuckDBSQLCompiler(AlchemyCompiler):
cheap_in_memory_tables = True
translator_class = DuckDBSQLExprTranslator
table_set_formatter_class = DuckDBTableSetFormatter
2 changes: 2 additions & 0 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ibis.expr.operations as ops
from ibis.backends.base.sql.compiler import Compiler, ExprTranslator, TableSetFormatter
from ibis.backends.base.sql.registry import binary_infix_ops, operation_registry, unary
from ibis.expr.rewrites import rewrite_sample


class ImpalaTableSetFormatter(TableSetFormatter):
Expand Down Expand Up @@ -58,3 +59,4 @@ def _floor_divide(op):
class ImpalaCompiler(Compiler):
translator_class = ImpalaExprTranslator
table_set_formatter_class = ImpalaTableSetFormatter
rewrites = Compiler.rewrites | rewrite_sample
2 changes: 2 additions & 0 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.mssql.datatypes import MSSQLType
from ibis.backends.mssql.registry import _timestamp_from_unix, operation_registry
from ibis.expr.rewrites import rewrite_sample


class MsSqlExprTranslator(AlchemyExprTranslator):
Expand Down Expand Up @@ -35,3 +36,4 @@ class MsSqlCompiler(AlchemyCompiler):

supports_indexed_grouping_keys = False
null_limit = None
rewrites = AlchemyCompiler.rewrites | rewrite_sample
2 changes: 2 additions & 0 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.mysql.datatypes import MySQLType
from ibis.backends.mysql.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class MySQLExprTranslator(AlchemyExprTranslator):
Expand All @@ -24,3 +25,4 @@ class MySQLCompiler(AlchemyCompiler):
translator_class = MySQLExprTranslator
support_values_syntax_in_select = False
null_limit = None
rewrites = AlchemyCompiler.rewrites | rewrite_sample
2 changes: 2 additions & 0 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from ibis.backends.oracle.datatypes import OracleType # noqa: E402
from ibis.backends.oracle.registry import operation_registry # noqa: E402
from ibis.expr.rewrites import rewrite_sample # noqa: E402

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -73,6 +74,7 @@ class OracleCompiler(AlchemyCompiler):
support_values_syntax_in_select = False
supports_indexed_grouping_keys = False
null_limit = None
rewrites = AlchemyCompiler.rewrites | rewrite_sample


class Backend(BaseAlchemyBackend):
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,3 +1450,8 @@ def execute_table_array_view(op, _, **kwargs):
@execute_node.register(ops.InMemoryTable)
def execute_in_memory_table(op, **kwargs):
return op.data.to_frame()


@execute_node.register(ops.Sample, pd.DataFrame, object, object)
def execute_sample(op, data, fraction, seed, **kwargs):
return data.sample(frac=fraction, random_state=seed)
2 changes: 2 additions & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.postgres.datatypes import PostgresType
from ibis.backends.postgres.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class PostgresUDFNode(ops.Value):
Expand Down Expand Up @@ -35,3 +36,4 @@ def _any_all_no_op(expr):

class PostgreSQLCompiler(AlchemyCompiler):
translator_class = PostgreSQLExprTranslator
rewrites = AlchemyCompiler.rewrites | rewrite_sample
6 changes: 6 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@
return df


@compiles(ops.Sample)

Check warning on line 303 in ibis/backends/pyspark/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/compiler.py#L303

Added line #L303 was not covered by tests
def compile_sample(t, op, **kwargs):
df = t.translate(op.table, **kwargs)
return df.sample(fraction=op.fraction, seed=op.seed)

Check warning on line 306 in ibis/backends/pyspark/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/compiler.py#L305-L306

Added lines #L305 - L306 were not covered by tests


@compiles(ops.And)
def compile_and(t, op, **kwargs):
return t.translate(op.left, **kwargs) & t.translate(op.right, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.sqlite.datatypes import SqliteType
from ibis.backends.sqlite.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class SQLiteExprTranslator(AlchemyExprTranslator):
Expand All @@ -32,3 +33,4 @@ class SQLiteCompiler(AlchemyCompiler):
translator_class = SQLiteExprTranslator
support_values_syntax_in_select = False
null_limit = None
rewrites = AlchemyCompiler.rewrites | rewrite_sample
68 changes: 68 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,3 +1527,71 @@ def test_dynamic_table_slice_with_computed_offset(backend):
result = expr.to_pandas()

backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(
[
"bigquery",
"datafusion",
"druid",
"flink",
"polars",
"snowflake",
]
)
def test_sample(backend):
t = backend.functional_alltypes.filter(_.int_col >= 2)

total_rows = t.count().execute()
empty = t.limit(1).execute().iloc[:0]

df = t.sample(0.1, method="row").execute()
assert len(df) <= total_rows
backend.assert_frame_equal(empty, df.iloc[:0])

df = t.sample(0.1, method="block").execute()
assert len(df) <= total_rows
backend.assert_frame_equal(empty, df.iloc[:0])


@pytest.mark.notimpl(
[
"bigquery",
"datafusion",
"druid",
"flink",
"polars",
"snowflake",
]
)
def test_sample_memtable(con, backend):
df = pd.DataFrame({"x": [1, 2, 3, 4]})
res = con.execute(ibis.memtable(df).sample(0.5))
assert len(res) <= 4
backend.assert_frame_equal(res.iloc[:0], df.iloc[:0])


@pytest.mark.notimpl(
[
"bigquery",
"clickhouse",
"datafusion",
"druid",
"flink",
"impala",
"mssql",
"mysql",
"oracle",
"polars",
"postgres",
"snowflake",
"sqlite",
"trino",
]
)
def test_sample_with_seed(backend):
t = backend.functional_alltypes
expr = t.sample(0.1, seed=1234)
df1 = expr.to_pandas()
df2 = expr.to_pandas()
backend.assert_frame_equal(df1, df2)
Loading