From 49a4099adc93780eeffef8204af36559eab50a9f Mon Sep 17 00:00:00 2001 From: tobymao Date: Tue, 20 Sep 2022 09:45:35 -0700 Subject: [PATCH] add mysql group_concat separator --- sqlglot/dialects/mysql.py | 11 ++++++++++ sqlglot/expressions.py | 4 ++++ sqlglot/parser.py | 42 +++++++++++++----------------------- sqlglot/tokens.py | 13 +---------- tests/dialects/test_mysql.py | 16 ++++++++++++++ 5 files changed, 47 insertions(+), 39 deletions(-) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index a2d39a020f..0ffed2491f 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -106,6 +106,7 @@ class Tokenizer(Tokenizer): KEYWORDS = { **Tokenizer.KEYWORDS, + "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -160,6 +161,15 @@ class Parser(Parser): "STR_TO_DATE": _str_to_date, } + FUNCTION_PARSERS = { + **Parser.FUNCTION_PARSERS, + "GROUP_CONCAT": lambda self: self.expression( + exp.GroupConcat, + this=self._parse_lambda(), + separator=self._match(TokenType.SEPARATOR) and self._parse_field(), + ), + } + class Generator(Generator): NULL_ORDERING_SUPPORTED = False @@ -173,6 +183,7 @@ class Generator(Generator): exp.DateAdd: _date_add_sql("ADD"), exp.DateSub: _date_add_sql("SUB"), exp.DateTrunc: _date_trunc_sql, + exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, exp.Trim: _trim_sql, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a2bc72ee8e..0bfeb9d7a9 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2199,6 +2199,10 @@ class Greatest(Func): is_var_len_args = True +class GroupConcat(Func): + arg_types = {"this": True, "separator": False} + + class If(Func): arg_types = {"this": True, "true": True, "false": False} diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 69a1c342fe..ab13bc463a 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -108,7 +108,6 @@ class Parser: TokenType.COLLATE, TokenType.COMMIT, TokenType.CONSTRAINT, - TokenType.CONVERT, TokenType.DEFAULT, TokenType.DELETE, TokenType.ENGINE, @@ -155,20 +154,13 @@ class Parser: *TYPE_TOKENS, } - CASTS = { - TokenType.CAST, - TokenType.TRY_CAST, - } - TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} FUNC_TOKENS = { - TokenType.CONVERT, TokenType.CURRENT_DATE, TokenType.CURRENT_DATETIME, TokenType.CURRENT_TIMESTAMP, TokenType.CURRENT_TIME, - TokenType.EXTRACT, TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, @@ -177,8 +169,6 @@ class Parser: TokenType.PRIMARY_KEY, TokenType.REPLACE, TokenType.ROW, - TokenType.SUBSTRING, - TokenType.TRIM, TokenType.UNNEST, TokenType.VAR, TokenType.LEFT, @@ -187,7 +177,6 @@ class Parser: TokenType.DATETIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, - *CASTS, *NESTED_TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -373,16 +362,12 @@ class Parser: } FUNCTION_PARSERS = { - TokenType.CONVERT: lambda self, _: self._parse_convert(), - TokenType.EXTRACT: lambda self, _: self._parse_extract(), - TokenType.SUBSTRING: lambda self, _: self._parse_substring(), - TokenType.TRIM: lambda self, _: self._parse_trim(), - **{ - token_type: lambda self, token_type: self._parse_cast( - self.STRICT_CAST and token_type == TokenType.CAST - ) - for token_type in CASTS - }, + "CONVERT": lambda self: self._parse_convert(), + "EXTRACT": lambda self: self._parse_extract(), + "SUBSTRING": lambda self: self._parse_substring(), + "TRIM": lambda self: self._parse_trim(), + "CAST": lambda self: self._parse_cast(self.STRICT_CAST), + "TRY_CAST": lambda self: self._parse_cast(False), } QUERY_MODIFIER_PARSERS = { @@ -1653,13 +1638,16 @@ def _parse_function(self): if token_type not in self.FUNC_TOKENS: return None - if self._match_set(self.FUNCTION_PARSERS): - self._advance() - this = self.FUNCTION_PARSERS[token_type](self, token_type) + this = self._curr.text + upper = this.upper() + self._advance(2) + + parser = self.FUNCTION_PARSERS.get(upper) + + if parser: + this = parser(self) else: subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) - this = self._curr.text - self._advance(2) if subquery_predicate and self._curr.token_type in ( TokenType.SELECT, @@ -1669,7 +1657,7 @@ def _parse_function(self): self._match_r_paren() return this - function = self.FUNCTIONS.get(this.upper()) + function = self.FUNCTIONS.get(upper) args = self._parse_csv(self._parse_lambda) if function: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 7fa150f089..0f1930b56b 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -98,7 +98,6 @@ class TokenType(AutoName): CACHE = auto() CALL = auto() CASE = auto() - CAST = auto() CHARACTER_SET = auto() CHECK = auto() CLUSTER_BY = auto() @@ -106,7 +105,6 @@ class TokenType(AutoName): COMMENT = auto() COMMIT = auto() CONSTRAINT = auto() - CONVERT = auto() CREATE = auto() CROSS = auto() CUBE = auto() @@ -129,7 +127,6 @@ class TokenType(AutoName): EXCEPT = auto() EXISTS = auto() EXPLAIN = auto() - EXTRACT = auto() FALSE = auto() FETCH = auto() FILTER = auto() @@ -208,24 +205,22 @@ class TokenType(AutoName): ROWS = auto() SCHEMA_COMMENT = auto() SELECT = auto() + SEPARATOR = auto() SET = auto() SHOW = auto() SOME = auto() SORT_BY = auto() STORED = auto() STRUCT = auto() - SUBSTRING = auto() TABLE_FORMAT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() TIME = auto() TOP = auto() THEN = auto() - TRIM = auto() TRUE = auto() TRAILING = auto() TRUNCATE = auto() - TRY_CAST = auto() UNBOUNDED = auto() UNCACHE = auto() UNION = auto() @@ -388,7 +383,6 @@ class Tokenizer(metaclass=_Tokenizer): "CACHE": TokenType.CACHE, "UNCACHE": TokenType.UNCACHE, "CASE": TokenType.CASE, - "CAST": TokenType.CAST, "CHARACTER SET": TokenType.CHARACTER_SET, "CHECK": TokenType.CHECK, "CLUSTER BY": TokenType.CLUSTER_BY, @@ -396,7 +390,6 @@ class Tokenizer(metaclass=_Tokenizer): "COMMENT": TokenType.SCHEMA_COMMENT, "COMMIT": TokenType.COMMIT, "CONSTRAINT": TokenType.CONSTRAINT, - "CONVERT": TokenType.CONVERT, "CREATE": TokenType.CREATE, "CROSS": TokenType.CROSS, "CUBE": TokenType.CUBE, @@ -417,7 +410,6 @@ class Tokenizer(metaclass=_Tokenizer): "EXCEPT": TokenType.EXCEPT, "EXISTS": TokenType.EXISTS, "EXPLAIN": TokenType.EXPLAIN, - "EXTRACT": TokenType.EXTRACT, "FALSE": TokenType.FALSE, "FETCH": TokenType.FETCH, "FILTER": TokenType.FILTER, @@ -492,7 +484,6 @@ class Tokenizer(metaclass=_Tokenizer): "SOME": TokenType.SOME, "SORT BY": TokenType.SORT_BY, "STORED": TokenType.STORED, - "SUBSTRING": TokenType.SUBSTRING, "TABLE": TokenType.TABLE, "TABLE_FORMAT": TokenType.TABLE_FORMAT, "TBLPROPERTIES": TokenType.PROPERTIES, @@ -502,9 +493,7 @@ class Tokenizer(metaclass=_Tokenizer): "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, "TRAILING": TokenType.TRAILING, - "TRIM": TokenType.TRIM, "TRUNCATE": TokenType.TRUNCATE, - "TRY_CAST": TokenType.TRY_CAST, "UNBOUNDED": TokenType.UNBOUNDED, "UNION": TokenType.UNION, "UNNEST": TokenType.UNNEST, diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index e32e151d01..acc91b5db1 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -81,3 +81,19 @@ def test_hash_comments(self): "mysql": "SELECT 1", }, ) + + def test_mysql(self): + self.validate_all( + "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", + write={ + "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')", + "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", + }, + ) + self.validate_all( + "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", + write={ + "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", + "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')", + }, + )