From b7153ea8b00e2e8230a3ed23f09b73075f774b89 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Thu, 14 Sep 2023 11:53:23 -0600 Subject: [PATCH] fix(flink): rewrite `ops.Clip` using if statements --- ibis/backends/flink/registry.py | 17 ++++++++++++++++- ibis/backends/flink/translator.py | 6 ++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/ibis/backends/flink/registry.py b/ibis/backends/flink/registry.py index db7f55ab8a54..76b80126b41d 100644 --- a/ibis/backends/flink/registry.py +++ b/ibis/backends/flink/registry.py @@ -13,7 +13,7 @@ from ibis.backends.base.sql.registry import ( operation_registry as base_operation_registry, ) -from ibis.backends.flink.utils import translate_literal +from ibis.backends.flink.utils import _to_pyflink_types, translate_literal from ibis.common.temporal import TimestampUnit if TYPE_CHECKING: @@ -185,6 +185,20 @@ def _window(translator: ExprTranslator, op: ops.Node) -> str: return result +def _clip(translator: ExprTranslator, op: ops.Node) -> str: + arg = translator.translate(op.arg) + + if op.upper is not None: + upper = translator.translate(op.upper) + arg = f"IF({arg} > {upper}, {upper}, {arg})" + + if op.lower is not None: + lower = translator.translate(op.lower) + arg = f"IF({arg} < {lower}, {lower}, {arg})" + + return f"CAST({arg} AS {_to_pyflink_types[type(op.dtype)]!s})" + + def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str: left = translator.translate(op.left) right = translator.translate(op.right) @@ -219,6 +233,7 @@ def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str: ops.Where: _filter, ops.TimestampFromUNIX: _timestamp_from_unix, ops.Window: _window, + ops.Clip: _clip, # Binary operations ops.Power: fixed_arity("power", 2), ops.FloorDivide: _floor_divide, diff --git a/ibis/backends/flink/translator.py b/ibis/backends/flink/translator.py index ad88921b48bd..cae7c3395241 100644 --- a/ibis/backends/flink/translator.py +++ b/ibis/backends/flink/translator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ibis.expr.operations as ops from ibis.backends.base.sql.compiler import ExprTranslator from ibis.backends.flink.registry import operation_registry @@ -9,3 +10,8 @@ class FlinkExprTranslator(ExprTranslator): "hive" # TODO: neither sqlglot nor sqlalchemy supports flink dialect ) _registry = operation_registry + + +@FlinkExprTranslator.rewrites(ops.Clip) +def _clip_no_op(op): + return op