diff --git a/ibis/backends/flink/registry.py b/ibis/backends/flink/registry.py index 04d58de4a243..6283269c5973 100644 --- a/ibis/backends/flink/registry.py +++ b/ibis/backends/flink/registry.py @@ -39,6 +39,31 @@ def extract_field_formatter(translator: ExprTranslator, op: ops.Node) -> str: return extract_field_formatter +def _cast(translator: ExprTranslator, op: ops.generic.Cast) -> str: + if op.to.is_timestamp() and op.to.timezone: + arg_translated = translator.translate(op.arg) + return f"TO_TIMESTAMP(CONVERT_TZ(CAST({arg_translated} AS STRING), 'UTC+0', '{op.to.timezone}'))" + + from ibis.backends.base.sql.registry.main import cast + + return cast(translator=translator, op=op) + + +def _left_op_right(translator: ExprTranslator, op_node: ops.Node, op_sign: str) -> str: + """Utility to be used in operators that perform '{op.left} {op_sign} {op.right}'.""" + return f"{translator.translate(op_node.left)} {op_sign} {translator.translate(op_node.right)}" + + +def _interval_add(translator: ExprTranslator, op: ops.temporal.IntervalSubtract) -> str: + return _left_op_right(translator=translator, op_node=op, op_sign="+") + + +def _interval_subtract( + translator: ExprTranslator, op: ops.temporal.IntervalSubtract +) -> str: + return _left_op_right(translator=translator, op_node=op, op_sign="-") + + def _literal(translator: ExprTranslator, op: ops.Literal) -> str: from ibis.backends.flink.utils import translate_literal @@ -76,10 +101,21 @@ def _filter(translator: ExprTranslator, op: ops.Node) -> str: return f"CASE WHEN {bool_expr} THEN {true_expr} ELSE {false_null_expr} END" -def _timestamp_diff(translator: ExprTranslator, op: ops.Node) -> str: - left = translator.translate(op.left) - right = translator.translate(op.right) - return f"timestampdiff(second, {left}, {right})" +def _timestamp_add(translator: ExprTranslator, op: ops.temporal.TimestampAdd) -> str: + return _left_op_right(translator=translator, op_node=op, op_sign="+") + + +def _timestamp_diff(translator: ExprTranslator, op: ops.temporal.TimestampDiff) -> str: + return _left_op_right(translator=translator, op_node=op, op_sign="-") + + +def _timestamp_sub(translator: ExprTranslator, op: ops.temporal.TimestampSub) -> str: + table_column = op.left + interval = op.right + + table_column_translated = translator.translate(table_column) + interval_translated = translator.translate(interval) + return f"{table_column_translated} - {interval_translated}" def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str: @@ -96,6 +132,17 @@ def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str: return f"TO_TIMESTAMP_LTZ({numeric}, {precision})" +def _timestamp_from_ymdhms( + translator: ExprTranslator, op: ops.temporal.TimestampFromYMDHMS +) -> str: + year, month, day, hours, minutes, seconds = ( + f"CAST({translator.translate(e)} AS STRING)" + for e in [op.year, op.month, op.day, op.hours, op.minutes, op.seconds] + ) + concat_string = f"CONCAT({year}, '-', {month}, '-', {day}, ' ', {hours}, ':', {minutes}, ':', {seconds})" + return f"CAST({concat_string} AS TIMESTAMP)" + + def _format_window_start(translator: ExprTranslator, boundary): if boundary is None: return "UNBOUNDED PRECEDING" @@ -224,6 +271,54 @@ def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str: return f"FLOOR(({left}) / ({right}))" +def _array_index(translator: ExprTranslator, op: ops.arrays.ArrayIndex): + table_column = op.arg + index = op.index + + table_column_translated = translator.translate(table_column) + index_translated = translator.translate(index) + + return f"{table_column_translated} [ {index_translated} + 1 ]" + + +def _day_of_week_index( + translator: ExprTranslator, op: ops.temporal.DayOfWeekIndex +) -> str: + arg = translator.translate(op.arg) + return f"MOD(DAYOFWEEK({arg}) + 5, 7)" + + +def _date_add(translator: ExprTranslator, op: ops.temporal.DateAdd) -> str: + return _left_op_right(translator=translator, op_node=op, op_sign="+") + + +def _date_diff(translator: ExprTranslator, op: ops.temporal.DateDiff) -> str: + raise com.UnsupportedOperationError("DATE_DIFF is not supported in Flink.") + + +def _date_from_ymd(translator: ExprTranslator, op: ops.temporal.DateFromYMD) -> str: + year, month, day = op.year, op.month, op.day + date_string = f"{year.value}-{month.value}-{day.value}" + return f"CAST('{date_string}' AS DATE)" + + +def _date_sub(translator: ExprTranslator, op: ops.temporal.DateSub) -> str: + return _left_op_right(translator=translator, op_node=op, op_sign="-") + + +def extract_epoch_seconds(translator: ExprTranslator, op: ops.Node) -> str: + arg = translator.translate(op.arg) + return f"UNIX_TIMESTAMP(CAST({arg} AS STRING))" + + +def _string_to_timestamp( + translator: ExprTranslator, op: ops.temporal.StringToTimestamp +) -> str: + arg = translator.translate(op.arg) + format_string = translator.translate(op.format_str) + return f"TO_TIMESTAMP({arg}, {format_string})" + + operation_registry.update( { # Unary operations @@ -240,6 +335,7 @@ def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str: ops.RegexSearch: fixed_arity("regexp", 2), # Timestamp operations ops.Date: _date, + ops.ExtractEpochSeconds: extract_epoch_seconds, ops.ExtractYear: _extract_field("year"), # equivalent to YEAR(date) ops.ExtractMonth: _extract_field("month"), # equivalent to MONTH(date) ops.ExtractDay: _extract_field("day"), # equivalent to DAYOFMONTH(date) @@ -249,16 +345,32 @@ def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str: ops.ExtractHour: _extract_field("hour"), # equivalent to HOUR(timestamp) ops.ExtractMinute: _extract_field("minute"), # equivalent to MINUTE(timestamp) ops.ExtractSecond: _extract_field("second"), # equivalent to SECOND(timestamp) + ops.ExtractMillisecond: _extract_field("millisecond"), + ops.ExtractMicrosecond: _extract_field("microsecond"), # Other operations + ops.Cast: _cast, + ops.IntervalAdd: _interval_add, + ops.IntervalSubtract: _interval_subtract, ops.Literal: _literal, ops.TryCast: _try_cast, ops.IfElse: _filter, + ops.TimestampAdd: _timestamp_add, ops.TimestampDiff: _timestamp_diff, ops.TimestampFromUNIX: _timestamp_from_unix, + ops.TimestampFromYMDHMS: _timestamp_from_ymdhms, + ops.TimestampSub: _timestamp_sub, ops.Window: _window, ops.Clip: _clip, # Binary operations ops.Power: fixed_arity("power", 2), ops.FloorDivide: _floor_divide, + # Temporal functions + ops.ArrayIndex: _array_index, + ops.DateAdd: _date_add, + ops.DateDiff: _date_diff, + ops.DateFromYMD: _date_from_ymd, + ops.DateSub: _date_sub, + ops.DayOfWeekIndex: _day_of_week_index, + ops.StringToTimestamp: _string_to_timestamp, } )