Skip to content

Commit

Permalink
Optimizer: case insensitivity (#785)
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon authored Nov 30, 2022
1 parent 9aa4f89 commit 638ed26
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 2 deletions.
92 changes: 92 additions & 0 deletions sqlglot/optimizer/lower_identities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection


def lower_identities(expression):
"""
Convert all unquoted identifiers to lower case.
Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> lower_identities(expression).sql()
'SELECT bar.a AS A FROM "Foo".bar'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
# We need to leave the output aliases unchanged, so the selects need special handling
_lower_selects(expression)

# These clauses can reference output aliases and also need special handling
_lower_order(expression)
_lower_having(expression)

# We've already handled these args, so don't traverse into them
traversed = {"expressions", "order", "having"}

if isinstance(expression, exp.Subquery):
# Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
lower_identities(expression.this)
traversed |= {"this"}

if isinstance(expression, exp.Union):
# Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
lower_identities(expression.left)
lower_identities(expression.right)
traversed |= {"this", "expression"}

for k, v in expression.args.items():
if k in traversed:
continue

for child in ensure_collection(v):
if isinstance(child, exp.Expression):
child.transform(_lower, copy=False)

return expression


def _lower_selects(expression):
for e in expression.expressions:
# Leave output aliases as-is
e.unalias().transform(_lower, copy=False)


def _lower_order(expression):
order = expression.args.get("order")

if not order:
return

output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}

for ordered in order.expressions:
# Don't lower references to output aliases
if not (
isinstance(ordered.this, exp.Column)
and not ordered.this.table
and ordered.this.name in output_aliases
):
ordered.transform(_lower, copy=False)


def _lower_having(expression):
having = expression.args.get("having")

if not having:
return

# Don't lower references to output aliases
for agg in having.find_all(exp.AggFunc):
agg.transform(_lower, copy=False)


def _lower(node):
if isinstance(node, exp.Identifier) and not node.quoted:
node.set("this", node.this.lower())
return node
2 changes: 2 additions & 0 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
Expand All @@ -17,6 +18,7 @@
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries

RULES = (
lower_identities,
qualify_tables,
isolate_table_selects,
qualify_columns,
Expand Down
41 changes: 41 additions & 0 deletions tests/fixtures/optimizer/lower_identities.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
SELECT a FROM x;
SELECT a FROM x;

SELECT "A" FROM "X";
SELECT "A" FROM "X";

SELECT a AS A FROM x;
SELECT a AS A FROM x;

SELECT * FROM x;
SELECT * FROM x;

SELECT A FROM x;
SELECT a FROM x;

SELECT a FROM X;
SELECT a FROM x;

SELECT A AS A FROM (SELECT a AS A FROM x);
SELECT a AS A FROM (SELECT a AS a FROM x);

SELECT a AS B FROM x ORDER BY B;
SELECT a AS B FROM x ORDER BY B;

SELECT A FROM x ORDER BY A;
SELECT a FROM x ORDER BY a;

SELECT A AS B FROM X GROUP BY A HAVING SUM(B) > 0;
SELECT a AS B FROM x GROUP BY a HAVING SUM(b) > 0;

SELECT A AS B, SUM(B) AS C FROM X GROUP BY A HAVING C > 0;
SELECT a AS B, SUM(b) AS C FROM x GROUP BY a HAVING C > 0;

SELECT A FROM X UNION SELECT A FROM X;
SELECT a FROM x UNION SELECT a FROM x;

SELECT A AS A FROM X UNION SELECT A AS A FROM X;
SELECT a AS A FROM x UNION SELECT a AS A FROM x;

(SELECT A AS A FROM X);
(SELECT a AS A FROM x);
5 changes: 5 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,8 @@ def test_scalar_functions(self):
with self.subTest(sql):
result = execute(f"SELECT {sql}")
self.assertEqual(result.rows, [(expected,)])

def test_case_sensitivity(self):
result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]})
self.assertEqual(result.columns, ("A",))
self.assertEqual(result.rows, [(1,)])
6 changes: 4 additions & 2 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ def check_file(self, file, func, pretty=False, execute=False, **kwargs):
if leave_tables_isolated is not None:
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)

optimized = func(parse_one(sql, read=dialect), **func_kwargs)

with self.subTest(title):
optimized = func(parse_one(sql, read=dialect), **func_kwargs)
self.assertEqual(
expected,
optimized.sql(pretty=pretty, dialect=dialect),
Expand Down Expand Up @@ -168,6 +167,9 @@ def test_qualify_columns__invalid(self):
def test_quote_identities(self):
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)

def test_lower_identities(self):
self.check_file("lower_identities", optimizer.lower_identities.lower_identities)

def test_pushdown_projection(self):
def pushdown_projections(expression, **kwargs):
expression = optimizer.qualify_tables.qualify_tables(expression)
Expand Down

0 comments on commit 638ed26

Please sign in to comment.