Skip to content

Commit

Permalink
fix: #543 more properly identify CREATE TABLE ... LIKE ... statements
Browse files Browse the repository at this point in the history
  • Loading branch information
markjm committed Mar 12, 2024
1 parent 60486b9 commit 945c80c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 29 deletions.
57 changes: 36 additions & 21 deletions sqlparse/engine/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def group_typecasts(tlist):
def match(token):
return token.match(T.Punctuation, '::')

def valid(token):
def valid(token, idx):
return token is not None

def post(tlist, pidx, tidx, nidx):
Expand All @@ -91,10 +91,10 @@ def group_tzcasts(tlist):
def match(token):
return token.ttype == T.Keyword.TZCast

def valid_prev(token):
def valid_prev(token, idx):
return token is not None

def valid_next(token):
def valid_next(token, idx):
return token is not None and (
token.is_whitespace
or token.match(T.Keyword, 'AS')
Expand All @@ -119,13 +119,13 @@ def match(token):
def match_to_extend(token):
return isinstance(token, sql.TypedLiteral)

def valid_prev(token):
def valid_prev(token, idx):
return token is not None

def valid_next(token):
def valid_next(token, idx):
return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)

def valid_final(token):
def valid_final(token, idx):
return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)

def post(tlist, pidx, tidx, nidx):
Expand All @@ -141,12 +141,12 @@ def group_period(tlist):
def match(token):
return token.match(T.Punctuation, '.')

def valid_prev(token):
def valid_prev(token, idx):
sqlcls = sql.SquareBrackets, sql.Identifier
ttypes = T.Name, T.String.Symbol
return imt(token, i=sqlcls, t=ttypes)

def valid_next(token):
def valid_next(token, idx):
# issue261, allow invalid next token
return True

Expand All @@ -166,10 +166,10 @@ def group_as(tlist):
def match(token):
return token.is_keyword and token.normalized == 'AS'

def valid_prev(token):
def valid_prev(token, idx):
return token.normalized == 'NULL' or not token.is_keyword

def valid_next(token):
def valid_next(token, idx):
ttypes = T.DML, T.DDL, T.CTE
return not imt(token, t=ttypes) and token is not None

Expand All @@ -183,7 +183,7 @@ def group_assignment(tlist):
def match(token):
return token.match(T.Assignment, ':=')

def valid(token):
def valid(token, idx):
return token is not None and token.ttype not in (T.Keyword,)

def post(tlist, pidx, tidx, nidx):
Expand All @@ -202,9 +202,9 @@ def group_comparison(tlist):
ttypes = T_NUMERICAL + T_STRING + T_NAME

def match(token):
return token.ttype == T.Operator.Comparison
return imt(token, t=(T.Operator.Comparison), m=(T.Keyword, 'LIKE'))

def valid(token):
def valid(token, idx):
if imt(token, t=ttypes, i=sqlcls):
return True
elif token and token.is_keyword and token.normalized == 'NULL':
Expand All @@ -215,7 +215,22 @@ def valid(token):
def post(tlist, pidx, tidx, nidx):
return pidx, nidx

valid_prev = valid_next = valid
def valid_next(token, idx):
return valid(token, idx)

def valid_prev(token, idx):
# https://dev.mysql.com/doc/refman/8.0/en/create-table-like.html
# LIKE is usually a compatarator, except when used in
# `CREATE TABLE x LIKE y` statements, Check if we are
# constructing a table - otherwise assume it is indeed a comparator
two_tokens_back_idx = idx - 3
if two_tokens_back_idx >= 0:
_, two_tokens_back = tlist.token_next(two_tokens_back_idx)
if imt(two_tokens_back, m=(T.Keyword, 'TABLE')):
return False

return valid(token, idx)

_group(tlist, sql.Comparison, match,
valid_prev, valid_next, post, extend=False)

Expand All @@ -237,10 +252,10 @@ def group_arrays(tlist):
def match(token):
return isinstance(token, sql.SquareBrackets)

def valid_prev(token):
def valid_prev(token, idx):
return imt(token, i=sqlcls, t=ttypes)

def valid_next(token):
def valid_next(token, idx):
return True

def post(tlist, pidx, tidx, nidx):
Expand All @@ -258,7 +273,7 @@ def group_operator(tlist):
def match(token):
return imt(token, t=(T.Operator, T.Wildcard))

def valid(token):
def valid(token, idx):
return imt(token, i=sqlcls, t=ttypes) \
or (token and token.match(
T.Keyword,
Expand All @@ -283,7 +298,7 @@ def group_identifier_list(tlist):
def match(token):
return token.match(T.Punctuation, ',')

def valid(token):
def valid(token, idx):
return imt(token, i=sqlcls, m=m_role, t=ttypes)

def post(tlist, pidx, tidx, nidx):
Expand Down Expand Up @@ -431,8 +446,8 @@ def group(stmt):


def _group(tlist, cls, match,
valid_prev=lambda t: True,
valid_next=lambda t: True,
valid_prev=lambda t, idx: True,
valid_next=lambda t, idx: True,
post=None,
extend=True,
recurse=True
Expand All @@ -454,7 +469,7 @@ def _group(tlist, cls, match,

if match(token):
nidx, next_ = tlist.token_next(tidx)
if prev_ and valid_prev(prev_) and valid_next(next_):
if prev_ and valid_prev(prev_, pidx) and valid_next(next_, nidx):
from_idx, to_idx = post(tlist, pidx, tidx, nidx)
grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)

Expand Down
3 changes: 2 additions & 1 deletion sqlparse/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@
r'(EXPLODE|INLINE|PARSE_URL_TUPLE|POSEXPLODE|STACK)\b',
tokens.Keyword),
(r"(AT|WITH')\s+TIME\s+ZONE\s+'[^']+'", tokens.Keyword.TZCast),
(r'(NOT\s+)?(LIKE|ILIKE|RLIKE)\b', tokens.Operator.Comparison),
(r'(NOT\s+)(LIKE|ILIKE|RLIKE)\b', tokens.Operator.Comparison),
(r'(ILIKE|RLIKE)\b', tokens.Operator.Comparison),
(r'(NOT\s+)?(REGEXP)\b', tokens.Operator.Comparison),
# Check for keywords, also returns tokens.Name if regex matches
# but the match isn't a keyword.
Expand Down
13 changes: 8 additions & 5 deletions tests/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,10 @@ def test_comparison_with_strings(operator):
assert p.tokens[0].right.ttype == T.String.Single


def test_like_and_ilike_comparison():
@pytest.mark.parametrize('operator', (
'LIKE', 'NOT LIKE', 'ILIKE', 'NOT ILIKE', 'RLIKE', 'NOT RLIKE'
))
def test_like_and_ilike_comparison(operator):
def validate_where_clause(where_clause, expected_tokens):
assert len(where_clause.tokens) == len(expected_tokens)
for where_token, expected_token in zip(where_clause, expected_tokens):
Expand All @@ -513,22 +516,22 @@ def validate_where_clause(where_clause, expected_tokens):
assert (isinstance(where_token, expected_ttype)
and re.match(expected_value, where_token.value))

[p1] = sqlparse.parse("select * from mytable where mytable.mycolumn LIKE 'expr%' limit 5;")
[p1] = sqlparse.parse(f"select * from mytable where mytable.mycolumn {operator} 'expr%' limit 5;")
[p1_where] = [token for token in p1 if isinstance(token, sql.Where)]
validate_where_clause(p1_where, [
(T.Keyword, "where"),
(T.Whitespace, None),
(sql.Comparison, r"mytable.mycolumn LIKE.*"),
(sql.Comparison, f"mytable.mycolumn {operator}.*"),
(T.Whitespace, None),
])

[p2] = sqlparse.parse(
"select * from mytable where mycolumn NOT ILIKE '-expr' group by othercolumn;")
f"select * from mytable where mycolumn {operator} '-expr' group by othercolumn;")
[p2_where] = [token for token in p2 if isinstance(token, sql.Where)]
validate_where_clause(p2_where, [
(T.Keyword, "where"),
(T.Whitespace, None),
(sql.Comparison, r"mycolumn NOT ILIKE.*"),
(sql.Comparison, f"mycolumn {operator}.*"),
(T.Whitespace, None),
])

Expand Down
15 changes: 15 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,18 @@ def test_copy_issue672():
p = sqlparse.parse('select * from foo')[0]
copied = copy.deepcopy(p)
assert str(p) == str(copied)


def test_copy_issue543():
tokens = sqlparse.parse('create table tab1.b like tab2')[0].tokens
assert [(t.ttype, t.value) for t in tokens if t.ttype != T.Whitespace] == \
[
(T.DDL, 'create'),
(T.Keyword, 'table'),
(None, 'tab1.b'),
(T.Keyword, 'like'),
(None, 'tab2')
]

comparison = sqlparse.parse('a LIKE "b"')[0].tokens[0]
assert isinstance(comparison, sql.Comparison)
4 changes: 2 additions & 2 deletions tests/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ def test_parse_window_as():


@pytest.mark.parametrize('s', (
"LIKE", "ILIKE", "NOT LIKE", "NOT ILIKE",
"ILIKE", "NOT LIKE", "NOT ILIKE",
"NOT LIKE", "NOT ILIKE",
))
def test_like_and_ilike_parsed_as_comparisons(s):
def test_likeish_but_not_like_parsed_as_comparisons(s):
p = sqlparse.parse(s)[0]
assert len(p.tokens) == 1
assert p.tokens[0].ttype == T.Operator.Comparison
Expand Down

0 comments on commit 945c80c

Please sign in to comment.