Skip to content

Commit

Permalink
perf(sql): avoid parenthesizing chains of commutative operators
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Sep 3, 2024
1 parent 46eee14 commit a84285c
Show file tree
Hide file tree
Showing 31 changed files with 331 additions and 156 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
SELECT
(
"t0"."int_col" + "t0"."tinyint_col"
) + "t0"."double_col" AS "Add(Add(int_col, tinyint_col), double_col)"
"t0"."int_col" + "t0"."tinyint_col" + "t0"."double_col" AS "Add(Add(int_col, tinyint_col), double_col)"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
SELECT
"t1"."key" AS "key",
SUM((
(
"t1"."value" + 1
) + 2
) + 3) AS "abc"
SUM("t1"."value" + 1 + 2 + 3) AS "abc"
FROM (
SELECT
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
SELECT
"t1"."key" AS "key",
SUM((
(
"t1"."value" + 1
) + 2
) + 3) AS "foo"
SUM("t1"."value" + 1 + 2 + 3) AS "foo"
FROM (
SELECT
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT `t1`.`key`, SUM(((`t1`.`value` + 1) + 2) + 3) AS `abc` FROM (SELECT * FROM `t0` AS `t0` WHERE `t0`.`value` = 42) AS `t1` GROUP BY 1
SELECT `t1`.`key`, SUM(`t1`.`value` + 1 + 2 + 3) AS `abc` FROM (SELECT * FROM `t0` AS `t0` WHERE `t0`.`value` = 42) AS `t1` GROUP BY 1
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT `t1`.`key`, SUM(((`t1`.`value` + 1) + 2) + 3) AS `foo` FROM (SELECT * FROM `t0` AS `t0` WHERE `t0`.`value` = 42) AS `t1` GROUP BY 1
SELECT `t1`.`key`, SUM(`t1`.`value` + 1 + 2 + 3) AS `foo` FROM (SELECT * FROM `t0` AS `t0` WHERE `t0`.`value` = 42) AS `t1` GROUP BY 1
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
SELECT
(
`t0`.`a` + `t0`.`b`
) + `t0`.`c` AS `Add(Add(a, b), c)`
`t0`.`a` + `t0`.`b` + `t0`.`c` AS `Add(Add(a, b), c)`
FROM `alltypes` AS `t0`
203 changes: 65 additions & 138 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,20 +239,6 @@ def __getitem__(self, key: str) -> sge.Column:
STAR = sge.Star()


def parenthesize_inputs(f):
"""Decorate a translation rule to parenthesize inputs."""

def wrapper(self, op, *, left, right):
return f(
self,
op,
left=self._add_parens(op.left, left),
right=self._add_parens(op.right, right),
)

return wrapper


@public
class SQLGlotCompiler(abc.ABC):
__slots__ = "f", "v"
Expand Down Expand Up @@ -390,45 +376,50 @@ class SQLGlotCompiler(abc.ABC):
ops.Uppercase: "upper",
}

BINARY_INFIX_OPS = (
# Binary operations
ops.Add,
ops.Subtract,
ops.Multiply,
ops.Divide,
ops.Modulus,
ops.Power,
BINARY_INFIX_OPS = {
# Numeric
ops.Add: (sge.Add, True),
ops.Subtract: (sge.Sub, False),
ops.Multiply: (sge.Mul, True),
ops.Divide: (sge.Div, False),
ops.Modulus: (sge.Mod, False),
ops.Power: (sge.Pow, False),
# Comparisons
ops.GreaterEqual,
ops.Greater,
ops.LessEqual,
ops.Less,
ops.Equals,
ops.NotEquals,
# Boolean comparisons
ops.And,
ops.Or,
ops.Xor,
# Bitwise business
ops.BitwiseLeftShift,
ops.BitwiseRightShift,
ops.BitwiseAnd,
ops.BitwiseOr,
ops.BitwiseXor,
# Time arithmetic
ops.DateAdd,
ops.DateSub,
ops.DateDiff,
ops.TimestampAdd,
ops.TimestampSub,
ops.TimestampDiff,
# Interval Marginalia
ops.IntervalAdd,
ops.IntervalMultiply,
ops.IntervalSubtract,
)
ops.GreaterEqual: (sge.GTE, False),
ops.Greater: (sge.GT, False),
ops.LessEqual: (sge.LTE, False),
ops.Less: (sge.LT, False),
ops.Equals: (sge.EQ, False),
ops.NotEquals: (sge.NEQ, False),
# Logical
ops.And: (sge.And, True),
ops.Or: (sge.Or, True),
ops.Xor: (sge.Xor, True),
# Bitwise
ops.BitwiseLeftShift: (sge.BitwiseLeftShift, False),
ops.BitwiseRightShift: (sge.BitwiseRightShift, False),
ops.BitwiseAnd: (sge.BitwiseAnd, True),
ops.BitwiseOr: (sge.BitwiseOr, True),
ops.BitwiseXor: (sge.BitwiseXor, True),
# Date
ops.DateAdd: (sge.Add, True),
ops.DateSub: (sge.Sub, False),
ops.DateDiff: (sge.Sub, False),
# Time
ops.TimeAdd: (sge.Add, True),
ops.TimeSub: (sge.Sub, False),
ops.TimeDiff: (sge.Sub, False),
# Timestamp
ops.TimestampAdd: (sge.Add, True),
ops.TimestampSub: (sge.Sub, False),
ops.TimestampDiff: (sge.Sub, False),
# Interval
ops.IntervalAdd: (sge.Add, True),
ops.IntervalMultiply: (sge.Mul, True),
ops.IntervalSubtract: (sge.Sub, False),
}

NEEDS_PARENS = BINARY_INFIX_OPS + (ops.IsNull,)
NEEDS_PARENS = tuple(BINARY_INFIX_OPS) + (ops.IsNull,)

# Constructed dynamically in `__init_subclass__` from their respective
# UPPERCASE values to handle inheritance, do not modify directly here.
Expand Down Expand Up @@ -466,6 +457,19 @@ def impl(self, _, *, _name: str = target_name, **kw):
for op, target_name in cls.SIMPLE_OPS.items():
setattr(cls, methodname(op), make_impl(op, target_name))

# Define binary op methods, only if BINARY_INFIX_OPS is set on the
# compiler class.
if binops := cls.__dict__.get("BINARY_INFIX_OPS", {}):

def make_binop(sge_cls, associative):
def impl(self, op, *, left, right):
return self.binop(sge_cls, op, left, right, associative=associative)

return impl

for op, (sge_cls, associative) in binops.items():
setattr(cls, methodname(op), make_binop(sge_cls, associative))

# unconditionally raise an exception for unsupported operations
#
# these *must* be defined after SIMPLE_OPS to handle compilers that
Expand Down Expand Up @@ -1501,93 +1505,16 @@ def visit_SQLQueryResult(self, op, *, query, schema, source):
def visit_RegexExtract(self, op, *, arg, pattern, index):
return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect)

@parenthesize_inputs
def visit_Add(self, op, *, left, right):
return sge.Add(this=left, expression=right)

visit_DateAdd = visit_TimestampAdd = visit_IntervalAdd = visit_Add

@parenthesize_inputs
def visit_Subtract(self, op, *, left, right):
return sge.Sub(this=left, expression=right)

visit_DateSub = visit_DateDiff = visit_TimestampSub = visit_TimestampDiff = (
visit_IntervalSubtract
) = visit_Subtract

@parenthesize_inputs
def visit_Multiply(self, op, *, left, right):
return sge.Mul(this=left, expression=right)

visit_IntervalMultiply = visit_Multiply

@parenthesize_inputs
def visit_Divide(self, op, *, left, right):
return sge.Div(this=left, expression=right)

@parenthesize_inputs
def visit_Modulus(self, op, *, left, right):
return sge.Mod(this=left, expression=right)

@parenthesize_inputs
def visit_Power(self, op, *, left, right):
return sge.Pow(this=left, expression=right)

@parenthesize_inputs
def visit_GreaterEqual(self, op, *, left, right):
return sge.GTE(this=left, expression=right)

@parenthesize_inputs
def visit_Greater(self, op, *, left, right):
return sge.GT(this=left, expression=right)

@parenthesize_inputs
def visit_LessEqual(self, op, *, left, right):
return sge.LTE(this=left, expression=right)

@parenthesize_inputs
def visit_Less(self, op, *, left, right):
return sge.LT(this=left, expression=right)

@parenthesize_inputs
def visit_Equals(self, op, *, left, right):
return sge.EQ(this=left, expression=right)

@parenthesize_inputs
def visit_NotEquals(self, op, *, left, right):
return sge.NEQ(this=left, expression=right)

@parenthesize_inputs
def visit_And(self, op, *, left, right):
return sge.And(this=left, expression=right)

@parenthesize_inputs
def visit_Or(self, op, *, left, right):
return sge.Or(this=left, expression=right)

@parenthesize_inputs
def visit_Xor(self, op, *, left, right):
return sge.Xor(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseLeftShift(self, op, *, left, right):
return sge.BitwiseLeftShift(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseRightShift(self, op, *, left, right):
return sge.BitwiseRightShift(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseAnd(self, op, *, left, right):
return sge.BitwiseAnd(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseOr(self, op, *, left, right):
return sge.BitwiseOr(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseXor(self, op, *, left, right):
return sge.BitwiseXor(this=left, expression=right)
def binop(self, sg_expr, op, left, right, *, associative=False):
# If the op is associative we can skip parenthesizing ops of the same
# type if they're on the left, since they would evaluate the same.
# SQLGlot has an optimizer for generating long sql chains of the same
# op of this form without recursion, by avoiding parenthesis in this
# common case we can make use of this optimization to handle large
# operator chains.
if not associative or type(op) is not type(op.left):
left = self._add_parens(op.left, left)
return sg_expr(this=left, expression=self._add_parens(op.right, right))

def visit_Undefined(self, op, **_):
raise com.OperationNotDefinedError(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" + "t0"."b" + "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" + (
"t0"."b" + "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" + "t0"."b" + "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" + (
"t0"."b" + "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" + "t0"."b" + "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" + (
"t0"."b" + "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" + "t0"."b" + "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" + (
"t0"."b" + "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" + "t0"."b" + "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" + (
"t0"."b" + "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" AND "t0"."b" AND "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" AND (
"t0"."b" AND "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" & "t0"."b" & "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" & (
"t0"."b" & "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SELECT
(
"t0"."a" << "t0"."b"
) << "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" << (
"t0"."b" << "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SELECT
(
"t0"."a" % "t0"."b"
) % "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" % (
"t0"."b" % "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" * "t0"."b" * "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" * (
"t0"."b" * "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" * "t0"."b" * "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" * (
"t0"."b" * "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
"t0"."a" OR "t0"."b" OR "t0"."c" AS "x"
FROM "t" AS "t0" --- op(op(a, b), c);
SELECT
"t0"."a" OR (
"t0"."b" OR "t0"."c"
) AS "x"
FROM "t" AS "t0" --- op(a, op(b, c));
Loading

0 comments on commit a84285c

Please sign in to comment.