Skip to content

Commit

Permalink
fix(flink): rewrite ops.Clip using if statements
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman authored and cpcloud committed Sep 14, 2023
1 parent c383f62 commit b7153ea
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
17 changes: 16 additions & 1 deletion ibis/backends/flink/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ibis.backends.base.sql.registry import (
operation_registry as base_operation_registry,
)
from ibis.backends.flink.utils import translate_literal
from ibis.backends.flink.utils import _to_pyflink_types, translate_literal
from ibis.common.temporal import TimestampUnit

if TYPE_CHECKING:
Expand Down Expand Up @@ -185,6 +185,20 @@ def _window(translator: ExprTranslator, op: ops.Node) -> str:
return result


def _clip(translator: ExprTranslator, op: ops.Node) -> str:
arg = translator.translate(op.arg)

if op.upper is not None:
upper = translator.translate(op.upper)
arg = f"IF({arg} > {upper}, {upper}, {arg})"

if op.lower is not None:
lower = translator.translate(op.lower)
arg = f"IF({arg} < {lower}, {lower}, {arg})"

return f"CAST({arg} AS {_to_pyflink_types[type(op.dtype)]!s})"


def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str:
left = translator.translate(op.left)
right = translator.translate(op.right)
Expand Down Expand Up @@ -219,6 +233,7 @@ def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str:
ops.Where: _filter,
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.Window: _window,
ops.Clip: _clip,
# Binary operations
ops.Power: fixed_arity("power", 2),
ops.FloorDivide: _floor_divide,
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/flink/translator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import ibis.expr.operations as ops
from ibis.backends.base.sql.compiler import ExprTranslator
from ibis.backends.flink.registry import operation_registry

Expand All @@ -9,3 +10,8 @@ class FlinkExprTranslator(ExprTranslator):
"hive" # TODO: neither sqlglot nor sqlalchemy supports flink dialect
)
_registry = operation_registry


@FlinkExprTranslator.rewrites(ops.Clip)
def _clip_no_op(op):
return op

0 comments on commit b7153ea

Please sign in to comment.