Skip to content

Commit

Permalink
feat(api): implement type for SortExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Apr 5, 2022
1 parent a6f2aa8 commit ab19bd6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
39 changes: 18 additions & 21 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ def test_grouped_bounded_range_window(backend, alltypes, df):
# `group_by='string_col'`:
# The window at a particular row will only contain other rows that
# have the same 'string_col' value.
#
preceding = 10
window = ibis.range_window(
preceding=10,
preceding=preceding,
following=0,
order_by='id',
group_by='string_col',
Expand All @@ -536,21 +536,22 @@ def test_grouped_bounded_range_window(backend, alltypes, df):
result = expr.execute().set_index('id').sort_index()

def gb_fn(df):
indices = np.searchsorted(
df.id,
[
df.id - 10,
# add 1 to get the upper bound without having to make two
# searchsorted calls
df.id + 1,
],
side="left",
).T
indices = np.searchsorted(df.id, [df["prec"], df["foll"]], side="left")
double_col = df.double_col.values
result = [double_col[start:stop].sum() for start, stop in indices]
return pd.Series(result, index=df.index)
return pd.Series(
[double_col[start:stop].sum() for start, stop in indices.T],
index=df.index,
)

res = df.sort_values("id").groupby("string_col").apply(gb_fn).droplevel(0)
res = (
# add 1 to get the upper bound without having to make two
# searchsorted calls
df.assign(prec=lambda t: t.id - preceding, foll=lambda t: t.id + 1)
.sort_values("id")
.groupby("string_col")
.apply(gb_fn)
.droplevel(0)
)
expected = (
df.assign(
# Mimic our range window spec using .apply()
Expand All @@ -560,9 +561,7 @@ def gb_fn(df):
.sort_index()
)

left, right = result.val, expected.val

backend.assert_series_equal(left, right)
backend.assert_series_equal(result.val, expected.val)


@pytest.mark.notimpl(["clickhouse", "dask", "datafusion", "pyspark"])
Expand All @@ -573,6 +572,4 @@ def test_percent_rank_whole_table_no_order_by(backend, alltypes, df):
column = df.id.rank(method="min").sub(1).div(len(df) - 1)
expected = df.assign(val=column).set_index('id').sort_index()

left, right = result.val, expected.val

backend.assert_series_equal(left, right)
backend.assert_series_equal(result.val, expected.val)
8 changes: 8 additions & 0 deletions ibis/expr/types/sortkeys.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from public import public

if TYPE_CHECKING:
import ibis.expr.datatypes as dt

from .generic import Expr


@public
class SortExpr(Expr):
def get_name(self) -> str | None:
return self.op().resolve_name()

def type(self) -> dt.DataType:
return self.op().expr.type()

0 comments on commit ab19bd6

Please sign in to comment.