Skip to content

Commit

Permalink
simplify comparisons (#793)
Browse files Browse the repository at this point in the history
* simplify comparisons

* fix bug and cleanup
  • Loading branch information
tobymao authored Dec 5, 2022
1 parent c935b59 commit e24d031
Show file tree
Hide file tree
Showing 4 changed files with 363 additions and 87 deletions.
8 changes: 8 additions & 0 deletions sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,11 @@ def dict_depth(d: t.Dict) -> int:
except StopIteration:
# d.values() returns an empty sequence
return 1


def first(it: t.Iterable[T]) -> T:
"""Returns the first element from an iterable.
Useful for sets.
"""
return next(i for i in it)
235 changes: 162 additions & 73 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import while_changing
from sqlglot.helper import first, while_changing

GENERATOR = Generator(normalize=True, identify=True)

Expand All @@ -30,6 +30,7 @@ def simplify(expression):

def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
node = uniq_sort(node)
node = absorb_and_eliminate(node)
exp.replace_children(node, lambda e: _simplify(e, False))
Expand All @@ -49,6 +50,19 @@ def _simplify(expression, root=True):
return expression


def rewrite_between(expression: exp.Expression) -> exp.Expression:
"""Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
"""
if isinstance(expression, exp.Between):
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
)
return expression


def simplify_not(expression):
"""
Demorgan's Law
Expand All @@ -57,19 +71,19 @@ def simplify_not(expression):
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
return NULL
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
return NULL
return exp.null()
if always_true(expression.this):
return FALSE
return exp.false()
if expression.this == FALSE:
return TRUE
return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
# NOT NOT x -> x
Expand All @@ -91,40 +105,119 @@ def flatten(expression):


def simplify_connectors(expression):
if isinstance(expression, exp.Connector):
left = expression.left
right = expression.right

if left == right:
return left

if isinstance(expression, exp.And):
if FALSE in (left, right):
return FALSE
if NULL in (left, right):
return NULL
if always_true(left) and always_true(right):
return TRUE
if always_true(left):
return right
if always_true(right):
return left
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return TRUE
if left == FALSE and right == FALSE:
return FALSE
if (
(left == NULL and right == NULL)
or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL)
):
return NULL
if left == FALSE:
return right
if right == FALSE:
def _simplify_connectors(expression, left, right):
if isinstance(expression, exp.Connector):
if left == right:
return left
return expression
if isinstance(expression, exp.And):
if FALSE in (left, right):
return exp.false()
if NULL in (left, right):
return exp.null()
if always_true(left) and always_true(right):
return exp.true()
if always_true(left):
return right
if always_true(right):
return left
return _simplify_comparison(expression, left, right)
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return exp.true()
if left == FALSE and right == FALSE:
return exp.false()
if (
(left == NULL and right == NULL)
or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL)
):
return exp.null()
if left == FALSE:
return right
if right == FALSE:
return left
return _simplify_comparison(expression, left, right, or_=True)
return None

return _flat_simplify(expression, _simplify_connectors)


LT_LTE = (exp.LT, exp.LTE)
GT_GTE = (exp.GT, exp.GTE)

COMPARISONS = (
*LT_LTE,
*GT_GTE,
exp.EQ,
exp.NEQ,
)

INVERSE_COMPARISONS = {
exp.LT: exp.GT,
exp.GT: exp.LT,
exp.LTE: exp.GTE,
exp.GTE: exp.LTE,
}


def _simplify_comparison(expression, left, right, or_=False):
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
ll, lr = left.args.values()
rl, rr = right.args.values()

largs = {ll, lr}
rargs = {rl, rr}

matching = largs & rargs
columns = {m for m in matching if isinstance(m, exp.Column)}

if matching and columns:
try:
l = first(largs - columns)
r = first(rargs - columns)
except StopIteration:
return expression

# make sure the comparison is always of the form x > 1 instead of 1 < x
if left.__class__ in INVERSE_COMPARISONS and l == ll:
left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
if right.__class__ in INVERSE_COMPARISONS and r == rl:
right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)

if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
elif l.is_string and r.is_string:
l = l.name
r = r.name
else:
return None

for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
return left if (av > bv if or_ else av <= bv) else right
if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
return left if (av < bv if or_ else av >= bv) else right

# we can't ever shortcut to true because the column could be null
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
if not or_ and av <= bv:
return exp.false()
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
if not or_ and av >= bv:
return exp.false()
elif isinstance(a, exp.EQ):
if isinstance(b, exp.LT):
return exp.false() if av >= bv else a
if isinstance(b, exp.LTE):
return exp.false() if av > bv else a
if isinstance(b, exp.GT):
return exp.false() if av <= bv else a
if isinstance(b, exp.GTE):
return exp.false() if av < bv else a
if isinstance(b, exp.NEQ):
return exp.false() if av == bv else a
return None


def remove_compliments(expression):
Expand All @@ -135,7 +228,7 @@ def remove_compliments(expression):
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
compliment = FALSE if isinstance(expression, exp.And) else TRUE
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()

for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
Expand Down Expand Up @@ -211,27 +304,7 @@ def absorb_and_eliminate(expression):

def simplify_literals(expression):
if isinstance(expression, exp.Binary):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)

while queue:
a = queue.popleft()

for b in queue:
result = _simplify_binary(expression, a, b)

if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)

if len(operands) < size:
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return _flat_simplify(expression, _simplify_binary)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
Expand All @@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b):

if c == NULL:
if isinstance(a, exp.Literal):
return TRUE if not_ else FALSE
return exp.true() if not_ else exp.false()
if a == NULL:
return FALSE if not_ else TRUE
elif isinstance(expression, exp.NullSafeEQ):
if a == b:
return TRUE
elif isinstance(expression, exp.NullSafeNEQ):
if a == b:
return FALSE
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
return None
elif NULL in (a, b):
return NULL

if isinstance(expression, exp.EQ) and a == b:
return TRUE
return exp.null()

if a.is_number and b.is_number:
a = int(a.name) if a.is_int else Decimal(a.name)
Expand Down Expand Up @@ -388,4 +454,27 @@ def date_literal(date):


def boolean_literal(condition):
return TRUE if condition else FALSE
return exp.true() if condition else exp.false()


def _flat_simplify(expression, simplifier):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)

while queue:
a = queue.popleft()

for b in queue:
result = simplifier(expression, a, b)

if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)

if len(operands) < size:
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
return expression
Loading

0 comments on commit e24d031

Please sign in to comment.