diff --git a/ibis/backends/base/sql/alchemy/query_builder.py b/ibis/backends/base/sql/alchemy/query_builder.py index 85cb33a0cdf1..f409d39dd683 100644 --- a/ibis/backends/base/sql/alchemy/query_builder.py +++ b/ibis/backends/base/sql/alchemy/query_builder.py @@ -6,6 +6,7 @@ import toolz from sqlalchemy import sql +import ibis.common.exceptions as com import ibis.expr.analysis as an import ibis.expr.operations as ops from ibis.backends.base.sql.alchemy.translator import ( @@ -84,7 +85,7 @@ def _format_table(self, op): ctx = self.context orig_op = op - if isinstance(op, ops.SelfReference): + if isinstance(op, (ops.SelfReference, ops.Sample)): op = op.table alias = ctx.get_ref(orig_op) @@ -128,28 +129,27 @@ def _format_table(self, op): for name, value in zip(op.schema.names, op.values) ) ) + elif ctx.is_extracted(op): + if isinstance(orig_op, ops.SelfReference): + result = ctx.get_ref(op) + else: + result = alias else: - # A subquery - if ctx.is_extracted(op): - # Was put elsewhere, e.g. WITH block, we just need to grab - # its alias - alias = ctx.get_ref(orig_op) - - # hack - if isinstance(orig_op, ops.SelfReference): - table = ctx.get_ref(op) - self_ref = alias if hasattr(alias, "name") else table.alias(alias) - ctx.set_ref(orig_op, self_ref) - return self_ref - return alias - - alias = ctx.get_ref(orig_op) - result = ctx.get_compiled_expr(orig_op) + result = ctx.get_compiled_expr(op) result = alias if hasattr(alias, "name") else result.alias(alias) + + if isinstance(orig_op, ops.Sample): + result = self._format_sample(orig_op, result) + ctx.set_ref(orig_op, result) return result + def _format_sample(self, op, table): + # Should never be hit in practice, as Sample operations should be rewritten + # before this point for all backends without TABLESAMPLE support + raise com.UnsupportedOperationError("`Table.sample` is not supported") + def _format_in_memory_table(self, op, translator): columns = translator._schema_to_sqlalchemy_columns(op.schema) if self.context.compiler.cheap_in_memory_tables: @@ -168,7 +168,7 @@ def _format_in_memory_table(self, op, translator): ).limit(0) elif self.context.compiler.support_values_syntax_in_select: rows = list(op.data.to_frame().itertuples(index=False)) - result = sa.values(*columns, name=op.name).data(rows) + result = sa.values(*columns, name=op.name).data(rows).select().subquery() else: raw_rows = ( sa.select( @@ -219,13 +219,11 @@ def _compile_subqueries(self): self.context.set_ref(expr, result) def _compile_table_set(self): - if self.table_set is not None: - helper = self.table_set_formatter_class(self, self.table_set) - result = helper.get_result() - return result - else: + if self.table_set is None: return None + return self.table_set_formatter_class(self, self.table_set).get_result() + def _add_select(self, table_set): if not self.select_set: return table_set.element diff --git a/ibis/backends/base/sql/compiler/query_builder.py b/ibis/backends/base/sql/compiler/query_builder.py index cb968f01fef3..ee6f4f76fe13 100644 --- a/ibis/backends/base/sql/compiler/query_builder.py +++ b/ibis/backends/base/sql/compiler/query_builder.py @@ -100,9 +100,11 @@ def _format_table(self, op): ctx = self.context orig_op = op - if isinstance(op, ops.SelfReference): + if isinstance(op, (ops.SelfReference, ops.Sample)): op = op.table + alias = ctx.get_ref(orig_op) + if isinstance(op, ops.InMemoryTable): result = self._format_in_memory_table(op) elif isinstance(op, ops.PhysicalTable): @@ -117,26 +119,28 @@ def _format_table(self, op): db=getattr(op, "namespace", None), quoted=self.parent.translator_class._quote_identifiers, ).sql(dialect=self.parent.translator_class._dialect_name) + elif ctx.is_extracted(op): + if isinstance(orig_op, ops.SelfReference): + result = ctx.get_ref(op) + else: + result = alias else: - # A subquery - if ctx.is_extracted(op): - # Was put elsewhere, e.g. WITH block, we just need to grab its - # alias - alias = ctx.get_ref(orig_op) - - # HACK: self-references have to be treated more carefully here - if isinstance(orig_op, ops.SelfReference): - return f"{ctx.get_ref(op)} {alias}" - else: - return alias - - subquery = ctx.get_compiled_expr(orig_op) + subquery = ctx.get_compiled_expr(op) result = f"(\n{util.indent(subquery, self.indent)}\n)" - result += f" {ctx.get_ref(orig_op)}" + if result != alias: + result = f"{result} {alias}" + + if isinstance(orig_op, ops.Sample): + result = self._format_sample(orig_op, result) return result + def _format_sample(self, op, table): + # Should never be hit in practice, as Sample operations should be rewritten + # before this point for all backends without TABLESAMPLE support + raise com.UnsupportedOperationError("`Table.sample` is not supported") + def get_result(self): # Got to unravel the join stack; the nesting order could be # arbitrary, so we do a depth first search and push the join tokens diff --git a/ibis/backends/base/sql/compiler/select_builder.py b/ibis/backends/base/sql/compiler/select_builder.py index 75ccf9b4b351..b88dd76aa2e4 100644 --- a/ibis/backends/base/sql/compiler/select_builder.py +++ b/ibis/backends/base/sql/compiler/select_builder.py @@ -147,6 +147,11 @@ def _collect_Limit(self, op, toplevel=False): assert self.limit is None self.limit = _LimitSpec(op.n, op.offset) + def _collect_Sample(self, op, toplevel=False): + if toplevel: + self.table_set = op + self.select_set = [op] + def _collect_Union(self, op, toplevel=False): if toplevel: self.table_set = op diff --git a/ibis/backends/clickhouse/compiler/core.py b/ibis/backends/clickhouse/compiler/core.py index 2b6590fcff66..7b00566458c3 100644 --- a/ibis/backends/clickhouse/compiler/core.py +++ b/ibis/backends/clickhouse/compiler/core.py @@ -29,7 +29,7 @@ from ibis.backends.clickhouse.compiler.values import translate_val from ibis.common.deferred import _ from ibis.expr.analysis import c, find_first_base_table, p, x, y -from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna +from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna, rewrite_sample if TYPE_CHECKING: from collections.abc import Mapping @@ -125,6 +125,7 @@ def fn(node, _, **kwargs): | nullify_empty_string_results | rewrite_fillna | rewrite_dropna + | rewrite_sample ) # apply translate rules in topological order node = op.map(fn)[op] diff --git a/ibis/backends/dask/execution/generic.py b/ibis/backends/dask/execution/generic.py index ce75ca62f0e9..34c56b592e4e 100644 --- a/ibis/backends/dask/execution/generic.py +++ b/ibis/backends/dask/execution/generic.py @@ -552,3 +552,8 @@ def execute_table_array_view(op, _, **kwargs): # Need to compute dataframe in order to squeeze into a scalar ddf = execute(op.table) return ddf.compute().squeeze() + + +@execute_node.register(ops.Sample, dd.DataFrame, object, object) +def execute_sample(op, data, fraction, seed, **kwargs): + return data.sample(frac=fraction, random_state=seed) diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index 60a48bbf5b45..6c766af97111 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -7,6 +7,7 @@ import ibis.backends.druid.datatypes as ddt from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.druid.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample class DruidExprTranslator(AlchemyExprTranslator): @@ -29,3 +30,4 @@ def translate(self, op): class DruidCompiler(AlchemyCompiler): translator_class = DruidExprTranslator null_limit = sa.literal_column("ALL") + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index a74136935059..eaac09d0b0a7 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -6,6 +6,7 @@ import ibis.backends.base.sql.alchemy.datatypes as sat import ibis.expr.operations as ops from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator +from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter from ibis.backends.duckdb.datatypes import DuckDBType from ibis.backends.duckdb.registry import operation_registry @@ -60,6 +61,19 @@ def _no_op(expr): return expr +class DuckDBTableSetFormatter(_AlchemyTableSetFormatter): + def _format_sample(self, op, table): + if op.method == "row": + method = sa.func.bernoulli + else: + method = sa.func.system + return table.tablesample( + sampling=method(sa.literal_column(f"{op.fraction * 100} PERCENT")), + seed=(None if op.seed is None else sa.literal_column(str(op.seed))), + ) + + class DuckDBSQLCompiler(AlchemyCompiler): cheap_in_memory_tables = True translator_class = DuckDBSQLExprTranslator + table_set_formatter_class = DuckDBTableSetFormatter diff --git a/ibis/backends/impala/compiler.py b/ibis/backends/impala/compiler.py index 897c6a619979..558ab877819b 100644 --- a/ibis/backends/impala/compiler.py +++ b/ibis/backends/impala/compiler.py @@ -3,6 +3,7 @@ import ibis.expr.operations as ops from ibis.backends.base.sql.compiler import Compiler, ExprTranslator, TableSetFormatter from ibis.backends.base.sql.registry import binary_infix_ops, operation_registry, unary +from ibis.expr.rewrites import rewrite_sample class ImpalaTableSetFormatter(TableSetFormatter): @@ -58,3 +59,4 @@ def _floor_divide(op): class ImpalaCompiler(Compiler): translator_class = ImpalaExprTranslator table_set_formatter_class = ImpalaTableSetFormatter + rewrites = Compiler.rewrites | rewrite_sample diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index 7722702eb95a..6a1d9c31b099 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -6,6 +6,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.mssql.datatypes import MSSQLType from ibis.backends.mssql.registry import _timestamp_from_unix, operation_registry +from ibis.expr.rewrites import rewrite_sample class MsSqlExprTranslator(AlchemyExprTranslator): @@ -35,3 +36,4 @@ class MsSqlCompiler(AlchemyCompiler): supports_indexed_grouping_keys = False null_limit = None + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index c33f49318bad..529dfe84b211 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -5,6 +5,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.mysql.datatypes import MySQLType from ibis.backends.mysql.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample class MySQLExprTranslator(AlchemyExprTranslator): @@ -24,3 +25,4 @@ class MySQLCompiler(AlchemyCompiler): translator_class = MySQLExprTranslator support_values_syntax_in_select = False null_limit = None + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index 03c79e93ac91..2095cb0b54b5 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -39,6 +39,7 @@ ) from ibis.backends.oracle.datatypes import OracleType # noqa: E402 from ibis.backends.oracle.registry import operation_registry # noqa: E402 +from ibis.expr.rewrites import rewrite_sample # noqa: E402 if TYPE_CHECKING: from collections.abc import Iterable @@ -73,6 +74,7 @@ class OracleCompiler(AlchemyCompiler): support_values_syntax_in_select = False supports_indexed_grouping_keys = False null_limit = None + rewrites = AlchemyCompiler.rewrites | rewrite_sample class Backend(BaseAlchemyBackend): diff --git a/ibis/backends/pandas/execution/generic.py b/ibis/backends/pandas/execution/generic.py index e23ecdd1aa8e..2e30dd33c6b4 100644 --- a/ibis/backends/pandas/execution/generic.py +++ b/ibis/backends/pandas/execution/generic.py @@ -1450,3 +1450,8 @@ def execute_table_array_view(op, _, **kwargs): @execute_node.register(ops.InMemoryTable) def execute_in_memory_table(op, **kwargs): return op.data.to_frame() + + +@execute_node.register(ops.Sample, pd.DataFrame, object, object) +def execute_sample(op, data, fraction, seed, **kwargs): + return data.sample(frac=fraction, random_state=seed) diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index e5efc76fa7d2..48a5f24b0111 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -5,6 +5,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.postgres.datatypes import PostgresType from ibis.backends.postgres.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample class PostgresUDFNode(ops.Value): @@ -35,3 +36,4 @@ def _any_all_no_op(expr): class PostgreSQLCompiler(AlchemyCompiler): translator_class = PostgreSQLExprTranslator + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 40a82411d13f..06f90be0661f 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -300,6 +300,12 @@ def compile_limit(t, op, **kwargs): return df +@compiles(ops.Sample) +def compile_sample(t, op, **kwargs): + df = t.translate(op.table, **kwargs) + return df.sample(fraction=op.fraction, seed=op.seed) + + @compiles(ops.And) def compile_and(t, op, **kwargs): return t.translate(op.left, **kwargs) & t.translate(op.right, **kwargs) diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index 97a4604524af..09db8897ebfc 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -16,6 +16,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.sqlite.datatypes import SqliteType from ibis.backends.sqlite.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample class SQLiteExprTranslator(AlchemyExprTranslator): @@ -32,3 +33,4 @@ class SQLiteCompiler(AlchemyCompiler): translator_class = SQLiteExprTranslator support_values_syntax_in_select = False null_limit = None + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 79c102a47c9a..e5133f6a5550 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1527,3 +1527,71 @@ def test_dynamic_table_slice_with_computed_offset(backend): result = expr.to_pandas() backend.assert_frame_equal(result, expected) + + +@pytest.mark.notimpl( + [ + "bigquery", + "datafusion", + "druid", + "flink", + "polars", + "snowflake", + ] +) +def test_sample(backend): + t = backend.functional_alltypes.filter(_.int_col >= 2) + + total_rows = t.count().execute() + empty = t.limit(1).execute().iloc[:0] + + df = t.sample(0.1, method="row").execute() + assert len(df) <= total_rows + backend.assert_frame_equal(empty, df.iloc[:0]) + + df = t.sample(0.1, method="block").execute() + assert len(df) <= total_rows + backend.assert_frame_equal(empty, df.iloc[:0]) + + +@pytest.mark.notimpl( + [ + "bigquery", + "datafusion", + "druid", + "flink", + "polars", + "snowflake", + ] +) +def test_sample_memtable(con, backend): + df = pd.DataFrame({"x": [1, 2, 3, 4]}) + res = con.execute(ibis.memtable(df).sample(0.5)) + assert len(res) <= 4 + backend.assert_frame_equal(res.iloc[:0], df.iloc[:0]) + + +@pytest.mark.notimpl( + [ + "bigquery", + "clickhouse", + "datafusion", + "druid", + "flink", + "impala", + "mssql", + "mysql", + "oracle", + "polars", + "postgres", + "snowflake", + "sqlite", + "trino", + ] +) +def test_sample_with_seed(backend): + t = backend.functional_alltypes + expr = t.sample(0.1, seed=1234) + df1 = expr.to_pandas() + df2 = expr.to_pandas() + backend.assert_frame_equal(df1, df2) diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 1f9911b61a27..e8d199daead5 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -7,6 +7,7 @@ from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter from ibis.backends.trino.datatypes import TrinoType from ibis.backends.trino.registry import operation_registry +from ibis.common.exceptions import UnsupportedOperationError class TrinoSQLExprTranslator(AlchemyExprTranslator): @@ -47,6 +48,16 @@ def _rewrite_string_contains(op): class TrinoTableSetFormatter(_AlchemyTableSetFormatter): + def _format_sample(self, op, table): + if op.seed is not None: + raise UnsupportedOperationError( + "`Table.sample` with a random seed is unsupported" + ) + method = sa.func.bernoulli if op.method == "row" else sa.func.system + return table.tablesample( + sampling=method(sa.literal_column(f"{op.fraction * 100}")) + ) + def _format_in_memory_table(self, op, translator): if not op.data: return sa.select( @@ -65,7 +76,7 @@ def _format_in_memory_table(self, op, translator): for row in op.data.to_frame().itertuples(index=False) ] columns = translator._schema_to_sqlalchemy_columns(op.schema) - return sa.values(*columns, name=op.name).data(rows) + return sa.values(*columns, name=op.name).data(rows).select().subquery() class TrinoSQLCompiler(AlchemyCompiler): diff --git a/ibis/expr/format.py b/ibis/expr/format.py index 39504b818562..2c4028214b3f 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -291,6 +291,7 @@ def _join(op, left, right, predicates, **kwargs): @fmt.register(ops.Limit) +@fmt.register(ops.Sample) def _limit(op, table, **kwargs): params = inline_args(kwargs) return f"{op.__class__.__name__}[{table}, {params}]" diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index b4740d63a151..84266b968858 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -3,7 +3,7 @@ import abc import itertools from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional from typing import Union as UnionType from public import public @@ -17,7 +17,7 @@ from ibis.common.collections import FrozenDict # noqa: TCH001 from ibis.common.deferred import Deferred from ibis.common.grounds import Immutable -from ibis.common.patterns import Coercible, Eq +from ibis.common.patterns import Between, Coercible, Eq from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Column, Named, Node, Scalar, Value from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001 @@ -580,6 +580,20 @@ def schema(self): return self.table.schema +@public +class Sample(Relation): + """Sample performs random sampling of records in a table.""" + + table: Relation + fraction: Annotated[float, Between(0, 1)] + method: Literal["row", "block"] + seed: UnionType[int, None] = None + + @attribute + def schema(self): + return self.table.schema + + # TODO(kszucs): split it into two operations, one working with a single replacement # value and the other with a mapping # TODO(kszucs): the single value case was limited to numeric and string types diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 6cdbbbd21523..4ae694f0120c 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -6,6 +6,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.common.exceptions import UnsupportedOperationError from ibis.common.patterns import pattern, replace from ibis.util import Namespace @@ -58,3 +59,23 @@ def rewrite_dropna(_): return _.table return ops.Selection(_.table, (), preds, ()) + + +@replace(p.Sample) +def rewrite_sample(_): + """Rewrite Sample as `t.filter(random() <= fraction)`. + + Errors as unsupported if a `seed` is specified. + """ + + if _.seed is not None: + raise UnsupportedOperationError( + "`Table.sample` with a random seed is unsupported" + ) + + return ops.Selection( + _.table, + (), + (ops.LessEqual(ops.RandomScalar(), _.fraction),), + (), + ) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 076b4d7058b5..df12ffe735db 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -1191,6 +1191,94 @@ def distinct( return res.select(self.columns) return res + def sample( + self, + fraction: float, + *, + method: Literal["row", "block"] = "row", + seed: int | None = None, + ) -> Table: + """Sample a fraction of rows from a table. + + ::: {.callout-note} + ## Results may be non-repeatable + + Sampling is by definition a random operation. Some backends support + specifying a `seed` for repeatable results, but not all backends + support that option. And some backends (duckdb, for example) do support + specifying a seed but may still not have repeatable results in all + cases. + + In all cases, results are backend-specific. An execution against one + backend is unlikely to sample the same rows when executed against a + different backend, even with the same `seed` set. + ::: + + Parameters + ---------- + fraction + The percentage of rows to include in the sample, expressed as a + float between 0 and 1. + method + The sampling method to use. The default is "row", which includes + each row with a probability of ``fraction``. If method is "block", + some backends may instead perform sampling a fraction of blocks of + rows (where "block" is a backend dependent definition). This is + identical to "row" for backends lacking a blockwise sampling + implementation. For those coming from SQL, "row" and "block" + correspond to "bernoulli" and "system" respectively in a + TABLESAMPLE clause. + seed + An optional random seed to use, for repeatable sampling. Backends + that never support specifying a seed for repeatable sampling will + error appropriately. Note that some backends (like DuckDB) do + support specifying a seed, but may still not have repeatable + results in all cases. + + Returns + ------- + Table + The input table, with `fraction` of rows selected. + + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + >>> t = ibis.memtable({"x": [1, 2, 3, 4], "y": ["a", "b", "c", "d"]}) + >>> t + ┏━━━━━━━┳━━━━━━━━┓ + ┃ x ┃ y ┃ + ┡━━━━━━━╇━━━━━━━━┩ + │ int64 │ string │ + ├───────┼────────┤ + │ 1 │ a │ + │ 2 │ b │ + │ 3 │ c │ + │ 4 │ d │ + └───────┴────────┘ + + Sample approximately half the rows, with a seed specified for + reproducibility. + + >>> t.sample(0.5, seed=1234) + ┏━━━━━━━┳━━━━━━━━┓ + ┃ x ┃ y ┃ + ┡━━━━━━━╇━━━━━━━━┩ + │ int64 │ string │ + ├───────┼────────┤ + │ 2 │ b │ + │ 3 │ c │ + └───────┴────────┘ + """ + if fraction == 1: + return self + elif fraction == 0: + return self.limit(0) + else: + return ops.Sample( + self, fraction=fraction, method=method, seed=seed + ).to_expr() + def limit(self, n: int | None, offset: int = 0) -> Table: """Select `n` rows from `self` starting at `offset`. diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 506b9c293456..69ee3b873705 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -1666,3 +1666,20 @@ def test_quantile_shape(): (b1,) = expr.op().selections assert b1.shape.is_columnar() + + +def test_sample(): + t = ibis.table({"x": "int64", "y": "string"}) + + expr = t.sample(1) + assert expr.equals(t) + + expr = t.sample(0) + assert expr.equals(t.limit(0)) + + expr = t.sample(0.5, method="block", seed=1234) + assert expr.schema() == t.schema() + op = expr.op() + assert op.fraction == 0.5 + assert op.method == "block" + assert op.seed == 1234