From 2cbca74cefb12b9aa25c2e27c24f51aacdf8b1c0 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Sat, 7 Oct 2023 20:34:47 -0600 Subject: [PATCH] fix(flink): correct the filtered count translation --- ibis/backends/flink/registry.py | 8 +++++++- .../test_translate_complex_filtered_agg/out.sql | 2 +- .../test_translate_complex_groupby_aggregation/out.sql | 2 +- .../test_translator/test_translate_count_star/out.sql | 2 +- .../test_translator/test_translate_having/out.sql | 2 +- .../test_translator/test_translate_value_counts/out.sql | 2 +- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ibis/backends/flink/registry.py b/ibis/backends/flink/registry.py index 86dc1b059b7d..6be715766361 100644 --- a/ibis/backends/flink/registry.py +++ b/ibis/backends/flink/registry.py @@ -18,7 +18,13 @@ def _count_star(translator: ExprTranslator, op: ops.Node) -> str: - return "count(*)" + # TODO(deepyaman): Use `FILTER` syntax; see note on `_filter` below. + if (where := op.where) is not None: + condition = f"CASE WHEN {translator.translate(where)} THEN 1 END" + else: + condition = "*" + + return f"COUNT({condition})" def _date(translator: ExprTranslator, op: ops.Node) -> str: diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_filtered_agg/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_filtered_agg/out.sql index 84e53a6950ae..5b9990fe3892 100644 --- a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_filtered_agg/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_filtered_agg/out.sql @@ -1,4 +1,4 @@ -SELECT t0.`b`, count(*) AS `total`, avg(t0.`a`) AS `avg_a`, +SELECT t0.`b`, COUNT(*) AS `total`, avg(t0.`a`) AS `avg_a`, avg(CASE WHEN t0.`g` = 'A' THEN t0.`a` ELSE NULL END) AS `avg_a_A`, avg(CASE WHEN t0.`g` = 'B' THEN t0.`a` ELSE NULL END) AS `avg_a_B` FROM table t0 diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_groupby_aggregation/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_groupby_aggregation/out.sql index 00df624ca45e..1ce799579cd5 100644 --- a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_groupby_aggregation/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_complex_groupby_aggregation/out.sql @@ -1,5 +1,5 @@ SELECT EXTRACT(year from t0.`i`) AS `year`, - EXTRACT(month from t0.`i`) AS `month`, count(*) AS `total`, + EXTRACT(month from t0.`i`) AS `month`, COUNT(*) AS `total`, count(DISTINCT t0.`b`) AS `b_unique` FROM table t0 GROUP BY EXTRACT(year from t0.`i`), EXTRACT(month from t0.`i`) \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_count_star/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_count_star/out.sql index 99aa48f2791f..e3e3e49089cd 100644 --- a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_count_star/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_count_star/out.sql @@ -1,3 +1,3 @@ -SELECT t0.`i`, count(*) AS `CountStar(table)` +SELECT t0.`i`, COUNT(*) AS `CountStar(table)` FROM table t0 GROUP BY t0.`i` \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql index 1f4153d357e6..3744dd045f0d 100644 --- a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql @@ -1,4 +1,4 @@ SELECT t0.`g`, sum(t0.`b`) AS `b_sum` FROM table t0 GROUP BY t0.`g` -HAVING count(*) >= CAST(1000 AS SMALLINT) \ No newline at end of file +HAVING COUNT(*) >= CAST(1000 AS SMALLINT) \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_value_counts/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_value_counts/out.sql index 42da71c14695..53e792c45392 100644 --- a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_value_counts/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_value_counts/out.sql @@ -1,4 +1,4 @@ -SELECT t0.`ExtractYear(i)`, count(*) AS `ExtractYear(i)_count` +SELECT t0.`ExtractYear(i)`, COUNT(*) AS `ExtractYear(i)_count` FROM ( SELECT EXTRACT(year from t1.`i`) AS `ExtractYear(i)` FROM table t1