diff --git a/README.md b/README.md index 592b12e75a..218d86cf0b 100644 --- a/README.md +++ b/README.md @@ -148,9 +148,9 @@ print(sqlglot.transpile(sql, read='mysql', pretty=True)[0]) */ SELECT tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, - CAST(x AS INT), -- comment 3 - y -- comment 4 -FROM bar /* comment 5 */, tbl /* comment 6*/ + CAST(x AS INT), /* comment 3 */ + y /* comment 4 */ +FROM bar /* comment 5 */, tbl /* comment 6 */ ``` diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index ae26db945b..81e36a88f0 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -43,14 +43,14 @@ class Expression(metaclass=_Expression): key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type", "comment") + __slots__ = ("args", "parent", "arg_key", "type", "comments") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None self.type = None - self.comment = None + self.comments = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -88,19 +88,6 @@ def text(self, key): return field.this return "" - def find_comment(self, key: str) -> str: - """ - Finds the comment that is attached to a specified child node. - - Args: - key: the key of the target child node (e.g. "this", "expression", etc). - - Returns: - The comment attached to the child node, or the empty string, if it doesn't exist. - """ - field = self.args.get(key) - return field.comment if isinstance(field, Expression) else "" - @property def is_string(self): return isinstance(self, Literal) and self.args["is_string"] @@ -137,7 +124,7 @@ def alias_or_name(self): def __deepcopy__(self, memo): copy = self.__class__(**deepcopy(self.args)) - copy.comment = self.comment + copy.comments = self.comments copy.type = self.type return copy @@ -369,7 +356,7 @@ def to_s(self, hide_missing: bool = True, level: int = 0) -> str: ) for k, vs in self.args.items() } - args["comment"] = self.comment + args["comments"] = self.comments args["type"] = self.type args = {k: v for k, v in args.items() if v or not hide_missing} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 156e6abfcb..2792dada09 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import re import typing as t from sqlglot import exp @@ -12,8 +11,6 @@ logger = logging.getLogger("sqlglot") -NEWLINE_RE = re.compile("\r\n?|\n") - class Generator: """ @@ -226,25 +223,24 @@ def sep(self, sep=" "): def seg(self, sql, sep=" "): return f"{self.sep(sep)}{sql}" - def maybe_comment(self, sql, expression, single_line=False): - comment = expression.comment if self._comments else None - - if not comment: - return sql - + def pad_comment(self, comment): comment = " " + comment if comment[0].strip() else comment comment = comment + " " if comment[-1].strip() else comment + return comment - if isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return f"/*{comment}*/{self.sep()}{sql}" + def maybe_comment(self, sql, expression): + comments = expression.comments if self._comments else None - if not self.pretty: - return f"{sql} /*{comment}*/" + if not comments: + return sql - if not NEWLINE_RE.search(comment): - return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" + sep = "\n" if self.pretty else " " + comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments) + + if isinstance(expression, self.WITH_SEPARATED_COMMENTS): + return f"{comments}{self.sep()}{sql}" - return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/" + return f"{sql} {comments}" def wrap(self, expression): this_sql = self.indent( @@ -1337,15 +1333,15 @@ def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): result_sqls = [] for i, e in enumerate(expressions): sql = self.sql(e, comment=False) - comment = self.maybe_comment("", e, single_line=True) + comments = self.maybe_comment("", e) if self.pretty: if self._leading_comma: - result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}") + result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}") else: - result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}") + result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}") else: - result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}") + result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) return self.indent(result_sqls, skip_first=False) if indent else result_sqls diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c480a1196a..0656b5f534 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -567,7 +567,7 @@ class Parser(metaclass=_Parser): "_curr", "_next", "_prev", - "_prev_comment", + "_prev_comments", "_show_trie", "_set_trie", ) @@ -600,7 +600,7 @@ def reset(self): self._curr = None self._next = None self._prev = None - self._prev_comment = None + self._prev_comments = None def parse(self, raw_tokens, sql=None): """ @@ -694,9 +694,9 @@ def raise_error(self, message, token=None): def expression(self, exp_class, **kwargs): instance = exp_class(**kwargs) - if self._prev_comment: - instance.comment = self._prev_comment - self._prev_comment = None + if self._prev_comments: + instance.comments = self._prev_comments + self._prev_comments = None self.validate_expression(instance) return instance @@ -739,10 +739,10 @@ def _advance(self, times=1): self._next = seq_get(self._tokens, self._index + 1) if self._index > 0: self._prev = self._tokens[self._index - 1] - self._prev_comment = self._prev.comment + self._prev_comments = self._prev.comments else: self._prev = None - self._prev_comment = None + self._prev_comments = None def _retreat(self, index): self._advance(index - self._index) @@ -1088,7 +1088,7 @@ def _parse_select(self, nested=False, table=False): self.raise_error(f"{this.key} does not support CTE") this = cte elif self._match(TokenType.SELECT): - comment = self._prev_comment + comments = self._prev_comments hint = self._parse_hint() all_ = self._match(TokenType.ALL) @@ -1113,7 +1113,7 @@ def _parse_select(self, nested=False, table=False): expressions=expressions, limit=limit, ) - this.comment = comment + this.comments = comments from_ = self._parse_from() if from_: this.set("from", from_) @@ -1872,7 +1872,7 @@ def _parse_primary(self): return exp.Literal.number(f"0.{self._prev.text}") if self._match(TokenType.L_PAREN): - comment = self._prev_comment + comments = self._prev_comments query = self._parse_select() if query: @@ -1892,8 +1892,8 @@ def _parse_primary(self): this = self.expression(exp.Tuple, expressions=expressions) else: this = self.expression(exp.Paren, this=this) - if comment: - this.comment = comment + if comments: + this.comments = comments return this return None @@ -2160,7 +2160,7 @@ def _parse_bracket(self, this): if not self._match(TokenType.R_BRACKET): self.raise_error("Expected ]") - this.comment = self._prev_comment + this.comments = self._prev_comments return self._parse_bracket(this) def _parse_case(self): @@ -2482,8 +2482,8 @@ def _parse_csv(self, parse_method, sep=TokenType.COMMA): items = [parse_result] if parse_result is not None else [] while self._match(sep): - if parse_result and self._prev_comment is not None: - parse_result.comment = self._prev_comment + if parse_result and self._prev_comments: + parse_result.comments = self._prev_comments parse_result = parse_method() if parse_result is not None: @@ -2622,14 +2622,14 @@ def _match_pair(self, token_type_a, token_type_b, advance=True): def _match_l_paren(self, expression=None): if not self._match(TokenType.L_PAREN): self.raise_error("Expecting (") - if expression and self._prev_comment: - expression.comment = self._prev_comment + if expression and self._prev_comments: + expression.comments = self._prev_comments def _match_r_paren(self, expression=None): if not self._match(TokenType.R_PAREN): self.raise_error("Expecting )") - if expression and self._prev_comment: - expression.comment = self._prev_comment + if expression and self._prev_comments: + expression.comments = self._prev_comments def _match_texts(self, texts): if self._curr and self._curr.text.upper() in texts: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index fa006268a5..dda0e012b6 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -294,7 +294,7 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col", "comment") + __slots__ = ("token_type", "text", "line", "col", "comments") @classmethod def number(cls, number: int) -> Token: @@ -322,13 +322,13 @@ def __init__( text: str, line: int = 1, col: int = 1, - comment: t.Optional[str] = None, + comments: t.List[str] = [], ) -> None: self.token_type = token_type self.text = text self.line = line self.col = max(col - len(text), 1) - self.comment = comment + self.comments = comments def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) @@ -690,12 +690,12 @@ class Tokenizer(metaclass=_Tokenizer): "_current", "_line", "_col", - "_comment", + "_comments", "_char", "_end", "_peek", "_prev_token_line", - "_prev_token_comment", + "_prev_token_comments", "_prev_token_type", "_replace_backslash", ) @@ -712,13 +712,13 @@ def reset(self) -> None: self._current = 0 self._line = 1 self._col = 1 - self._comment = None + self._comments: t.List[str] = [] self._char = None self._end = None self._peek = None self._prev_token_line = -1 - self._prev_token_comment = None + self._prev_token_comments: t.List[str] = [] self._prev_token_type = None def tokenize(self, sql: str) -> t.List[Token]: @@ -771,7 +771,7 @@ def _text(self) -> str: def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._prev_token_line = self._line - self._prev_token_comment = self._comment + self._prev_token_comments = self._comments self._prev_token_type = token_type # type: ignore self.tokens.append( Token( @@ -779,10 +779,10 @@ def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._text if text is None else text, self._line, self._col, - self._comment, + self._comments, ) ) - self._comment = None + self._comments = [] if token_type in self.COMMANDS and ( len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON @@ -861,22 +861,18 @@ def _scan_comment(self, comment_start: str) -> bool: while not self._end and self._chars(comment_end_size) != comment_end: self._advance() - self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore + self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore self._advance(comment_end_size - 1) else: while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore self._advance() - self._comment = self._text[comment_start_size:] # type: ignore - - # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both - # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one. + self._comments.append(self._text[comment_start_size:]) # type: ignore + # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. + # Multiple consecutive comments are preserved by appending them to the current comments list. if comment_start_line == self._prev_token_line: - if self._prev_token_comment is None: - self.tokens[-1].comment = self._comment - self._prev_token_comment = self._comment - - self._comment = None + self.tokens[-1].comments.extend(self._comments) + self._comments = [] return True diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 09acc42326..6033570511 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1270,8 +1270,8 @@ def test_hash_comments(self): self.validate_all( """/* comment1 */ SELECT - x, -- comment2 - y -- comment3""", + x, /* comment2 */ + y /* comment3 */""", read={ "mysql": """SELECT # comment1 x, # comment2 diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 79fcfabddb..e162d1dfc9 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -581,6 +581,7 @@ SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo) SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) SELECT CAST(x AS INT) /* comment */ FROM foo SELECT a /* x */, b /* x */ +SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */ SELECT * FROM foo /* x */, bla /* x */ SELECT 1 /* comment */ + 1 SELECT 1 /* c1 */ + 2 /* c2 */ diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 5f9023d3b6..0e13adeb64 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -599,9 +599,9 @@ def test_comment_alias(self): """SELECT a, b AS B, - c, -- comment - d AS D, -- another comment - CAST(x AS INT) -- final comment + c, /* comment */ + d AS D, /* another comment */ + CAST(x AS INT) /* final comment */ FROM foo""", ) diff --git a/tests/test_parser.py b/tests/test_parser.py index 68cbf19361..fa7b589ad9 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -205,8 +205,9 @@ def test_var(self): def test_comments(self): expression = parse_one( """ - --comment1 - SELECT /* this won't be used */ + --comment1.1 + --comment1.2 + SELECT /*comment1.3*/ a, --comment2 b as B, --comment3:testing "test--annotation", @@ -217,13 +218,13 @@ def test_comments(self): """ ) - self.assertEqual(expression.comment, "comment1") - self.assertEqual(expression.expressions[0].comment, "comment2") - self.assertEqual(expression.expressions[1].comment, "comment3:testing") - self.assertEqual(expression.expressions[2].comment, None) - self.assertEqual(expression.expressions[3].comment, "comment4 --foo") - self.assertEqual(expression.expressions[4].comment, "") - self.assertEqual(expression.expressions[5].comment, " space") + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.expressions[0].comments, ["comment2"]) + self.assertEqual(expression.expressions[1].comments, ["comment3:testing"]) + self.assertEqual(expression.expressions[2].comments, None) + self.assertEqual(expression.expressions[3].comments, ["comment4 --foo"]) + self.assertEqual(expression.expressions[4].comments, [""]) + self.assertEqual(expression.expressions[5].comments, [" space"]) def test_type_literals(self): self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index d4772ba9e3..1d1b966566 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -7,13 +7,13 @@ class TestTokens(unittest.TestCase): def test_comment_attachment(self): tokenizer = Tokenizer() sql_comment = [ - ("/*comment*/ foo", "comment"), - ("/*comment*/ foo --test", "comment"), - ("--comment\nfoo --test", "comment"), - ("foo --comment", "comment"), - ("foo", None), - ("foo /*comment 1*/ /*comment 2*/", "comment 1"), + ("/*comment*/ foo", ["comment"]), + ("/*comment*/ foo --test", ["comment", "test"]), + ("--comment\nfoo --test", ["comment", "test"]), + ("foo --comment", ["comment"]), + ("foo", []), + ("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]), ] for sql, comment in sql_comment: - self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment) + self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index fb8546054a..7bf53e5fef 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -64,7 +64,7 @@ def test_leading_comma(self): ) self.validate( "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ", - "SELECT\n FOO -- x\n , BAR -- y\n , BAZ", + "SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ", leading_comma=True, pretty=True, ) @@ -84,7 +84,8 @@ def test_space(self): def test_comments(self): self.validate("SELECT */*comment*/", "SELECT * /* comment */") self.validate( - "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */" + "SELECT * FROM table /*comment 1*/ /*comment 2*/", + "SELECT * FROM table /* comment 1 */ /* comment 2 */", ) self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") @@ -118,6 +119,53 @@ def test_comments(self): ) self.validate( """ +-- comment 1 +-- comment 2 +-- comment 3 +SELECT * FROM foo + """, + "/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo", + ) + self.validate( + """ +-- comment 1 +-- comment 2 +-- comment 3 +SELECT * FROM foo""", + """/* comment 1 */ +/* comment 2 */ +/* comment 3 */ +SELECT + * +FROM foo""", + pretty=True, + ) + self.validate( + """ +SELECT * FROM tbl /*line1 +line2 +line3*/ /*another comment*/ where 1=1 -- comment at the end""", + """SELECT * FROM tbl /* line1 +line2 +line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""", + ) + self.validate( + """ +SELECT * FROM tbl /*line1 +line2 +line3*/ /*another comment*/ where 1=1 -- comment at the end""", + """SELECT + * +FROM tbl /* line1 +line2 +line3 */ +/* another comment */ +WHERE + 1 = 1 /* comment at the end */""", + pretty=True, + ) + self.validate( + """ /* multi line comment @@ -136,8 +184,8 @@ def test_comments(self): */ SELECT tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, - CAST(x AS INT), -- comment 3 - y -- comment 4 + CAST(x AS INT), /* comment 3 */ + y /* comment 4 */ FROM bar /* comment 5 */, tbl /* comment 6 */""", read="mysql", pretty=True,