Skip to content

Commit

Permalink
support multiple aggregations in an expression
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Nov 19, 2022
1 parent 0349b34 commit 8a32dd2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
2 changes: 0 additions & 2 deletions sqlglot/executor/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,9 @@ def aggregate(self, step, context):

def sort(self, step, context):
projections = self.generate_tuple(step.projections)

projection_columns = [p.alias_or_name for p in step.projections]
all_columns = list(context.columns) + projection_columns
sink = self.table(all_columns)

for reader, ctx in context:
sink.append(reader.row + ctx.eval_tuple(projections))

Expand Down
17 changes: 9 additions & 8 deletions sqlglot/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,18 @@ def from_expression(
sequence = itertools.count()

for e in expression.expressions:
aggregation = e.find(exp.AggFunc)
aggs = list(e.find_all(exp.AggFunc))

if aggregation:
if aggs:
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
for operand in aggregation.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True))
for agg in aggs:
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True))
else:
projections.append(e)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def to_csv(expression):
)
return expression

for i, (sql, _) in enumerate(self.sqls[0:7]):
for i, (sql, _) in enumerate(self.sqls[0:10]):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
Expand Down

0 comments on commit 8a32dd2

Please sign in to comment.