From 54a7f2873b210d8e97804a6e7678a47462f63912 Mon Sep 17 00:00:00 2001 From: Jo <46752250+GeorgeSittas@users.noreply.github.com> Date: Tue, 29 Nov 2022 04:27:19 +0200 Subject: [PATCH] Fix postgres string_agg parsing & array generation (#778) * Fix string_agg so that it parses postgres correctly * Fixup * PR feedback --- sqlglot/dialects/postgres.py | 20 +++++++------------- sqlglot/dialects/tsql.py | 2 +- sqlglot/parser.py | 17 ++++++++++++----- tests/dialects/test_mysql.py | 4 ++-- tests/dialects/test_postgres.py | 6 ++++++ 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index f535976707..1cb5025062 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -76,23 +76,16 @@ def _trim_sql(self, expression): def _string_agg_sql(self, expression): expression = expression.copy() - - this = expression.this - distinct = expression.find(exp.Distinct) - if distinct: - # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression - self.unsupported("PostgreSQL STRING_AGG doesn't support DISTINCT.") - this = distinct.expressions[0] - distinct.pop() + separator = expression.args.get("separator") or exp.Literal.string(",") order = "" - if isinstance(expression.this, exp.Order): - if expression.this.this: - this = expression.this.this - expression.this.this.pop() + this = expression.this + if isinstance(this, exp.Order): + if this.this: + this = this.this + this.pop() order = self.sql(expression.this) # Order has a leading space - separator = expression.args.get("separator") or exp.Literal.string(",") return f"STRING_AGG({self.format_args(this, separator)}{order})" @@ -297,4 +290,5 @@ class Generator(generator.Generator): exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.DataType: _datatype_sql, exp.GroupConcat: _string_agg_sql, + exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index fbbbab92f7..07ce38b3f2 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -105,7 +105,7 @@ def _string_agg_sql(self, e): if e.this.this: this = e.this.this e.this.this.pop() - order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" + order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space separator = e.args.get("separator") or exp.Literal.string(",") return f"STRING_AGG({self.format_args(this, separator)}){order}" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index fec2c3b3fa..bdf0d2d3bc 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -2261,11 +2261,18 @@ def _parse_cast(self, strict): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) def _parse_string_agg(self): - # Parses , - args = self._parse_csv(self._parse_conjunction) + if self._match(TokenType.DISTINCT): + args = self._parse_csv(self._parse_conjunction) + expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)]) + else: + args = self._parse_csv(self._parse_conjunction) + expression = seq_get(args, 0) index = self._index - self._match(TokenType.R_PAREN) + if not self._match(TokenType.R_PAREN): + # postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]]) + order = self._parse_order(this=expression) + return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY [ASC | DESC]). # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that @@ -2276,8 +2283,8 @@ def _parse_string_agg(self): self.validate_expression(this, args) return this - self._match(TokenType.L_PAREN) - order = self._parse_order(this=seq_get(args, 0)) + self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller) + order = self._parse_order(this=expression) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) def _parse_convert(self, strict): diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 246b03a7e5..5064dbec84 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -181,7 +181,7 @@ def test_mysql(self): "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')", "sqlite": "GROUP_CONCAT(DISTINCT x)", "tsql": "STRING_AGG(x, ',') WITHIN GROUP (ORDER BY y DESC)", - "postgres": "STRING_AGG(x, ',' ORDER BY y DESC NULLS LAST)", + "postgres": "STRING_AGG(DISTINCT x, ',' ORDER BY y DESC NULLS LAST)", }, ) self.validate_all( @@ -199,7 +199,7 @@ def test_mysql(self): "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", "sqlite": "GROUP_CONCAT(DISTINCT x, '')", "tsql": "STRING_AGG(x, '') WITHIN GROUP (ORDER BY y DESC)", - "postgres": "STRING_AGG(x, '' ORDER BY y DESC NULLS LAST)", + "postgres": "STRING_AGG(DISTINCT x, '' ORDER BY y DESC NULLS LAST)", }, ) self.validate_identity( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index bef4b4cbfb..cd6117c6ba 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -63,6 +63,12 @@ def test_ddl(self): ) def test_postgres(self): + self.validate_identity("SELECT ARRAY[1, 2, 3]") + self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)") + self.validate_identity("STRING_AGG(x, y)") + self.validate_identity("STRING_AGG(x, ',' ORDER BY y)") + self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)") + self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)") self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") self.validate_identity( "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"