Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: several Table.sample improvements #10207

Merged
merged 4 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,24 @@ def distinct(op, **kw):
return table.unique()


@translate.register(ops.Sample)
def sample(op, **kw):
if op.seed is not None:
raise com.UnsupportedOperationError(
"`Table.sample` with a random seed is unsupported"
)
table = translate(op.parent, **kw)
# Disable predicate pushdown since `t.filter(...).sample(...)` could have
# different statistical or performance characteristics than
# `t.sample(...).filter(...)`. Same for slice pushdown.
return table.map_batches(
lambda df: df.sample(fraction=op.fraction),
predicate_pushdown=False,
slice_pushdown=False,
streamable=True,
)
jcrist marked this conversation as resolved.
Show resolved Hide resolved


@translate.register(ops.CountStar)
def count_star(op, **kw):
if (where := op.where) is not None:
Expand Down
19 changes: 18 additions & 1 deletion ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@
LOWERED_OPS: dict[type[ops.Node], pats.Replace | None] = {
ops.Bucket: lower_bucket,
ops.Capitalize: lower_capitalize,
ops.Sample: lower_sample,
ops.Sample: lower_sample(supported_methods=()),
ops.StringSlice: lower_stringslice,
}
"""A mapping from an operation class to either a rewrite rule for rewriting that
Expand Down Expand Up @@ -1448,6 +1448,23 @@
copy=False,
)

def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
# sample was changed to be owned by the table being sampled in 25.17.0
#
# this is a small workaround for backwards compatibility
if "this" in sample.__class__.arg_types:
sample.args["this"] = parent

Check warning on line 1463 in ibis/backends/sql/compilers/base.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/base.py#L1463

Added line #L1463 was not covered by tests
else:
parent.args["sample"] = sample
return sg.select(STAR).from_(parent)

def visit_Limit(self, op, *, parent, n, offset):
# push limit/offset into subqueries
if isinstance(parent, sge.Subquery) and parent.this.args.get("limit") is None:
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
lower_sample,
split_select_distinct_with_order_by,
)
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit
Expand Down Expand Up @@ -118,6 +119,14 @@ class BigQueryCompiler(SQLGlotCompiler):

supports_qualify = True

LOWERED_OPS = {
ops.Sample: lower_sample(
supported_methods=("block",),
supports_seed=False,
physical_tables_only=True,
),
}

UNSUPPORTED_OPS = (
ops.DateDiff,
ops.ExtractAuthority,
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DruidCompiler(SQLGlotCompiler):

agg = AggGen(supports_filter=True)

LOWERED_OPS = {ops.Capitalize: None}
LOWERED_OPS = {ops.Capitalize: None, ops.Sample: None}

UNSUPPORTED_OPS = (
ops.ApproxMedian,
Expand All @@ -47,6 +47,7 @@ class DruidCompiler(SQLGlotCompiler):
ops.IsInf,
ops.Levenshtein,
ops.Median,
ops.RandomScalar,
ops.RandomUUID,
ops.RegexReplace,
ops.RegexSplit,
Expand All @@ -64,6 +65,7 @@ class DruidCompiler(SQLGlotCompiler):
ops.TypeOf,
ops.Unnest,
ops.Variance,
ops.Sample,
)

SIMPLE_OPS = {
Expand Down
15 changes: 2 additions & 13 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import lower_sample
from ibis.util import gen_name

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,7 +46,7 @@ class DuckDBCompiler(SQLGlotCompiler):
supports_qualify = True

LOWERED_OPS = {
ops.Sample: None,
ops.Sample: lower_sample(),
ops.StringSlice: None,
}

Expand Down Expand Up @@ -171,18 +172,6 @@ def visit_ArrayRepeat(self, op, *, arg, times):
func = sge.Lambda(this=arg, expressions=[sg.to_identifier("_")])
return self.f.flatten(self.f.list_apply(self.f.range(times), func))

# TODO(kszucs): this could be moved to the base SQLGlotCompiler
def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)

return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_ArraySlice(self, op, *, arg, start, stop):
arg_length = self.f.len(arg)

Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.datatypes import ImpalaType
from ibis.backends.sql.dialects import Impala
from ibis.backends.sql.rewrites import rewrite_empty_order_by_window
from ibis.backends.sql.rewrites import lower_sample, rewrite_empty_order_by_window


class ImpalaCompiler(SQLGlotCompiler):
Expand All @@ -23,6 +23,12 @@ class ImpalaCompiler(SQLGlotCompiler):
*SQLGlotCompiler.rewrites,
)

LOWERED_OPS = {
ops.Sample: lower_sample(
supported_methods=("block",), physical_tables_only=True
),
}

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
lower_sample,
p,
replace,
split_select_distinct_with_order_by,
Expand Down Expand Up @@ -73,6 +74,12 @@ class MSSQLCompiler(SQLGlotCompiler):
post_rewrites = (split_select_distinct_with_order_by,)
copy_func_args = True

LOWERED_OPS = {
ops.Sample: lower_sample(
supported_methods=("block",), physical_tables_only=True
),
}

UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
exclude_unsupported_window_frame_from_row_number,
lower_log2,
lower_log10,
lower_sample,
rewrite_empty_order_by_window,
)

Expand Down Expand Up @@ -46,6 +47,7 @@ class OracleCompiler(SQLGlotCompiler):
LOWERED_OPS = {
ops.Log2: lower_log2,
ops.Log10: lower_log10,
ops.Sample: lower_sample(physical_tables_only=True),
}

UNSUPPORTED_OPS = (
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import split_select_distinct_with_order_by
from ibis.backends.sql.rewrites import lower_sample, split_select_distinct_with_order_by
from ibis.common.exceptions import InvalidDecoratorError
from ibis.util import gen_name

Expand Down Expand Up @@ -50,6 +50,8 @@ class PostgresCompiler(SQLGlotCompiler):
POS_INF = sge.Literal.number("'Inf'::double precision")
NEG_INF = sge.Literal.number("'-Inf'::double precision")

LOWERED_OPS = {ops.Sample: lower_sample(physical_tables_only=True)}

UNSUPPORTED_OPS = (
ops.RowID,
ops.TimeDelta,
Expand Down
13 changes: 2 additions & 11 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
lower_sample,
p,
split_select_distinct_with_order_by,
)
Expand Down Expand Up @@ -65,7 +66,7 @@ class PySparkCompiler(SQLGlotCompiler):
)

LOWERED_OPS = {
ops.Sample: None,
ops.Sample: lower_sample(supports_seed=False),
}

SIMPLE_OPS = {
Expand Down Expand Up @@ -333,16 +334,6 @@ def visit_TimestampRange(self, op, *, start, stop, step):
zero = sge.Interval(this=sge.convert(0), unit=unit)
return self._build_sequence(start, stop, step, zero)

def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
if seed is not None:
raise com.UnsupportedOperationError(
"PySpark backend does not support sampling with seed."
)
sample = sge.TableSample(percent=sge.convert(int(fraction * 100.0)))
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_WindowBoundary(self, op, *, value, preceding):
if isinstance(op.value, ops.Literal) and op.value.value == 0:
value = "CURRENT ROW"
Expand Down
4 changes: 4 additions & 0 deletions ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ class RisingWaveCompiler(PostgresCompiler):
dialect = RisingWave
type_mapper = RisingWaveType

LOWERED_OPS = {ops.Sample: None}

UNSUPPORTED_OPS = (
ops.Arbitrary,
ops.Mode,
ops.RandomScalar,
ops.RandomUUID,
ops.MultiQuantile,
ops.ApproxMultiQuantile,
ops.Sample,
*(
op
for op in ALL_OPERATIONS
Expand Down
13 changes: 2 additions & 11 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
exclude_unsupported_window_frame_from_row_number,
lower_log2,
lower_log10,
lower_sample,
rewrite_empty_order_by_window,
)

Expand Down Expand Up @@ -59,7 +60,7 @@ class SnowflakeCompiler(SQLGlotCompiler):
LOWERED_OPS = {
ops.Log2: lower_log2,
ops.Log10: lower_log10,
ops.Sample: None,
ops.Sample: lower_sample(),
}

UNSUPPORTED_OPS = (
Expand Down Expand Up @@ -762,16 +763,6 @@ def visit_TimestampRange(self, op, *, start, stop, step):
.subquery()
)

def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_ArrayMap(self, op, *, arg, param, body):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))

Expand Down
17 changes: 2 additions & 15 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
lower_sample,
split_select_distinct_with_order_by,
)
from ibis.util import gen_name
Expand Down Expand Up @@ -54,7 +55,7 @@ class TrinoCompiler(SQLGlotCompiler):
)

LOWERED_OPS = {
ops.Sample: None,
ops.Sample: lower_sample(supports_seed=False),
}

SIMPLE_OPS = {
Expand Down Expand Up @@ -107,20 +108,6 @@ def _minimize_spec(start, end, spec):
return None
return spec

def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
if seed is not None:
raise com.UnsupportedOperationError(
"`Table.sample` with a random seed is unsupported"
)
sample = sge.TableSample(
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_Correlation(self, op, *, left, right, how, where):
if how == "sample":
raise com.UnsupportedOperationError(
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,18 +307,30 @@ class Tokenizer(Hive.Tokenizer):
STRING_ESCAPES = ["'"]


def tablesample_percent_to_int(self, expr):
"""Impala's TABLESAMPLE only supports integer percentages."""
expr = expr.copy()
expr.args["percent"] = sge.convert(round(float(expr.args["percent"].this)))
return self.tablesample_sql(expr)


class Impala(Hive):
NULL_ORDERING = "nulls_are_large"
REGEXP_EXTRACT_DEFAULT_GROUP = 0
TABLESAMPLE_SIZE_IS_PERCENT = True
ALIAS_POST_TABLESAMPLE = False

class Generator(Hive.Generator):
TABLESAMPLE_WITH_METHOD = True

TRANSFORMS = Hive.Generator.TRANSFORMS.copy() | {
sge.ApproxDistinct: rename_func("ndv"),
sge.IsNan: rename_func("is_nan"),
sge.IsInf: rename_func("is_inf"),
sge.DayOfWeek: rename_func("dayofweek"),
sge.Interval: lambda self, e: _interval(self, e, quote_arg=False),
sge.CurrentDate: rename_func("current_date"),
sge.TableSample: tablesample_percent_to_int,
}


Expand Down
Loading