Skip to content

Commit

Permalink
add mysql group_concat separator
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Sep 20, 2022
1 parent 45603f1 commit 49a4099
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 39 deletions.
11 changes: 11 additions & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class Tokenizer(Tokenizer):

KEYWORDS = {
**Tokenizer.KEYWORDS,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
Expand Down Expand Up @@ -160,6 +161,15 @@ class Parser(Parser):
"STR_TO_DATE": _str_to_date,
}

FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
),
}

class Generator(Generator):
NULL_ORDERING_SUPPORTED = False

Expand All @@ -173,6 +183,7 @@ class Generator(Generator):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2199,6 +2199,10 @@ class Greatest(Func):
is_var_len_args = True


class GroupConcat(Func):
arg_types = {"this": True, "separator": False}


class If(Func):
arg_types = {"this": True, "true": True, "false": False}

Expand Down
42 changes: 15 additions & 27 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ class Parser:
TokenType.COLLATE,
TokenType.COMMIT,
TokenType.CONSTRAINT,
TokenType.CONVERT,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.ENGINE,
Expand Down Expand Up @@ -155,20 +154,13 @@ class Parser:
*TYPE_TOKENS,
}

CASTS = {
TokenType.CAST,
TokenType.TRY_CAST,
}

TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}

FUNC_TOKENS = {
TokenType.CONVERT,
TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP,
TokenType.CURRENT_TIME,
TokenType.EXTRACT,
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
Expand All @@ -177,8 +169,6 @@ class Parser:
TokenType.PRIMARY_KEY,
TokenType.REPLACE,
TokenType.ROW,
TokenType.SUBSTRING,
TokenType.TRIM,
TokenType.UNNEST,
TokenType.VAR,
TokenType.LEFT,
Expand All @@ -187,7 +177,6 @@ class Parser:
TokenType.DATETIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
*CASTS,
*NESTED_TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
Expand Down Expand Up @@ -373,16 +362,12 @@ class Parser:
}

FUNCTION_PARSERS = {
TokenType.CONVERT: lambda self, _: self._parse_convert(),
TokenType.EXTRACT: lambda self, _: self._parse_extract(),
TokenType.SUBSTRING: lambda self, _: self._parse_substring(),
TokenType.TRIM: lambda self, _: self._parse_trim(),
**{
token_type: lambda self, token_type: self._parse_cast(
self.STRICT_CAST and token_type == TokenType.CAST
)
for token_type in CASTS
},
"CONVERT": lambda self: self._parse_convert(),
"EXTRACT": lambda self: self._parse_extract(),
"SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"TRY_CAST": lambda self: self._parse_cast(False),
}

QUERY_MODIFIER_PARSERS = {
Expand Down Expand Up @@ -1653,13 +1638,16 @@ def _parse_function(self):
if token_type not in self.FUNC_TOKENS:
return None

if self._match_set(self.FUNCTION_PARSERS):
self._advance()
this = self.FUNCTION_PARSERS[token_type](self, token_type)
this = self._curr.text
upper = this.upper()
self._advance(2)

parser = self.FUNCTION_PARSERS.get(upper)

if parser:
this = parser(self)
else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
this = self._curr.text
self._advance(2)

if subquery_predicate and self._curr.token_type in (
TokenType.SELECT,
Expand All @@ -1669,7 +1657,7 @@ def _parse_function(self):
self._match_r_paren()
return this

function = self.FUNCTIONS.get(this.upper())
function = self.FUNCTIONS.get(upper)
args = self._parse_csv(self._parse_lambda)

if function:
Expand Down
13 changes: 1 addition & 12 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,13 @@ class TokenType(AutoName):
CACHE = auto()
CALL = auto()
CASE = auto()
CAST = auto()
CHARACTER_SET = auto()
CHECK = auto()
CLUSTER_BY = auto()
COLLATE = auto()
COMMENT = auto()
COMMIT = auto()
CONSTRAINT = auto()
CONVERT = auto()
CREATE = auto()
CROSS = auto()
CUBE = auto()
Expand All @@ -129,7 +127,6 @@ class TokenType(AutoName):
EXCEPT = auto()
EXISTS = auto()
EXPLAIN = auto()
EXTRACT = auto()
FALSE = auto()
FETCH = auto()
FILTER = auto()
Expand Down Expand Up @@ -208,24 +205,22 @@ class TokenType(AutoName):
ROWS = auto()
SCHEMA_COMMENT = auto()
SELECT = auto()
SEPARATOR = auto()
SET = auto()
SHOW = auto()
SOME = auto()
SORT_BY = auto()
STORED = auto()
STRUCT = auto()
SUBSTRING = auto()
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
TIME = auto()
TOP = auto()
THEN = auto()
TRIM = auto()
TRUE = auto()
TRAILING = auto()
TRUNCATE = auto()
TRY_CAST = auto()
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
Expand Down Expand Up @@ -388,15 +383,13 @@ class Tokenizer(metaclass=_Tokenizer):
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
"CAST": TokenType.CAST,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"CONSTRAINT": TokenType.CONSTRAINT,
"CONVERT": TokenType.CONVERT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
"CUBE": TokenType.CUBE,
Expand All @@ -417,7 +410,6 @@ class Tokenizer(metaclass=_Tokenizer):
"EXCEPT": TokenType.EXCEPT,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
"EXTRACT": TokenType.EXTRACT,
"FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER,
Expand Down Expand Up @@ -492,7 +484,6 @@ class Tokenizer(metaclass=_Tokenizer):
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"STORED": TokenType.STORED,
"SUBSTRING": TokenType.SUBSTRING,
"TABLE": TokenType.TABLE,
"TABLE_FORMAT": TokenType.TABLE_FORMAT,
"TBLPROPERTIES": TokenType.PROPERTIES,
Expand All @@ -502,9 +493,7 @@ class Tokenizer(metaclass=_Tokenizer):
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,
"TRIM": TokenType.TRIM,
"TRUNCATE": TokenType.TRUNCATE,
"TRY_CAST": TokenType.TRY_CAST,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNNEST": TokenType.UNNEST,
Expand Down
16 changes: 16 additions & 0 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,19 @@ def test_hash_comments(self):
"mysql": "SELECT 1",
},
)

def test_mysql(self):
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
},
)
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
write={
"mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
"sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
},
)

0 comments on commit 49a4099

Please sign in to comment.