Skip to content

Commit

Permalink
feat(flink): add tests and translation rules for additional operators
Browse files Browse the repository at this point in the history
  • Loading branch information
chloeh13q authored and jcrist committed Jul 6, 2023
1 parent 91ec3bc commit fc2aa5d
Show file tree
Hide file tree
Showing 19 changed files with 202 additions and 30 deletions.
42 changes: 41 additions & 1 deletion ibis/backends/flink/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _count_star(translator: ExprTranslator, op: ops.Node) -> str:
return "count(*)"


def _timestamp_from_unix(translator, op):
def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str:
arg, unit = op.args

if unit == TimestampUnit.MILLISECOND:
Expand All @@ -23,9 +23,49 @@ def _timestamp_from_unix(translator, op):
raise ValueError(f"{unit!r} unit is not supported!")


def _extract_field(sql_attr: str) -> str:
def extract_field_formatter(translator: ExprTranslator, op: ops.Node) -> str:
arg = translator.translate(op.args[0])
if sql_attr == "epochseconds":
return f"UNIX_SECONDS({arg})"
else:
return f"EXTRACT({sql_attr} from {arg})"

return extract_field_formatter


def _filter(translator: ExprTranslator, op: ops.Node) -> str:
bool_expr = translator.translate(op.bool_expr)
true_expr = translator.translate(op.true_expr)
false_null_expr = translator.translate(op.false_null_expr)

# [TODO](chloeh13q): It's preferable to use the FILTER syntax instead of CASE WHEN
# to let the planner do more optimizations to reduce the state size; besides, FILTER
# is more compliant with the SQL standard.
# For example,
# ```
# COUNT(DISTINCT CASE WHEN flag = 'app' THEN user_id ELSE NULL END) AS app_uv
# ```
# is equivalent to
# ```
# COUNT(DISTINCT) FILTER (WHERE flag = 'app') AS app_uv
# ```
return f"CASE WHEN {bool_expr} THEN {true_expr} ELSE {false_null_expr} END"


operation_registry.update(
{
ops.CountStar: _count_star,
ops.ExtractYear: _extract_field("year"), # equivalent to YEAR(date)
ops.ExtractQuarter: _extract_field("quarter"), # equivalent to QUARTER(date)
ops.ExtractMonth: _extract_field("month"), # equivalent to MONTH(date)
ops.ExtractWeekOfYear: _extract_field("week"), # equivalent to WEEK(date)
ops.ExtractDayOfYear: _extract_field("doy"), # equivalent to DAYOFYEAR(date)
ops.ExtractDay: _extract_field("day"), # equivalent to DAYOFMONTH(date)
ops.ExtractHour: _extract_field("hour"), # equivalent to HOUR(timestamp)
ops.ExtractMinute: _extract_field("minute"), # equivalent to MINUTE(timestamp)
ops.ExtractSecond: _extract_field("second"), # equivalent to SECOND(timestamp)
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.Where: _filter,
}
)
22 changes: 22 additions & 0 deletions ibis/backends/flink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@
from ibis.backends.conftest import TEST_TABLES


@pytest.fixture
def simple_schema():
return [
('a', 'int8'),
('b', 'int16'),
('c', 'int32'),
('d', 'int64'),
('e', 'float32'),
('f', 'float64'),
('g', 'string'),
('h', 'boolean'),
('i', 'timestamp'),
('j', 'date'),
('k', 'time'),
]


@pytest.fixture
def simple_table(simple_schema):
return ibis.table(simple_schema, name='table')


@pytest.fixture
def batting() -> ir.Table:
return ibis.table(schema=TEST_TABLES["batting"], name="batting")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT t0.`b`, count(*) AS `total`, avg(t0.`a`) AS `avg_a`,
avg(CASE WHEN t0.`g` = 'A' THEN t0.`a` ELSE NULL END) AS `avg_a_A`,
avg(CASE WHEN t0.`g` = 'B' THEN t0.`a` ELSE NULL END) AS `avg_a_B`
FROM table t0
GROUP BY t0.`b`
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT EXTRACT(year from t0.`i`) AS `year`,
EXTRACT(month from t0.`i`) AS `month`, count(*) AS `total`,
count(DISTINCT t0.`b`) AS `b_unique`
FROM table t0
GROUP BY EXTRACT(year from t0.`i`), EXTRACT(month from t0.`i`)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT t0.`a`, avg(abs(t0.`the_sum`)) AS `mad`
FROM (
SELECT t1.`a`, t1.`c`, sum(t1.`b`) AS `the_sum`
FROM table t1
GROUP BY t1.`a`, t1.`c`
) t0
GROUP BY t0.`a`
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(day from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(doy from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(hour from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(minute from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(month from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(quarter from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(second from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(week from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(year from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT t0.*
FROM table t0
WHERE ((t0.`c` > 0) OR (t0.`c` < 0)) AND
(t0.`g` IN ('A', 'B'))
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT t0.`g`, sum(t0.`b`) AS `b_sum`
FROM table t0
GROUP BY t0.`g`
HAVING count(*) >= 1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT count(DISTINCT CASE WHEN t0.`g` = 'A' THEN t0.`b` ELSE NULL END) AS `CountDistinct(b, Equals(g, 'A'))`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT t0.`ExtractYear(i)`, count(*) AS `ExtractYear(i)_count`
FROM (
SELECT EXTRACT(year from t1.`i`) AS `ExtractYear(i)`
FROM table t1
) t0
GROUP BY t0.`ExtractYear(i)`
117 changes: 88 additions & 29 deletions ibis/backends/flink/tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,17 @@
import pytest
from pytest import param

import ibis
from ibis.backends.flink.compiler.core import translate


@pytest.fixture
def schema():
return [
('a', 'int8'),
('b', 'int16'),
('c', 'int32'),
('d', 'int64'),
('e', 'float32'),
('f', 'float64'),
('g', 'string'),
('h', 'boolean'),
('i', 'timestamp'),
('j', 'date'),
('k', 'time'),
]


@pytest.fixture
def table(schema):
return ibis.table(schema, name='table')


def test_translate_sum(snapshot, table):
expr = table.a.sum()
def test_translate_sum(snapshot, simple_table):
expr = simple_table.a.sum()
result = translate(expr.as_table().op())
snapshot.assert_match(str(result), "out.sql")


def test_translate_count_star(snapshot, table):
expr = table.group_by(table.i).size()
def test_translate_count_star(snapshot, simple_table):
expr = simple_table.group_by(simple_table.i).size()
result = translate(expr.as_table().op())
snapshot.assert_match(str(result), "out.sql")

Expand All @@ -46,7 +23,89 @@ def test_translate_count_star(snapshot, table):
param("s", id="timestamp_s"),
],
)
def test_translate_timestamp_from_unix(snapshot, table, unit):
expr = table.d.to_timestamp(unit=unit)
def test_translate_timestamp_from_unix(snapshot, simple_table, unit):
expr = simple_table.d.to_timestamp(unit=unit)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_complex_projections(snapshot, simple_table):
expr = (
simple_table.group_by(['a', 'c'])
.aggregate(the_sum=simple_table.b.sum())
.group_by('a')
.aggregate(mad=lambda x: x.the_sum.abs().mean())
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_filter(snapshot, simple_table):
expr = simple_table[
((simple_table.c > 0) | (simple_table.c < 0)) & simple_table.g.isin(['A', 'B'])
]
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


@pytest.mark.parametrize(
"kind",
[
"year",
"quarter",
"month",
"week_of_year",
"day_of_year",
"day",
"hour",
"minute",
"second",
],
)
def test_translate_extract_fields(snapshot, simple_table, kind):
expr = getattr(simple_table.i, kind)().name("tmp")
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_complex_groupby_aggregation(snapshot, simple_table):
keys = [simple_table.i.year().name('year'), simple_table.i.month().name('month')]
b_unique = simple_table.b.nunique()
expr = simple_table.group_by(keys).aggregate(
total=simple_table.count(), b_unique=b_unique
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_simple_filtered_agg(snapshot, simple_table):
expr = simple_table.b.nunique(where=simple_table.g == 'A')
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_complex_filtered_agg(snapshot, simple_table):
expr = simple_table.group_by('b').aggregate(
total=simple_table.count(),
avg_a=simple_table.a.mean(),
avg_a_A=simple_table.a.mean(where=simple_table.g == 'A'),
avg_a_B=simple_table.a.mean(where=simple_table.g == 'B'),
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_value_counts(snapshot, simple_table):
expr = simple_table.i.year().value_counts()
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_having(snapshot, simple_table):
expr = (
simple_table.group_by('g')
.having(simple_table.count() >= 1000)
.aggregate(simple_table.b.sum().name('b_sum'))
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")

0 comments on commit fc2aa5d

Please sign in to comment.