Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: #543 more properly identify CREATE TABLE ... LIKE ... statements #767

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions sqlparse/engine/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def group_typecasts(tlist):
def match(token):
return token.match(T.Punctuation, '::')

def valid(token):
def valid(token, idx):
return token is not None

def post(tlist, pidx, tidx, nidx):
Expand All @@ -91,10 +91,10 @@ def group_tzcasts(tlist):
def match(token):
return token.ttype == T.Keyword.TZCast

def valid_prev(token):
def valid_prev(token, idx):
return token is not None

def valid_next(token):
def valid_next(token, idx):
return token is not None and (
token.is_whitespace
or token.match(T.Keyword, 'AS')
Expand All @@ -119,13 +119,13 @@ def match(token):
def match_to_extend(token):
return isinstance(token, sql.TypedLiteral)

def valid_prev(token):
def valid_prev(token, idx):
return token is not None

def valid_next(token):
def valid_next(token, idx):
return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)

def valid_final(token):
def valid_final(token, idx):
return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)

def post(tlist, pidx, tidx, nidx):
Expand All @@ -141,12 +141,12 @@ def group_period(tlist):
def match(token):
return token.match(T.Punctuation, '.')

def valid_prev(token):
def valid_prev(token, idx):
sqlcls = sql.SquareBrackets, sql.Identifier
ttypes = T.Name, T.String.Symbol
return imt(token, i=sqlcls, t=ttypes)

def valid_next(token):
def valid_next(token, idx):
# issue261, allow invalid next token
return True

Expand All @@ -166,10 +166,10 @@ def group_as(tlist):
def match(token):
return token.is_keyword and token.normalized == 'AS'

def valid_prev(token):
def valid_prev(token, idx):
return token.normalized == 'NULL' or not token.is_keyword

def valid_next(token):
def valid_next(token, idx):
ttypes = T.DML, T.DDL, T.CTE
return not imt(token, t=ttypes) and token is not None

Expand All @@ -183,7 +183,7 @@ def group_assignment(tlist):
def match(token):
return token.match(T.Assignment, ':=')

def valid(token):
def valid(token, idx):
return token is not None and token.ttype not in (T.Keyword,)

def post(tlist, pidx, tidx, nidx):
Expand All @@ -202,9 +202,9 @@ def group_comparison(tlist):
ttypes = T_NUMERICAL + T_STRING + T_NAME

def match(token):
return token.ttype == T.Operator.Comparison
return imt(token, t=(T.Operator.Comparison), m=(T.Keyword, 'LIKE'))

def valid(token):
def valid(token, idx):
if imt(token, t=ttypes, i=sqlcls):
return True
elif token and token.is_keyword and token.normalized == 'NULL':
Expand All @@ -215,7 +215,22 @@ def valid(token):
def post(tlist, pidx, tidx, nidx):
return pidx, nidx

valid_prev = valid_next = valid
def valid_next(token, idx):
return valid(token, idx)

def valid_prev(token, idx):
# https://dev.mysql.com/doc/refman/8.0/en/create-table-like.html
# LIKE is usually a comparator, except when used in
# `CREATE TABLE x LIKE y` statements, Check if we are
# constructing a table - otherwise assume it is indeed a comparator
two_tokens_back_idx = idx - 3
if two_tokens_back_idx >= 0:
_, two_tokens_back = tlist.token_next(two_tokens_back_idx)
if imt(two_tokens_back, m=(T.Keyword, 'TABLE')):
return False

return valid(token, idx)

_group(tlist, sql.Comparison, match,
valid_prev, valid_next, post, extend=False)

Expand All @@ -237,10 +252,10 @@ def group_arrays(tlist):
def match(token):
return isinstance(token, sql.SquareBrackets)

def valid_prev(token):
def valid_prev(token, idx):
return imt(token, i=sqlcls, t=ttypes)

def valid_next(token):
def valid_next(token, idx):
return True

def post(tlist, pidx, tidx, nidx):
Expand All @@ -258,7 +273,7 @@ def group_operator(tlist):
def match(token):
return imt(token, t=(T.Operator, T.Wildcard))

def valid(token):
def valid(token, idx):
return imt(token, i=sqlcls, t=ttypes) \
or (token and token.match(
T.Keyword,
Expand All @@ -283,7 +298,7 @@ def group_identifier_list(tlist):
def match(token):
return token.match(T.Punctuation, ',')

def valid(token):
def valid(token, idx):
return imt(token, i=sqlcls, m=m_role, t=ttypes)

def post(tlist, pidx, tidx, nidx):
Expand Down Expand Up @@ -431,8 +446,8 @@ def group(stmt):


def _group(tlist, cls, match,
valid_prev=lambda t: True,
valid_next=lambda t: True,
valid_prev=lambda t, idx: True,
valid_next=lambda t, idx: True,
post=None,
extend=True,
recurse=True
Expand All @@ -454,7 +469,7 @@ def _group(tlist, cls, match,

if match(token):
nidx, next_ = tlist.token_next(tidx)
if prev_ and valid_prev(prev_) and valid_next(next_):
if prev_ and valid_prev(prev_, pidx) and valid_next(next_, nidx):
from_idx, to_idx = post(tlist, pidx, tidx, nidx)
grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)

Expand Down
3 changes: 2 additions & 1 deletion sqlparse/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@
r'(EXPLODE|INLINE|PARSE_URL_TUPLE|POSEXPLODE|STACK)\b',
tokens.Keyword),
(r"(AT|WITH')\s+TIME\s+ZONE\s+'[^']+'", tokens.Keyword.TZCast),
(r'(NOT\s+)?(LIKE|ILIKE|RLIKE)\b', tokens.Operator.Comparison),
(r'(NOT\s+)(LIKE|ILIKE|RLIKE)\b', tokens.Operator.Comparison),
(r'(ILIKE|RLIKE)\b', tokens.Operator.Comparison),
markjm marked this conversation as resolved.
Show resolved Hide resolved
(r'(NOT\s+)?(REGEXP)\b', tokens.Operator.Comparison),
# Check for keywords, also returns tokens.Name if regex matches
# but the match isn't a keyword.
Expand Down
13 changes: 8 additions & 5 deletions tests/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,10 @@ def test_comparison_with_strings(operator):
assert p.tokens[0].right.ttype == T.String.Single


def test_like_and_ilike_comparison():
@pytest.mark.parametrize('operator', (
markjm marked this conversation as resolved.
Show resolved Hide resolved
'LIKE', 'NOT LIKE', 'ILIKE', 'NOT ILIKE', 'RLIKE', 'NOT RLIKE'
))
def test_like_and_ilike_comparison(operator):
def validate_where_clause(where_clause, expected_tokens):
assert len(where_clause.tokens) == len(expected_tokens)
for where_token, expected_token in zip(where_clause, expected_tokens):
Expand All @@ -513,22 +516,22 @@ def validate_where_clause(where_clause, expected_tokens):
assert (isinstance(where_token, expected_ttype)
and re.match(expected_value, where_token.value))

[p1] = sqlparse.parse("select * from mytable where mytable.mycolumn LIKE 'expr%' limit 5;")
[p1] = sqlparse.parse(f"select * from mytable where mytable.mycolumn {operator} 'expr%' limit 5;")
[p1_where] = [token for token in p1 if isinstance(token, sql.Where)]
validate_where_clause(p1_where, [
(T.Keyword, "where"),
(T.Whitespace, None),
(sql.Comparison, r"mytable.mycolumn LIKE.*"),
(sql.Comparison, f"mytable.mycolumn {operator}.*"),
(T.Whitespace, None),
])

[p2] = sqlparse.parse(
"select * from mytable where mycolumn NOT ILIKE '-expr' group by othercolumn;")
f"select * from mytable where mycolumn {operator} '-expr' group by othercolumn;")
[p2_where] = [token for token in p2 if isinstance(token, sql.Where)]
validate_where_clause(p2_where, [
(T.Keyword, "where"),
(T.Whitespace, None),
(sql.Comparison, r"mycolumn NOT ILIKE.*"),
(sql.Comparison, f"mycolumn {operator}.*"),
(T.Whitespace, None),
])

Expand Down
15 changes: 15 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,18 @@ def test_copy_issue672():
p = sqlparse.parse('select * from foo')[0]
copied = copy.deepcopy(p)
assert str(p) == str(copied)


def test_copy_issue543():
tokens = sqlparse.parse('create table tab1.b like tab2')[0].tokens
assert [(t.ttype, t.value) for t in tokens if t.ttype != T.Whitespace] == \
[
(T.DDL, 'create'),
(T.Keyword, 'table'),
(None, 'tab1.b'),
(T.Keyword, 'like'),
(None, 'tab2')
]

comparison = sqlparse.parse('a LIKE "b"')[0].tokens[0]
assert isinstance(comparison, sql.Comparison)
4 changes: 2 additions & 2 deletions tests/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ def test_parse_window_as():


@pytest.mark.parametrize('s', (
"LIKE", "ILIKE", "NOT LIKE", "NOT ILIKE",
"ILIKE", "NOT LIKE", "NOT ILIKE",
"NOT LIKE", "NOT ILIKE",
))
def test_like_and_ilike_parsed_as_comparisons(s):
def test_likeish_but_not_like_parsed_as_comparisons(s):
p = sqlparse.parse(s)[0]
assert len(p.tokens) == 1
assert p.tokens[0].ttype == T.Operator.Comparison
Expand Down
Loading