Skip to content

Commit

Permalink
feat(rust): add support for HAVING clause to SQL GROUP BY operati…
Browse files Browse the repository at this point in the history
…ons (#8704)
  • Loading branch information
alexander-beedie authored May 7, 2023
1 parent 98cfa93 commit 726cc99
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
12 changes: 9 additions & 3 deletions polars/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ impl SQLContext {
.get(0)
.ok_or_else(|| polars_err!(ComputeError: "no table name provided in query"))?;

let lf = self.execute_from_statement(sql_tbl)?;
let mut lf = self.execute_from_statement(sql_tbl)?;
let mut contains_wildcard = false;

// Filter Expression
let lf = match select_stmt.selection.as_ref() {
lf = match select_stmt.selection.as_ref() {
Some(expr) => {
let filter_expression = parse_sql_expr(expr, self)?;
lf.filter(filter_expression)
Expand Down Expand Up @@ -262,7 +262,13 @@ impl SQLContext {
if groupby_keys.is_empty() {
Ok(lf.select(projections))
} else {
self.process_groupby(lf, contains_wildcard, &groupby_keys, &projections)
lf = self.process_groupby(lf, contains_wildcard, &groupby_keys, &projections)?;

// Apply 'having' clause, post-aggregation
match select_stmt.having.as_ref() {
Some(expr) => Ok(lf.filter(parse_sql_expr(expr, self)?)),
None => Ok(lf),
}
}
}

Expand Down
1 change: 1 addition & 0 deletions polars/polars-sql/src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub fn all_keywords() -> Vec<&'static str> {
keywords::NOT,
keywords::IN,
keywords::WITH,
keywords::HAVING,
];
keywords.extend_from_slice(sql_keywords);
keywords
Expand Down
48 changes: 33 additions & 15 deletions py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,43 @@ def test_sql_groupby(foods_ipc_path: Path) -> None:

out = c.query(
"""
SELECT
category,
count(category) as count,
max(calories),
min(fats_g)
FROM foods
GROUP BY category
ORDER BY count, category DESC
LIMIT 2
"""
SELECT
category,
count(category) as n,
max(calories),
min(fats_g)
FROM foods
GROUP BY category
HAVING n > 5
ORDER BY n, category DESC
"""
)

assert out.to_dict(False) == {
"category": ["meat", "vegetables"],
"count": [5, 7],
"calories": [120, 45],
"fats_g": [5.0, 0.0],
"category": ["vegetables", "fruit", "seafood"],
"n": [7, 7, 8],
"calories": [45, 130, 200],
"fats_g": [0.0, 0.0, 1.5],
}

lf = pl.LazyFrame(
{
"group": ["a", "b", "c", "c", "b"],
"attr": ["x", "y", "x", "y", "y"],
}
)
c.register("test", lf)
out = c.query(
"""
SELECT
group,
COUNT(DISTINCT attr) AS n_dist_attr
FROM test
GROUP BY group
HAVING n_dist_attr > 1
"""
)
assert out.to_dict(False) == {"group": ["c"], "n_dist_attr": [2]}


def test_sql_join(foods_ipc_path: Path) -> None:
c = pl.SQLContext()
Expand Down

0 comments on commit 726cc99

Please sign in to comment.