From 8a32dd29b52a548c4941ba76642420a9e189dfbd Mon Sep 17 00:00:00 2001 From: tobymao Date: Fri, 18 Nov 2022 22:44:56 -0800 Subject: [PATCH] support multiple aggregations in an expression --- sqlglot/executor/python.py | 2 -- sqlglot/planner.py | 17 +++++++++-------- tests/test_executor.py | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index cb2543cf33..b0fe6d1fef 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -311,11 +311,9 @@ def aggregate(self, step, context): def sort(self, step, context): projections = self.generate_tuple(step.projections) - projection_columns = [p.alias_or_name for p in step.projections] all_columns = list(context.columns) + projection_columns sink = self.table(all_columns) - for reader, ctx in context: sink.append(reader.row + ctx.eval_tuple(projections)) diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 51db2d4638..3e96ea5707 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -131,17 +131,18 @@ def from_expression( sequence = itertools.count() for e in expression.expressions: - aggregation = e.find(exp.AggFunc) + aggs = list(e.find_all(exp.AggFunc)) - if aggregation: + if aggs: projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) aggregations.append(e) - for operand in aggregation.unnest_operands(): - if isinstance(operand, exp.Column): - continue - if operand not in operands: - operands[operand] = f"_a_{next(sequence)}" - operand.replace(exp.column(operands[operand], quoted=True)) + for agg in aggs: + for operand in agg.unnest_operands(): + if isinstance(operand, exp.Column): + continue + if operand not in operands: + operands[operand] = f"_a_{next(sequence)}" + operand.replace(exp.column(operands[operand], quoted=True)) else: projections.append(e) diff --git a/tests/test_executor.py b/tests/test_executor.py index 2c4d7cd43f..b596a7d6d7 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -74,7 +74,7 @@ def to_csv(expression): ) return expression - for i, (sql, _) in enumerate(self.sqls[0:7]): + for i, (sql, _) in enumerate(self.sqls[0:10]): with self.subTest(f"tpch-h {i + 1}"): a = self.cached_execute(sql) sql = parse_one(sql).transform(to_csv).sql(pretty=True)