Skip to content

Commit

Permalink
Refactor comments so they're stored in a list (#750)
Browse files Browse the repository at this point in the history
* Refactor comments so that we store them in a list

* Add a test for multiple leading comments

* Multi-comment generating, cleanup & more tests

* Cleanup

* Cleanup

* Update README
  • Loading branch information
georgesittas committed Nov 23, 2022
1 parent e98f4d9 commit 6b0da1e
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 104 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ print(sqlglot.transpile(sql, read='mysql', pretty=True)[0])
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3
y -- comment 4
FROM bar /* comment 5 */, tbl /* comment 6*/
CAST(x AS INT), /* comment 3 */
y /* comment 4 */
FROM bar /* comment 5 */, tbl /* comment 6 */
```


Expand Down
21 changes: 4 additions & 17 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):

key = "Expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type", "comment")
__slots__ = ("args", "parent", "arg_key", "type", "comments")

def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
self.comment = None
self.comments = None

for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
Expand Down Expand Up @@ -88,19 +88,6 @@ def text(self, key):
return field.this
return ""

def find_comment(self, key: str) -> str:
"""
Finds the comment that is attached to a specified child node.
Args:
key: the key of the target child node (e.g. "this", "expression", etc).
Returns:
The comment attached to the child node, or the empty string, if it doesn't exist.
"""
field = self.args.get(key)
return field.comment if isinstance(field, Expression) else ""

@property
def is_string(self):
return isinstance(self, Literal) and self.args["is_string"]
Expand Down Expand Up @@ -137,7 +124,7 @@ def alias_or_name(self):

def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
copy.comment = self.comment
copy.comments = self.comments
copy.type = self.type
return copy

Expand Down Expand Up @@ -369,7 +356,7 @@ def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
)
for k, vs in self.args.items()
}
args["comment"] = self.comment
args["comments"] = self.comments
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}

Expand Down
36 changes: 16 additions & 20 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import re
import typing as t

from sqlglot import exp
Expand All @@ -12,8 +11,6 @@

logger = logging.getLogger("sqlglot")

NEWLINE_RE = re.compile("\r\n?|\n")


class Generator:
"""
Expand Down Expand Up @@ -226,25 +223,24 @@ def sep(self, sep=" "):
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"

def maybe_comment(self, sql, expression, single_line=False):
comment = expression.comment if self._comments else None

if not comment:
return sql

def pad_comment(self, comment):
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
return comment

if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"/*{comment}*/{self.sep()}{sql}"
def maybe_comment(self, sql, expression):
comments = expression.comments if self._comments else None

if not self.pretty:
return f"{sql} /*{comment}*/"
if not comments:
return sql

if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
sep = "\n" if self.pretty else " "
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)

if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"{comments}{self.sep()}{sql}"

return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
return f"{sql} {comments}"

def wrap(self, expression):
this_sql = self.indent(
Expand Down Expand Up @@ -1337,15 +1333,15 @@ def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
comment = self.maybe_comment("", e, single_line=True)
comments = self.maybe_comment("", e)

if self.pretty:
if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
else:
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")

result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
Expand Down
38 changes: 19 additions & 19 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class Parser(metaclass=_Parser):
"_curr",
"_next",
"_prev",
"_prev_comment",
"_prev_comments",
"_show_trie",
"_set_trie",
)
Expand Down Expand Up @@ -600,7 +600,7 @@ def reset(self):
self._curr = None
self._next = None
self._prev = None
self._prev_comment = None
self._prev_comments = None

def parse(self, raw_tokens, sql=None):
"""
Expand Down Expand Up @@ -694,9 +694,9 @@ def raise_error(self, message, token=None):

def expression(self, exp_class, **kwargs):
instance = exp_class(**kwargs)
if self._prev_comment:
instance.comment = self._prev_comment
self._prev_comment = None
if self._prev_comments:
instance.comments = self._prev_comments
self._prev_comments = None
self.validate_expression(instance)
return instance

Expand Down Expand Up @@ -739,10 +739,10 @@ def _advance(self, times=1):
self._next = seq_get(self._tokens, self._index + 1)
if self._index > 0:
self._prev = self._tokens[self._index - 1]
self._prev_comment = self._prev.comment
self._prev_comments = self._prev.comments
else:
self._prev = None
self._prev_comment = None
self._prev_comments = None

def _retreat(self, index):
self._advance(index - self._index)
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def _parse_select(self, nested=False, table=False):
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
comment = self._prev_comment
comments = self._prev_comments

hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
Expand All @@ -1113,7 +1113,7 @@ def _parse_select(self, nested=False, table=False):
expressions=expressions,
limit=limit,
)
this.comment = comment
this.comments = comments
from_ = self._parse_from()
if from_:
this.set("from", from_)
Expand Down Expand Up @@ -1872,7 +1872,7 @@ def _parse_primary(self):
return exp.Literal.number(f"0.{self._prev.text}")

if self._match(TokenType.L_PAREN):
comment = self._prev_comment
comments = self._prev_comments
query = self._parse_select()

if query:
Expand All @@ -1892,8 +1892,8 @@ def _parse_primary(self):
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
if comment:
this.comment = comment
if comments:
this.comments = comments
return this

return None
Expand Down Expand Up @@ -2160,7 +2160,7 @@ def _parse_bracket(self, this):
if not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]")

this.comment = self._prev_comment
this.comments = self._prev_comments
return self._parse_bracket(this)

def _parse_case(self):
Expand Down Expand Up @@ -2482,8 +2482,8 @@ def _parse_csv(self, parse_method, sep=TokenType.COMMA):
items = [parse_result] if parse_result is not None else []

while self._match(sep):
if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment
if parse_result and self._prev_comments:
parse_result.comments = self._prev_comments

parse_result = parse_method()
if parse_result is not None:
Expand Down Expand Up @@ -2622,14 +2622,14 @@ def _match_pair(self, token_type_a, token_type_b, advance=True):
def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN):
self.raise_error("Expecting (")
if expression and self._prev_comment:
expression.comment = self._prev_comment
if expression and self._prev_comments:
expression.comments = self._prev_comments

def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN):
self.raise_error("Expecting )")
if expression and self._prev_comment:
expression.comment = self._prev_comment
if expression and self._prev_comments:
expression.comments = self._prev_comments

def _match_texts(self, texts):
if self._curr and self._curr.text.upper() in texts:
Expand Down
36 changes: 16 additions & 20 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class TokenType(AutoName):


class Token:
__slots__ = ("token_type", "text", "line", "col", "comment")
__slots__ = ("token_type", "text", "line", "col", "comments")

@classmethod
def number(cls, number: int) -> Token:
Expand Down Expand Up @@ -322,13 +322,13 @@ def __init__(
text: str,
line: int = 1,
col: int = 1,
comment: t.Optional[str] = None,
comments: t.List[str] = [],
) -> None:
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
self.comment = comment
self.comments = comments

def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
Expand Down Expand Up @@ -690,12 +690,12 @@ class Tokenizer(metaclass=_Tokenizer):
"_current",
"_line",
"_col",
"_comment",
"_comments",
"_char",
"_end",
"_peek",
"_prev_token_line",
"_prev_token_comment",
"_prev_token_comments",
"_prev_token_type",
"_replace_backslash",
)
Expand All @@ -712,13 +712,13 @@ def reset(self) -> None:
self._current = 0
self._line = 1
self._col = 1
self._comment = None
self._comments: t.List[str] = []

self._char = None
self._end = None
self._peek = None
self._prev_token_line = -1
self._prev_token_comment = None
self._prev_token_comments: t.List[str] = []
self._prev_token_type = None

def tokenize(self, sql: str) -> t.List[Token]:
Expand Down Expand Up @@ -771,18 +771,18 @@ def _text(self) -> str:

def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comment = self._comment
self._prev_token_comments = self._comments
self._prev_token_type = token_type # type: ignore
self.tokens.append(
Token(
token_type,
self._text if text is None else text,
self._line,
self._col,
self._comment,
self._comments,
)
)
self._comment = None
self._comments = []

if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
Expand Down Expand Up @@ -861,22 +861,18 @@ def _scan_comment(self, comment_start: str) -> bool:
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()

self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1)
else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance()
self._comment = self._text[comment_start_size:] # type: ignore

# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
self._comments.append(self._text[comment_start_size:]) # type: ignore

# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
# Multiple consecutive comments are preserved by appending them to the current comments list.
if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment
self._prev_token_comment = self._comment

self._comment = None
self.tokens[-1].comments.extend(self._comments)
self._comments = []

return True

Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,8 +1270,8 @@ def test_hash_comments(self):
self.validate_all(
"""/* comment1 */
SELECT
x, -- comment2
y -- comment3""",
x, /* comment2 */
y /* comment3 */""",
read={
"mysql": """SELECT # comment1
x, # comment2
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
SELECT CAST(x AS INT) /* comment */ FROM foo
SELECT a /* x */, b /* x */
SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */
SELECT * FROM foo /* x */, bla /* x */
SELECT 1 /* comment */ + 1
SELECT 1 /* c1 */ + 2 /* c2 */
Expand Down
Loading

0 comments on commit 6b0da1e

Please sign in to comment.