Skip to content

Commit

Permalink
refactor(sql): simplify paren handling for binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Sep 11, 2024
1 parent bac76ff commit 192be96
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 45 deletions.
99 changes: 56 additions & 43 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,48 +374,61 @@ class SQLGlotCompiler(abc.ABC):

BINARY_INFIX_OPS = {
# Numeric
ops.Add: (sge.Add, True),
ops.Subtract: (sge.Sub, False),
ops.Multiply: (sge.Mul, True),
ops.Divide: (sge.Div, False),
ops.Modulus: (sge.Mod, False),
ops.Power: (sge.Pow, False),
ops.Add: sge.Add,
ops.Subtract: sge.Sub,
ops.Multiply: sge.Mul,
ops.Divide: sge.Div,
ops.Modulus: sge.Mod,
ops.Power: sge.Pow,
# Comparisons
ops.GreaterEqual: (sge.GTE, False),
ops.Greater: (sge.GT, False),
ops.LessEqual: (sge.LTE, False),
ops.Less: (sge.LT, False),
ops.Equals: (sge.EQ, False),
ops.NotEquals: (sge.NEQ, False),
ops.GreaterEqual: sge.GTE,
ops.Greater: sge.GT,
ops.LessEqual: sge.LTE,
ops.Less: sge.LT,
ops.Equals: sge.EQ,
ops.NotEquals: sge.NEQ,
# Logical
ops.And: (sge.And, True),
ops.Or: (sge.Or, True),
ops.Xor: (sge.Xor, True),
ops.And: sge.And,
ops.Or: sge.Or,
ops.Xor: sge.Xor,
# Bitwise
ops.BitwiseLeftShift: (sge.BitwiseLeftShift, False),
ops.BitwiseRightShift: (sge.BitwiseRightShift, False),
ops.BitwiseAnd: (sge.BitwiseAnd, True),
ops.BitwiseOr: (sge.BitwiseOr, True),
ops.BitwiseXor: (sge.BitwiseXor, True),
ops.BitwiseLeftShift: sge.BitwiseLeftShift,
ops.BitwiseRightShift: sge.BitwiseRightShift,
ops.BitwiseAnd: sge.BitwiseAnd,
ops.BitwiseOr: sge.BitwiseOr,
ops.BitwiseXor: sge.BitwiseXor,
# Date
ops.DateAdd: (sge.Add, True),
ops.DateSub: (sge.Sub, False),
ops.DateDiff: (sge.Sub, False),
ops.DateAdd: sge.Add,
ops.DateSub: sge.Sub,
ops.DateDiff: sge.Sub,
# Time
ops.TimeAdd: (sge.Add, True),
ops.TimeSub: (sge.Sub, False),
ops.TimeDiff: (sge.Sub, False),
ops.TimeAdd: sge.Add,
ops.TimeSub: sge.Sub,
ops.TimeDiff: sge.Sub,
# Timestamp
ops.TimestampAdd: (sge.Add, True),
ops.TimestampSub: (sge.Sub, False),
ops.TimestampDiff: (sge.Sub, False),
ops.TimestampAdd: sge.Add,
ops.TimestampSub: sge.Sub,
ops.TimestampDiff: sge.Sub,
# Interval
ops.IntervalAdd: (sge.Add, True),
ops.IntervalMultiply: (sge.Mul, True),
ops.IntervalSubtract: (sge.Sub, False),
ops.IntervalAdd: sge.Add,
ops.IntervalMultiply: sge.Mul,
ops.IntervalSubtract: sge.Sub,
}

NEEDS_PARENS = tuple(BINARY_INFIX_OPS) + (ops.IsNull,)
# A set of SQLGlot classes that may need to be parenthesized
SQLGLOT_NEEDS_PARENS = set(BINARY_INFIX_OPS.values()).union((sge.Is,))

# A set of SQLGlot classes that are associative operations
SQLGLOT_ASSOCIATIVE_OPS = {
sge.Add,
sge.Mul,
sge.And,
sge.Or,
sge.Xor,
sge.BitwiseAnd,
sge.BitwiseOr,
sge.BitwiseXor,
}

# Constructed dynamically in `__init_subclass__` from their respective
# UPPERCASE values to handle inheritance, do not modify directly here.
Expand Down Expand Up @@ -457,14 +470,14 @@ def impl(self, _, *, _name: str = target_name, **kw):
# compiler class.
if binops := cls.__dict__.get("BINARY_INFIX_OPS", {}):

def make_binop(sge_cls, associative):
def make_binop(sge_cls):
def impl(self, op, *, left, right):
return self.binop(sge_cls, op, left, right, associative=associative)
return self.binop(sge_cls, left, right)

return impl

for op, (sge_cls, associative) in binops.items():
setattr(cls, methodname(op), make_binop(sge_cls, associative))
for op, sge_cls in binops.items():
setattr(cls, methodname(op), make_binop(sge_cls))

# unconditionally raise an exception for unsupported operations
#
Expand Down Expand Up @@ -1384,8 +1397,8 @@ def visit_Aggregate(self, op, *, parent, groups, metrics):
return sel

@classmethod
def _add_parens(cls, op, sg_expr):
if isinstance(op, cls.NEEDS_PARENS):
def _add_parens(cls, sg_expr):
if type(sg_expr) in cls.SQLGLOT_NEEDS_PARENS:
return sge.paren(sg_expr, copy=False)
return sg_expr

Expand Down Expand Up @@ -1499,16 +1512,16 @@ def visit_SQLQueryResult(self, op, *, query, schema, source):
def visit_RegexExtract(self, op, *, arg, pattern, index):
return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect)

def binop(self, sg_expr, op, left, right, *, associative=False):
def binop(self, sg_cls, left, right):
# If the op is associative we can skip parenthesizing ops of the same
# type if they're on the left, since they would evaluate the same.
# SQLGlot has an optimizer for generating long sql chains of the same
# op of this form without recursion, by avoiding parenthesis in this
# common case we can make use of this optimization to handle large
# operator chains.
if not associative or type(op) is not type(op.left):
left = self._add_parens(op.left, left)
return sg_expr(this=left, expression=self._add_parens(op.right, right))
if not (sg_cls in self.SQLGLOT_ASSOCIATIVE_OPS and type(left) is sg_cls):
left = self._add_parens(left)
return sg_cls(this=left, expression=self._add_parens(right))

def visit_Undefined(self, op, **_):
raise com.OperationNotDefinedError(
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def visit_ArrayRepeat(self, op, *, arg, times):
return self.f.arrayFlatten(self.f.arrayMap(func, self.f.range(times)))

def visit_ArraySlice(self, op, *, arg, start, stop):
start = self._add_parens(op.start, start)
start = self._add_parens(start)
start_correct = self.if_(start < 0, start, start + 1)

if stop is not None:
stop = self._add_parens(op.stop, stop)
stop = self._add_parens(stop)

length = self.if_(
stop < 0,
Expand Down

0 comments on commit 192be96

Please sign in to comment.