Skip to content

Commit

Permalink
fix(sqlglot): ensure that sample works with multiple sqlglot versions
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 27, 2024
1 parent 29ca899 commit 66a30bc
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
10 changes: 10 additions & 0 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,16 @@ def add_query_to_expr(self, *, name: str, table: ir.Table, query: str) -> str:
# generate the SQL string
return parsed.sql(dialect)

def _make_sample_backwards_compatible(self, *, sample, parent):
# sample was changed to be owned by the table being sampled in 25.17.0
#
# this is a small workaround for backwards compatibility
if "this" in sample.__class__.arg_types:
sample.args["this"] = parent

Check warning on line 1645 in ibis/backends/sql/compilers/base.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/base.py#L1645

Added line #L1645 was not covered by tests
else:
parent.args["sample"] = sample
return sg.select(STAR).from_(parent)


# `__init_subclass__` is uncalled for subclasses - we manually call it here to
# autogenerate the base class implementations as well.
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(
this=parent,
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return sg.select(STAR).from_(sample)

return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_ArraySlice(self, op, *, arg, start, stop):
arg_length = self.f.len(arg)
Expand Down
7 changes: 2 additions & 5 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,8 @@ def visit_Sample(
raise com.UnsupportedOperationError(
"PySpark backend does not support sampling with seed."
)
sample = sge.TableSample(
this=parent,
percent=sge.convert(fraction * 100.0),
)
return sg.select(STAR).from_(sample)
sample = sge.TableSample(percent=sge.convert(fraction * 100.0))
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

Check warning on line 338 in ibis/backends/sql/compilers/pyspark.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/pyspark.py#L337-L338

Added lines #L337 - L338 were not covered by tests

def visit_WindowBoundary(self, op, *, value, preceding):
if isinstance(op.value, ops.Literal) and op.value.value == 0:
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,12 +761,11 @@ def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(
this=parent,
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return sg.select(STAR).from_(sample)
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

Check warning on line 768 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L768

Added line #L768 was not covered by tests

def visit_ArrayMap(self, op, *, arg, param, body):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))
Expand Down
5 changes: 2 additions & 3 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,16 @@ def _minimize_spec(start, end, spec):
def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
if op.seed is not None:
if seed is not None:
raise com.UnsupportedOperationError(
"`Table.sample` with a random seed is unsupported"
)
sample = sge.TableSample(
this=parent,
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return sg.select(STAR).from_(sample)
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_Correlation(self, op, *, left, right, how, where):
if how == "sample":
Expand Down

0 comments on commit 66a30bc

Please sign in to comment.