diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index c486318a..5dc148b2 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -77,7 +77,7 @@ def group_typecasts(tlist): def match(token): return token.match(T.Punctuation, '::') - def valid(token): + def valid(token, idx): return token is not None def post(tlist, pidx, tidx, nidx): @@ -91,10 +91,10 @@ def group_tzcasts(tlist): def match(token): return token.ttype == T.Keyword.TZCast - def valid_prev(token): + def valid_prev(token, idx): return token is not None - def valid_next(token): + def valid_next(token, idx): return token is not None and ( token.is_whitespace or token.match(T.Keyword, 'AS') @@ -119,13 +119,13 @@ def match(token): def match_to_extend(token): return isinstance(token, sql.TypedLiteral) - def valid_prev(token): + def valid_prev(token, idx): return token is not None - def valid_next(token): + def valid_next(token, idx): return token is not None and token.match(*sql.TypedLiteral.M_CLOSE) - def valid_final(token): + def valid_final(token, idx): return token is not None and token.match(*sql.TypedLiteral.M_EXTEND) def post(tlist, pidx, tidx, nidx): @@ -141,12 +141,12 @@ def group_period(tlist): def match(token): return token.match(T.Punctuation, '.') - def valid_prev(token): + def valid_prev(token, idx): sqlcls = sql.SquareBrackets, sql.Identifier ttypes = T.Name, T.String.Symbol return imt(token, i=sqlcls, t=ttypes) - def valid_next(token): + def valid_next(token, idx): # issue261, allow invalid next token return True @@ -166,10 +166,10 @@ def group_as(tlist): def match(token): return token.is_keyword and token.normalized == 'AS' - def valid_prev(token): + def valid_prev(token, idx): return token.normalized == 'NULL' or not token.is_keyword - def valid_next(token): + def valid_next(token, idx): ttypes = T.DML, T.DDL, T.CTE return not imt(token, t=ttypes) and token is not None @@ -183,7 +183,7 @@ def group_assignment(tlist): def match(token): return token.match(T.Assignment, ':=') - def valid(token): + def valid(token, idx): return token is not None and token.ttype not in (T.Keyword,) def post(tlist, pidx, tidx, nidx): @@ -202,9 +202,9 @@ def group_comparison(tlist): ttypes = T_NUMERICAL + T_STRING + T_NAME def match(token): - return token.ttype == T.Operator.Comparison + return imt(token, t=(T.Operator.Comparison), m=(T.Keyword, 'LIKE')) - def valid(token): + def valid(token, idx): if imt(token, t=ttypes, i=sqlcls): return True elif token and token.is_keyword and token.normalized == 'NULL': @@ -215,7 +215,22 @@ def valid(token): def post(tlist, pidx, tidx, nidx): return pidx, nidx - valid_prev = valid_next = valid + def valid_next(token, idx): + return valid(token, idx) + + def valid_prev(token, idx): + # https://dev.mysql.com/doc/refman/8.0/en/create-table-like.html + # LIKE is usually a compatarator, except when used in + # `CREATE TABLE x LIKE y` statements, Check if we are + # constructing a table - otherwise assume it is indeed a comparator + two_tokens_back_idx = idx - 3 + if two_tokens_back_idx >= 0: + _, two_tokens_back = tlist.token_next(two_tokens_back_idx) + if imt(two_tokens_back, m=(T.Keyword, 'TABLE')): + return False + + return valid(token, idx) + _group(tlist, sql.Comparison, match, valid_prev, valid_next, post, extend=False) @@ -237,10 +252,10 @@ def group_arrays(tlist): def match(token): return isinstance(token, sql.SquareBrackets) - def valid_prev(token): + def valid_prev(token, idx): return imt(token, i=sqlcls, t=ttypes) - def valid_next(token): + def valid_next(token, idx): return True def post(tlist, pidx, tidx, nidx): @@ -258,7 +273,7 @@ def group_operator(tlist): def match(token): return imt(token, t=(T.Operator, T.Wildcard)) - def valid(token): + def valid(token, idx): return imt(token, i=sqlcls, t=ttypes) \ or (token and token.match( T.Keyword, @@ -283,7 +298,7 @@ def group_identifier_list(tlist): def match(token): return token.match(T.Punctuation, ',') - def valid(token): + def valid(token, idx): return imt(token, i=sqlcls, m=m_role, t=ttypes) def post(tlist, pidx, tidx, nidx): @@ -431,8 +446,8 @@ def group(stmt): def _group(tlist, cls, match, - valid_prev=lambda t: True, - valid_next=lambda t: True, + valid_prev=lambda t, idx: True, + valid_next=lambda t, idx: True, post=None, extend=True, recurse=True @@ -454,7 +469,7 @@ def _group(tlist, cls, match, if match(token): nidx, next_ = tlist.token_next(tidx) - if prev_ and valid_prev(prev_) and valid_next(next_): + if prev_ and valid_prev(prev_, pidx) and valid_next(next_, nidx): from_idx, to_idx = post(tlist, pidx, tidx, nidx) grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend) diff --git a/sqlparse/keywords.py b/sqlparse/keywords.py index d3794fd3..fa8b6e65 100644 --- a/sqlparse/keywords.py +++ b/sqlparse/keywords.py @@ -82,7 +82,8 @@ r'(EXPLODE|INLINE|PARSE_URL_TUPLE|POSEXPLODE|STACK)\b', tokens.Keyword), (r"(AT|WITH')\s+TIME\s+ZONE\s+'[^']+'", tokens.Keyword.TZCast), - (r'(NOT\s+)?(LIKE|ILIKE|RLIKE)\b', tokens.Operator.Comparison), + (r'(NOT\s+)(LIKE|ILIKE|RLIKE)\b', tokens.Operator.Comparison), + (r'(ILIKE|RLIKE)\b', tokens.Operator.Comparison), (r'(NOT\s+)?(REGEXP)\b', tokens.Operator.Comparison), # Check for keywords, also returns tokens.Name if regex matches # but the match isn't a keyword. diff --git a/tests/test_grouping.py b/tests/test_grouping.py index e90243b5..5e71dca5 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -498,7 +498,10 @@ def test_comparison_with_strings(operator): assert p.tokens[0].right.ttype == T.String.Single -def test_like_and_ilike_comparison(): +@pytest.mark.parametrize('operator', ( + 'LIKE', 'NOT LIKE', 'ILIKE', 'NOT ILIKE', 'RLIKE', 'NOT RLIKE' +)) +def test_like_and_ilike_comparison(operator): def validate_where_clause(where_clause, expected_tokens): assert len(where_clause.tokens) == len(expected_tokens) for where_token, expected_token in zip(where_clause, expected_tokens): @@ -513,22 +516,22 @@ def validate_where_clause(where_clause, expected_tokens): assert (isinstance(where_token, expected_ttype) and re.match(expected_value, where_token.value)) - [p1] = sqlparse.parse("select * from mytable where mytable.mycolumn LIKE 'expr%' limit 5;") + [p1] = sqlparse.parse(f"select * from mytable where mytable.mycolumn {operator} 'expr%' limit 5;") [p1_where] = [token for token in p1 if isinstance(token, sql.Where)] validate_where_clause(p1_where, [ (T.Keyword, "where"), (T.Whitespace, None), - (sql.Comparison, r"mytable.mycolumn LIKE.*"), + (sql.Comparison, f"mytable.mycolumn {operator}.*"), (T.Whitespace, None), ]) [p2] = sqlparse.parse( - "select * from mytable where mycolumn NOT ILIKE '-expr' group by othercolumn;") + f"select * from mytable where mycolumn {operator} '-expr' group by othercolumn;") [p2_where] = [token for token in p2 if isinstance(token, sql.Where)] validate_where_clause(p2_where, [ (T.Keyword, "where"), (T.Whitespace, None), - (sql.Comparison, r"mycolumn NOT ILIKE.*"), + (sql.Comparison, f"mycolumn {operator}.*"), (T.Whitespace, None), ]) diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 961adc17..3a49ba6a 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -444,3 +444,18 @@ def test_copy_issue672(): p = sqlparse.parse('select * from foo')[0] copied = copy.deepcopy(p) assert str(p) == str(copied) + + +def test_copy_issue543(): + tokens = sqlparse.parse('create table tab1.b like tab2')[0].tokens + assert [(t.ttype, t.value) for t in tokens if t.ttype != T.Whitespace] == \ + [ + (T.DDL, 'create'), + (T.Keyword, 'table'), + (None, 'tab1.b'), + (T.Keyword, 'like'), + (None, 'tab2') + ] + + comparison = sqlparse.parse('a LIKE "b"')[0].tokens[0] + assert isinstance(comparison, sql.Comparison) diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index af0ba163..ce1b4af0 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -209,10 +209,10 @@ def test_parse_window_as(): @pytest.mark.parametrize('s', ( - "LIKE", "ILIKE", "NOT LIKE", "NOT ILIKE", + "ILIKE", "NOT LIKE", "NOT ILIKE", "NOT LIKE", "NOT ILIKE", )) -def test_like_and_ilike_parsed_as_comparisons(s): +def test_likeish_but_not_like_parsed_as_comparisons(s): p = sqlparse.parse(s)[0] assert len(p.tokens) == 1 assert p.tokens[0].ttype == T.Operator.Comparison