Skip to content

Commit

Permalink
Fix postgres string_agg parsing & array generation (#778)
Browse files Browse the repository at this point in the history
* Fix string_agg so that it parses postgres correctly

* Fixup

* PR feedback
  • Loading branch information
georgesittas committed Nov 29, 2022
1 parent d5846b6 commit 54a7f28
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 21 deletions.
20 changes: 7 additions & 13 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,16 @@ def _trim_sql(self, expression):

def _string_agg_sql(self, expression):
expression = expression.copy()

this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("PostgreSQL STRING_AGG doesn't support DISTINCT.")
this = distinct.expressions[0]
distinct.pop()
separator = expression.args.get("separator") or exp.Literal.string(",")

order = ""
if isinstance(expression.this, exp.Order):
if expression.this.this:
this = expression.this.this
expression.this.this.pop()
this = expression.this
if isinstance(this, exp.Order):
if this.this:
this = this.this
this.pop()
order = self.sql(expression.this) # Order has a leading space

separator = expression.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}{order})"


Expand Down Expand Up @@ -297,4 +290,5 @@ class Generator(generator.Generator):
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
}
2 changes: 1 addition & 1 deletion sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _string_agg_sql(self, e):
if e.this.this:
this = e.this.this
e.this.this.pop()
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})"
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space

separator = e.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
Expand Down
17 changes: 12 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,11 +2261,18 @@ def _parse_cast(self, strict):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)

def _parse_string_agg(self):
# Parses <expression> , <separator>
args = self._parse_csv(self._parse_conjunction)
if self._match(TokenType.DISTINCT):
args = self._parse_csv(self._parse_conjunction)
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
else:
args = self._parse_csv(self._parse_conjunction)
expression = seq_get(args, 0)

index = self._index
self._match(TokenType.R_PAREN)
if not self._match(TokenType.R_PAREN):
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))

# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
Expand All @@ -2276,8 +2283,8 @@ def _parse_string_agg(self):
self.validate_expression(this, args)
return this

self._match(TokenType.L_PAREN)
order = self._parse_order(this=seq_get(args, 0))
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))

def _parse_convert(self, strict):
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_mysql(self):
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
"sqlite": "GROUP_CONCAT(DISTINCT x)",
"tsql": "STRING_AGG(x, ',') WITHIN GROUP (ORDER BY y DESC)",
"postgres": "STRING_AGG(x, ',' ORDER BY y DESC NULLS LAST)",
"postgres": "STRING_AGG(DISTINCT x, ',' ORDER BY y DESC NULLS LAST)",
},
)
self.validate_all(
Expand All @@ -199,7 +199,7 @@ def test_mysql(self):
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
"sqlite": "GROUP_CONCAT(DISTINCT x, '')",
"tsql": "STRING_AGG(x, '') WITHIN GROUP (ORDER BY y DESC)",
"postgres": "STRING_AGG(x, '' ORDER BY y DESC NULLS LAST)",
"postgres": "STRING_AGG(DISTINCT x, '' ORDER BY y DESC NULLS LAST)",
},
)
self.validate_identity(
Expand Down
6 changes: 6 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def test_ddl(self):
)

def test_postgres(self):
self.validate_identity("SELECT ARRAY[1, 2, 3]")
self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)")
self.validate_identity("STRING_AGG(x, y)")
self.validate_identity("STRING_AGG(x, ',' ORDER BY y)")
self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)")
self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
Expand Down

0 comments on commit 54a7f28

Please sign in to comment.