From 726cc993da52073d3ca41a45fe814a827b8bbde1 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sun, 7 May 2023 11:04:14 +0400 Subject: [PATCH] feat(rust): add support for `HAVING` clause to SQL `GROUP BY` operations (#8704) --- polars/polars-sql/src/context.rs | 12 ++++++-- polars/polars-sql/src/keywords.rs | 1 + py-polars/tests/unit/test_sql.py | 48 +++++++++++++++++++++---------- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/polars/polars-sql/src/context.rs b/polars/polars-sql/src/context.rs index 889dd7076a87..9ee4f479bb99 100644 --- a/polars/polars-sql/src/context.rs +++ b/polars/polars-sql/src/context.rs @@ -204,11 +204,11 @@ impl SQLContext { .get(0) .ok_or_else(|| polars_err!(ComputeError: "no table name provided in query"))?; - let lf = self.execute_from_statement(sql_tbl)?; + let mut lf = self.execute_from_statement(sql_tbl)?; let mut contains_wildcard = false; // Filter Expression - let lf = match select_stmt.selection.as_ref() { + lf = match select_stmt.selection.as_ref() { Some(expr) => { let filter_expression = parse_sql_expr(expr, self)?; lf.filter(filter_expression) @@ -262,7 +262,13 @@ impl SQLContext { if groupby_keys.is_empty() { Ok(lf.select(projections)) } else { - self.process_groupby(lf, contains_wildcard, &groupby_keys, &projections) + lf = self.process_groupby(lf, contains_wildcard, &groupby_keys, &projections)?; + + // Apply 'having' clause, post-aggregation + match select_stmt.having.as_ref() { + Some(expr) => Ok(lf.filter(parse_sql_expr(expr, self)?)), + None => Ok(lf), + } } } diff --git a/polars/polars-sql/src/keywords.rs b/polars/polars-sql/src/keywords.rs index 2221834941f3..1926cd58cf51 100644 --- a/polars/polars-sql/src/keywords.rs +++ b/polars/polars-sql/src/keywords.rs @@ -53,6 +53,7 @@ pub fn all_keywords() -> Vec<&'static str> { keywords::NOT, keywords::IN, keywords::WITH, + keywords::HAVING, ]; keywords.extend_from_slice(sql_keywords); keywords diff --git a/py-polars/tests/unit/test_sql.py b/py-polars/tests/unit/test_sql.py index b3c40208686d..b816d59e7b12 100644 --- a/py-polars/tests/unit/test_sql.py +++ b/py-polars/tests/unit/test_sql.py @@ -20,25 +20,43 @@ def test_sql_groupby(foods_ipc_path: Path) -> None: out = c.query( """ - SELECT - category, - count(category) as count, - max(calories), - min(fats_g) - FROM foods - GROUP BY category - ORDER BY count, category DESC - LIMIT 2 - """ + SELECT + category, + count(category) as n, + max(calories), + min(fats_g) + FROM foods + GROUP BY category + HAVING n > 5 + ORDER BY n, category DESC + """ ) - assert out.to_dict(False) == { - "category": ["meat", "vegetables"], - "count": [5, 7], - "calories": [120, 45], - "fats_g": [5.0, 0.0], + "category": ["vegetables", "fruit", "seafood"], + "n": [7, 7, 8], + "calories": [45, 130, 200], + "fats_g": [0.0, 0.0, 1.5], } + lf = pl.LazyFrame( + { + "group": ["a", "b", "c", "c", "b"], + "attr": ["x", "y", "x", "y", "y"], + } + ) + c.register("test", lf) + out = c.query( + """ + SELECT + group, + COUNT(DISTINCT attr) AS n_dist_attr + FROM test + GROUP BY group + HAVING n_dist_attr > 1 + """ + ) + assert out.to_dict(False) == {"group": ["c"], "n_dist_attr": [2]} + def test_sql_join(foods_ipc_path: Path) -> None: c = pl.SQLContext()