Skip to content

Commit

Permalink
Add support for STRING_AGG <-> GROUP_CONCAT transpilation (#774)
Browse files Browse the repository at this point in the history
* Add support for STRING_AGG <-> GROUP_CONCAT transpilation

* Fixup

* Add tests, fix group_concat/string_agg for sqlite/postgres

* Format

* Fix test_mysql
  • Loading branch information
georgesittas authored Nov 28, 2022
1 parent 39194d4 commit d5846b6
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 11 deletions.
23 changes: 23 additions & 0 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"


def _string_agg_sql(self, expression):
expression = expression.copy()

this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("PostgreSQL STRING_AGG doesn't support DISTINCT.")
this = distinct.expressions[0]
distinct.pop()

order = ""
if isinstance(expression.this, exp.Order):
if expression.this.this:
this = expression.this.this
expression.this.this.pop()
order = self.sql(expression.this) # Order has a leading space

separator = expression.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}{order})"


def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
Expand Down Expand Up @@ -274,4 +296,5 @@ class Generator(generator.Generator):
exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
}
18 changes: 18 additions & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@
from sqlglot.tokens import TokenType


# https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression):
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
this = distinct.expressions[0]
distinct = "DISTINCT "

if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
if expression.this.this and not distinct:
this = expression.this.this

separator = expression.args.get("separator")
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"


class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
Expand Down Expand Up @@ -62,6 +79,7 @@ class Generator(generator.Generator):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql,
}

def transaction_sql(self, expression):
Expand Down
41 changes: 33 additions & 8 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"mm": "%B",
"m": "%B",
}

DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
Expand All @@ -37,11 +38,12 @@


DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")

# N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}


def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
this=seq_get(args, 1),
Expand All @@ -58,7 +60,7 @@ def _format_time(args):
return _format_time


def parse_format(args):
def _parse_format(args):
fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
Expand All @@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e):
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"


def generate_format_sql(self, e):
def _format_sql(self, e):
fmt = (
e.args["format"]
if isinstance(e, exp.NumberToStr)
Expand All @@ -87,6 +89,28 @@ def generate_format_sql(self, e):
return f"FORMAT({self.format_args(e.this, fmt)})"


def _string_agg_sql(self, e):
e = e.copy()

this = e.this
distinct = e.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.expressions[0]
distinct.pop()

order = ""
if isinstance(e.this, exp.Order):
if e.this.this:
this = e.this.this
e.this.this.pop()
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})"

separator = e.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"


class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
Expand Down Expand Up @@ -228,14 +252,14 @@ class Parser(parser.Parser):
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": tsql_format_time_lambda(exp.TimeToStr),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list,
"IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": parse_format,
"FORMAT": _parse_format,
}

VAR_LENGTH_DATATYPES = {
Expand Down Expand Up @@ -298,6 +322,7 @@ class Generator(generator.Generator):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: generate_format_sql,
exp.TimeToStr: generate_format_sql,
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
}
21 changes: 21 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ class Parser(metaclass=_Parser):
"TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"TRY_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
}

QUERY_MODIFIER_PARSERS = {
Expand Down Expand Up @@ -2259,6 +2260,26 @@ def _parse_cast(self, strict):

return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)

def _parse_string_agg(self):
# Parses <expression> , <separator>
args = self._parse_csv(self._parse_conjunction)

index = self._index
self._match(TokenType.R_PAREN)

# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
if not self._match(TokenType.WITHIN_GROUP):
self._retreat(index)
this = exp.GroupConcat.from_arg_list(args)
self.validate_expression(this, args)
return this

self._match(TokenType.L_PAREN)
order = self._parse_order(this=seq_get(args, 0))
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))

def _parse_convert(self, strict):
this = self._parse_column()
if self._match(TokenType.USING):
Expand Down
17 changes: 15 additions & 2 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,27 @@ def test_mysql(self):
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
"sqlite": "GROUP_CONCAT(DISTINCT x)",
"tsql": "STRING_AGG(x, ',') WITHIN GROUP (ORDER BY y DESC)",
"postgres": "STRING_AGG(x, ',' ORDER BY y DESC NULLS LAST)",
},
)
self.validate_all(
"GROUP_CONCAT(x ORDER BY y SEPARATOR z)",
write={
"mysql": "GROUP_CONCAT(x ORDER BY y SEPARATOR z)",
"sqlite": "GROUP_CONCAT(x, z)",
"tsql": "STRING_AGG(x, z) WITHIN GROUP (ORDER BY y)",
"postgres": "STRING_AGG(x, z ORDER BY y NULLS FIRST)",
},
)
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
"sqlite": "GROUP_CONCAT(DISTINCT x, '')",
"tsql": "STRING_AGG(x, '') WITHIN GROUP (ORDER BY y DESC)",
"postgres": "STRING_AGG(x, '' ORDER BY y DESC NULLS LAST)",
},
)
self.validate_identity(
Expand Down
28 changes: 27 additions & 1 deletion tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,40 @@ def test_tsql(self):
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
},
)

self.validate_all(
"CONVERT(INT, CONVERT(NUMERIC, '444.75'))",
write={
"mysql": "CAST(CAST('444.75' AS DECIMAL) AS INT)",
"tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)",
},
)
self.validate_all(
"STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)",
write={
"tsql": "STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)",
"mysql": "GROUP_CONCAT(x ORDER BY z DESC SEPARATOR y)",
"sqlite": "GROUP_CONCAT(x, y)",
"postgres": "STRING_AGG(x, y ORDER BY z DESC NULLS LAST)",
},
)
self.validate_all(
"STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)",
write={
"tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)",
"mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')",
"sqlite": "GROUP_CONCAT(x, '|')",
"postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)",
},
)
self.validate_all(
"STRING_AGG(x, '|')",
write={
"tsql": "STRING_AGG(x, '|')",
"mysql": "GROUP_CONCAT(x SEPARATOR '|')",
"sqlite": "GROUP_CONCAT(x, '|')",
"postgres": "STRING_AGG(x, '|')",
},
)

def test_types(self):
self.validate_identity("CAST(x AS XML)")
Expand Down

0 comments on commit d5846b6

Please sign in to comment.