Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(api): refactor the implementation of windowing #9200

Merged
merged 13 commits into from
Jul 18, 2024
68 changes: 50 additions & 18 deletions ibis/backends/flink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@
if TYPE_CHECKING:
from pyflink.table import StreamTableEnvironment

TEST_TABLES["functional_alltypes"] = ibis.schema(
{
"id": "int32",
"bool_col": "boolean",
"tinyint_col": "int8",
"smallint_col": "int16",
"int_col": "int32",
"bigint_col": "int64",
"float_col": "float32",
"double_col": "float64",
"date_string_col": "string",
"string_col": "string",
"timestamp_col": "timestamp(3)", # overriding the higher level fixture with precision because Flink's
# watermark must use a field of type TIMESTAMP(p) or TIMESTAMP_LTZ(p), where 'p' is from 0 to 3
"year": "int32",
"month": "int32",
}
)


def get_table_env(
local_env: bool,
Expand Down Expand Up @@ -152,24 +171,7 @@ def awards_players_schema():

@pytest.fixture
def functional_alltypes_schema():
return ibis.schema(
{
"id": "int32",
"bool_col": "boolean",
"tinyint_col": "int8",
"smallint_col": "int16",
"int_col": "int32",
"bigint_col": "int64",
"float_col": "float32",
"double_col": "float64",
"date_string_col": "string",
"string_col": "string",
"timestamp_col": "timestamp(3)", # overriding the higher level fixture with precision because Flink's
# watermark must use a field of type TIMESTAMP(p) or TIMESTAMP_LTZ(p), where 'p' is from 0 to 3
"year": "int32",
"month": "int32",
}
)
return TEST_TABLES["functional_alltypes"]


@pytest.fixture
Expand All @@ -188,3 +190,33 @@ def generate_csv_configs(csv_file):
}

return generate_csv_configs


@pytest.fixture(scope="session")
def functional_alltypes_no_header(tmpdir_factory, data_dir):
file = tmpdir_factory.mktemp("data") / "functional_alltypes.csv"
with (
open(data_dir / "csv" / "functional_alltypes.csv") as reader,
open(str(file), mode="w") as writer,
):
reader.readline() # read the first line and discard it
for line in reader:
writer.write(line)
return file


@pytest.fixture(scope="session", autouse=True)
def functional_alltypes_with_watermark(con, functional_alltypes_no_header):
# create a streaming table with watermark for testing event-time based ops
t = con.create_table(
"functional_alltypes_with_watermark",
schema=TEST_TABLES["functional_alltypes"],
tbl_properties={
"connector": "filesystem",
"path": functional_alltypes_no_header,
"format": "csv",
},
watermark=ibis.watermark("timestamp_col", ibis.interval(seconds=10)),
temp=True,
)
return t
50 changes: 0 additions & 50 deletions ibis/backends/flink/tests/test_compiler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from __future__ import annotations

from operator import methodcaller

import pytest
from pytest import param

import ibis
from ibis.common.deferred import _


def test_sum(simple_table, assert_sql):
expr = simple_table.a.sum()
Expand Down Expand Up @@ -103,48 +98,3 @@ def test_having(simple_table, assert_sql):
.aggregate(simple_table.b.sum().name("b_sum"))
)
assert_sql(expr)


@pytest.mark.parametrize(
"method",
[
methodcaller("tumble", window_size=ibis.interval(minutes=15)),
methodcaller(
"hop",
window_size=ibis.interval(minutes=15),
window_slide=ibis.interval(minutes=1),
),
methodcaller(
"cumulate",
window_size=ibis.interval(minutes=1),
window_step=ibis.interval(seconds=10),
),
],
ids=["tumble", "hop", "cumulate"],
)
def test_windowing_tvf(simple_table, method, assert_sql):
expr = method(simple_table.window_by(time_col=simple_table.i))
assert_sql(expr)


def test_window_aggregation(simple_table, assert_sql):
expr = (
simple_table.window_by(time_col=simple_table.i)
.tumble(window_size=ibis.interval(minutes=15))
.group_by(["window_start", "window_end", "g"])
.aggregate(mean=_.d.mean())
)
assert_sql(expr)


def test_window_topn(simple_table, assert_sql):
expr = simple_table.window_by(time_col="i").tumble(
window_size=ibis.interval(seconds=600),
)["a", "b", "c", "d", "g", "window_start", "window_end"]
expr = expr.mutate(
rownum=ibis.row_number().over(
group_by=["window_start", "window_end"], order_by=ibis.desc("g")
)
)
expr = expr[expr.rownum <= 3]
assert_sql(expr)
166 changes: 0 additions & 166 deletions ibis/backends/flink/tests/test_join.py

This file was deleted.

29 changes: 27 additions & 2 deletions ibis/backends/flink/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytest import param

import ibis
from ibis import _
from ibis.backends.tests.errors import Py4JJavaError


Expand Down Expand Up @@ -53,13 +54,37 @@ def test_window_invalid_start_end(con, window):
con.execute(expr)


def test_range_window(con, simple_table, assert_sql):
def test_range_window(simple_table, assert_sql):
expr = simple_table.f.sum().over(
range=(-ibis.interval(minutes=500), 0), order_by=simple_table.f
)
assert_sql(expr)


def test_rows_window(con, simple_table, assert_sql):
def test_rows_window(simple_table, assert_sql):
expr = simple_table.f.sum().over(rows=(-1000, 0), order_by=simple_table.f)
assert_sql(expr)


def test_tumble_window_by_grouped_agg(con):
t = con.table("functional_alltypes_with_watermark")
expr = (
t.window_by(time_col=t.timestamp_col)
.tumble(size=ibis.interval(seconds=30))
.agg(by=["string_col"], avg=_.float_col.mean())
)
result = expr.to_pandas()
assert list(result.columns) == ["window_start", "window_end", "string_col", "avg"]
assert result.shape == (610, 4)


def test_tumble_window_by_ungrouped_agg(con):
t = con.table("functional_alltypes_with_watermark")
expr = (
t.window_by(time_col=t.timestamp_col)
.tumble(size=ibis.interval(seconds=30))
.agg(avg=_.float_col.mean())
)
result = expr.to_pandas()
assert list(result.columns) == ["window_start", "window_end", "avg"]
assert result.shape == (610, 3)
Loading