diff --git a/ibis/backends/flink/registry.py b/ibis/backends/flink/registry.py index ec9aa3b574b4..b720cc6b3d6e 100644 --- a/ibis/backends/flink/registry.py +++ b/ibis/backends/flink/registry.py @@ -13,7 +13,7 @@ def _count_star(translator: ExprTranslator, op: ops.Node) -> str: return "count(*)" -def _timestamp_from_unix(translator, op): +def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str: arg, unit = op.args if unit == TimestampUnit.MILLISECOND: @@ -23,9 +23,49 @@ def _timestamp_from_unix(translator, op): raise ValueError(f"{unit!r} unit is not supported!") +def _extract_field(sql_attr: str) -> str: + def extract_field_formatter(translator: ExprTranslator, op: ops.Node) -> str: + arg = translator.translate(op.args[0]) + if sql_attr == "epochseconds": + return f"UNIX_SECONDS({arg})" + else: + return f"EXTRACT({sql_attr} from {arg})" + + return extract_field_formatter + + +def _filter(translator: ExprTranslator, op: ops.Node) -> str: + bool_expr = translator.translate(op.bool_expr) + true_expr = translator.translate(op.true_expr) + false_null_expr = translator.translate(op.false_null_expr) + + # [TODO](chloeh13q): It's preferable to use the FILTER syntax instead of CASE WHEN + # to let the planner do more optimizations to reduce the state size; besides, FILTER + # is more compliant with the SQL standard. + # For example, + # ``` + # COUNT(DISTINCT CASE WHEN flag = 'app' THEN user_id ELSE NULL END) AS app_uv + # ``` + # is equivalent to + # ``` + # COUNT(DISTINCT) FILTER (WHERE flag = 'app') AS app_uv + # ``` + return f"CASE WHEN {bool_expr} THEN {true_expr} ELSE {false_null_expr} END" + + operation_registry.update( { ops.CountStar: _count_star, + ops.ExtractYear: _extract_field("year"), # equivalent to YEAR(date) + ops.ExtractQuarter: _extract_field("quarter"), # equivalent to QUARTER(date) + ops.ExtractMonth: _extract_field("month"), # equivalent to MONTH(date) + ops.ExtractWeekOfYear: _extract_field("week"), # equivalent to WEEK(date) + ops.ExtractDayOfYear: _extract_field("doy"), # equivalent to DAYOFYEAR(date) + ops.ExtractDay: _extract_field("day"), # equivalent to DAYOFMONTH(date) + ops.ExtractHour: _extract_field("hour"), # equivalent to HOUR(timestamp) + ops.ExtractMinute: _extract_field("minute"), # equivalent to MINUTE(timestamp) + ops.ExtractSecond: _extract_field("second"), # equivalent to SECOND(timestamp) ops.TimestampFromUNIX: _timestamp_from_unix, + ops.Where: _filter, } ) diff --git a/ibis/backends/flink/tests/conftest.py b/ibis/backends/flink/tests/conftest.py index c3044f985cea..5305a9070bf3 100644 --- a/ibis/backends/flink/tests/conftest.py +++ b/ibis/backends/flink/tests/conftest.py @@ -5,6 +5,28 @@ from ibis.backends.conftest import TEST_TABLES +@pytest.fixture +def simple_schema(): + return [ + ('a', 'int8'), + ('b', 'int16'), + ('c', 'int32'), + ('d', 'int64'), + ('e', 'float32'), + ('f', 'float64'), + ('g', 'string'), + ('h', 'boolean'), + ('i', 'timestamp'), + ('j', 'date'), + ('k', 'time'), + ] + + +@pytest.fixture +def simple_table(simple_schema): + return ibis.table(simple_schema, name='table') + + @pytest.fixture def batting() -> ir.Table: return ibis.table(schema=TEST_TABLES["batting"], name="batting") diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_filtered_agg/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_filtered_agg/out.sql new file mode 100644 index 000000000000..84e53a6950ae --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_filtered_agg/out.sql @@ -0,0 +1,5 @@ +SELECT t0.`b`, count(*) AS `total`, avg(t0.`a`) AS `avg_a`, + avg(CASE WHEN t0.`g` = 'A' THEN t0.`a` ELSE NULL END) AS `avg_a_A`, + avg(CASE WHEN t0.`g` = 'B' THEN t0.`a` ELSE NULL END) AS `avg_a_B` +FROM table t0 +GROUP BY t0.`b` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_groupby_aggregation/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_groupby_aggregation/out.sql new file mode 100644 index 000000000000..00df624ca45e --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_groupby_aggregation/out.sql @@ -0,0 +1,5 @@ +SELECT EXTRACT(year from t0.`i`) AS `year`, + EXTRACT(month from t0.`i`) AS `month`, count(*) AS `total`, + count(DISTINCT t0.`b`) AS `b_unique` +FROM table t0 +GROUP BY EXTRACT(year from t0.`i`), EXTRACT(month from t0.`i`) \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_projections/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_projections/out.sql new file mode 100644 index 000000000000..c883cc8ab63a --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_projections/out.sql @@ -0,0 +1,7 @@ +SELECT t0.`a`, avg(abs(t0.`the_sum`)) AS `mad` +FROM ( + SELECT t1.`a`, t1.`c`, sum(t1.`b`) AS `the_sum` + FROM table t1 + GROUP BY t1.`a`, t1.`c` +) t0 +GROUP BY t0.`a` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/day/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/day/out.sql new file mode 100644 index 000000000000..14d96a04c89c --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/day/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(day from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/day_of_year/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/day_of_year/out.sql new file mode 100644 index 000000000000..9774a20af63a --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/day_of_year/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(doy from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/hour/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/hour/out.sql new file mode 100644 index 000000000000..e19999b7b1f1 --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/hour/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(hour from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/minute/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/minute/out.sql new file mode 100644 index 000000000000..aeed550bdab2 --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/minute/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(minute from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/month/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/month/out.sql new file mode 100644 index 000000000000..57e3d1c6fa45 --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/month/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(month from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/quarter/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/quarter/out.sql new file mode 100644 index 000000000000..c6170172bde1 --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/quarter/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(quarter from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/second/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/second/out.sql new file mode 100644 index 000000000000..0c32dade798e --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/second/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(second from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/week_of_year/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/week_of_year/out.sql new file mode 100644 index 000000000000..3f6ca60700eb --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/week_of_year/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(week from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/year/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/year/out.sql new file mode 100644 index 000000000000..ab354a7fd861 --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_extract_fields/year/out.sql @@ -0,0 +1,2 @@ +SELECT EXTRACT(year from t0.`i`) AS `tmp` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_filter/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_filter/out.sql new file mode 100644 index 000000000000..da36f49f52e3 --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_filter/out.sql @@ -0,0 +1,4 @@ +SELECT t0.* +FROM table t0 +WHERE ((t0.`c` > 0) OR (t0.`c` < 0)) AND + (t0.`g` IN ('A', 'B')) \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql new file mode 100644 index 000000000000..c34fa980e5ae --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql @@ -0,0 +1,4 @@ +SELECT t0.`g`, sum(t0.`b`) AS `b_sum` +FROM table t0 +GROUP BY t0.`g` +HAVING count(*) >= 1000 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_simple_filtered_agg/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_simple_filtered_agg/out.sql new file mode 100644 index 000000000000..19afa7a54cdf --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_simple_filtered_agg/out.sql @@ -0,0 +1,2 @@ +SELECT count(DISTINCT CASE WHEN t0.`g` = 'A' THEN t0.`b` ELSE NULL END) AS `CountDistinct(b, Equals(g, 'A'))` +FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_value_counts/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_value_counts/out.sql new file mode 100644 index 000000000000..42da71c14695 --- /dev/null +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_value_counts/out.sql @@ -0,0 +1,6 @@ +SELECT t0.`ExtractYear(i)`, count(*) AS `ExtractYear(i)_count` +FROM ( + SELECT EXTRACT(year from t1.`i`) AS `ExtractYear(i)` + FROM table t1 +) t0 +GROUP BY t0.`ExtractYear(i)` \ No newline at end of file diff --git a/ibis/backends/flink/tests/test_translator.py b/ibis/backends/flink/tests/test_translator.py index 0a1b5f2c5fe3..0a838f5f49a6 100644 --- a/ibis/backends/flink/tests/test_translator.py +++ b/ibis/backends/flink/tests/test_translator.py @@ -1,40 +1,17 @@ import pytest from pytest import param -import ibis from ibis.backends.flink.compiler.core import translate -@pytest.fixture -def schema(): - return [ - ('a', 'int8'), - ('b', 'int16'), - ('c', 'int32'), - ('d', 'int64'), - ('e', 'float32'), - ('f', 'float64'), - ('g', 'string'), - ('h', 'boolean'), - ('i', 'timestamp'), - ('j', 'date'), - ('k', 'time'), - ] - - -@pytest.fixture -def table(schema): - return ibis.table(schema, name='table') - - -def test_translate_sum(snapshot, table): - expr = table.a.sum() +def test_translate_sum(snapshot, simple_table): + expr = simple_table.a.sum() result = translate(expr.as_table().op()) snapshot.assert_match(str(result), "out.sql") -def test_translate_count_star(snapshot, table): - expr = table.group_by(table.i).size() +def test_translate_count_star(snapshot, simple_table): + expr = simple_table.group_by(simple_table.i).size() result = translate(expr.as_table().op()) snapshot.assert_match(str(result), "out.sql") @@ -46,7 +23,89 @@ def test_translate_count_star(snapshot, table): param("s", id="timestamp_s"), ], ) -def test_translate_timestamp_from_unix(snapshot, table, unit): - expr = table.d.to_timestamp(unit=unit) +def test_translate_timestamp_from_unix(snapshot, simple_table, unit): + expr = simple_table.d.to_timestamp(unit=unit) + result = translate(expr.as_table().op()) + snapshot.assert_match(result, "out.sql") + + +def test_translate_complex_projections(snapshot, simple_table): + expr = ( + simple_table.group_by(['a', 'c']) + .aggregate(the_sum=simple_table.b.sum()) + .group_by('a') + .aggregate(mad=lambda x: x.the_sum.abs().mean()) + ) + result = translate(expr.as_table().op()) + snapshot.assert_match(result, "out.sql") + + +def test_translate_filter(snapshot, simple_table): + expr = simple_table[ + ((simple_table.c > 0) | (simple_table.c < 0)) & simple_table.g.isin(['A', 'B']) + ] + result = translate(expr.as_table().op()) + snapshot.assert_match(result, "out.sql") + + +@pytest.mark.parametrize( + "kind", + [ + "year", + "quarter", + "month", + "week_of_year", + "day_of_year", + "day", + "hour", + "minute", + "second", + ], +) +def test_translate_extract_fields(snapshot, simple_table, kind): + expr = getattr(simple_table.i, kind)().name("tmp") + result = translate(expr.as_table().op()) + snapshot.assert_match(result, "out.sql") + + +def test_translate_complex_groupby_aggregation(snapshot, simple_table): + keys = [simple_table.i.year().name('year'), simple_table.i.month().name('month')] + b_unique = simple_table.b.nunique() + expr = simple_table.group_by(keys).aggregate( + total=simple_table.count(), b_unique=b_unique + ) + result = translate(expr.as_table().op()) + snapshot.assert_match(result, "out.sql") + + +def test_translate_simple_filtered_agg(snapshot, simple_table): + expr = simple_table.b.nunique(where=simple_table.g == 'A') + result = translate(expr.as_table().op()) + snapshot.assert_match(result, "out.sql") + + +def test_translate_complex_filtered_agg(snapshot, simple_table): + expr = simple_table.group_by('b').aggregate( + total=simple_table.count(), + avg_a=simple_table.a.mean(), + avg_a_A=simple_table.a.mean(where=simple_table.g == 'A'), + avg_a_B=simple_table.a.mean(where=simple_table.g == 'B'), + ) + result = translate(expr.as_table().op()) + snapshot.assert_match(result, "out.sql") + + +def test_translate_value_counts(snapshot, simple_table): + expr = simple_table.i.year().value_counts() + result = translate(expr.as_table().op()) + snapshot.assert_match(result, "out.sql") + + +def test_translate_having(snapshot, simple_table): + expr = ( + simple_table.group_by('g') + .having(simple_table.count() >= 1000) + .aggregate(simple_table.b.sum().name('b_sum')) + ) result = translate(expr.as_table().op()) snapshot.assert_match(result, "out.sql")