Skip to content

Commit

Permalink
feat(flink): implement translation rules and tests for over aggregati…
Browse files Browse the repository at this point in the history
…on in Flink backend
  • Loading branch information
chloeh13q authored and jcrist committed Jul 26, 2023
1 parent 4b57f7f commit e173cd7
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 78 deletions.
79 changes: 2 additions & 77 deletions ibis/backends/flink/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,13 @@

from __future__ import annotations

import datetime
import functools
import math

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sql.compiler import Compiler, Select, SelectBuilder
from ibis.backends.flink.translator import FlinkExprTranslator
from ibis.backends.flink.utils import (
DaysToSecondsInterval,
YearsToMonthsInterval,
format_precision,
)
from ibis.backends.flink.utils import translate_literal


class FlinkSelectBuilder(SelectBuilder):
Expand Down Expand Up @@ -52,7 +46,6 @@ class FlinkCompiler(Compiler):


def translate(op: ops.TableNode) -> str:
# TODO(chloeh13q): support translation of non-select exprs (e.g. literals)
return translate_op(op)


Expand All @@ -63,79 +56,11 @@ def translate_op(op: ops.TableNode) -> str:

@translate_op.register(ops.Literal)
def _literal(op: ops.Literal) -> str:
value = op.value
dtype = op.output_dtype

if dtype.is_boolean():
# TODO(chloeh13q): Flink supports a third boolean called "UNKNOWN"
return 'TRUE' if value else 'FALSE'
elif dtype.is_string():
quoted = value.replace("'", "''").replace("\\", "\\\\")
return f"'{quoted}'"
elif dtype.is_date():
if isinstance(value, datetime.date):
value = value.strftime('%Y-%m-%d')
return repr(value)
elif dtype.is_numeric():
if math.isnan(value):
raise ValueError("NaN is not supported in Flink SQL")
elif math.isinf(value):
raise ValueError("Infinity is not supported in Flink SQL")
return repr(value)
elif dtype.is_timestamp():
# TODO(chloeh13q): support timestamp with local timezone
if isinstance(value, datetime.datetime):
fmt = '%Y-%m-%d %H:%M:%S'
# datetime.datetime only supports resolution up to microseconds, even
# though Flink supports fractional precision up to 9 digits. We will
# need to use numpy or pandas datetime types for higher resolutions.
if value.microsecond:
fmt += '.%f'
return 'TIMESTAMP ' + repr(value.strftime(fmt))
raise NotImplementedError(f'No translation rule for timestamp {value}')
elif dtype.is_time():
return f"TIME '{value}'"
elif dtype.is_interval():
return f"INTERVAL {translate_interval(value, dtype)}"
raise NotImplementedError(f'No translation rule for {dtype}')
return translate_literal(op)


@translate_op.register(ops.Selection)
@translate_op.register(ops.Aggregation)
@translate_op.register(ops.Limit)
def _(op: ops.Selection | ops.Aggregation | ops.Limit) -> str:
return FlinkCompiler.to_sql(op) # to_sql uses to_ast, which builds a select tree


def translate_interval(value, dtype):
"""Convert interval to Flink SQL type.
Flink supports only two types of temporal intervals: day-time intervals with up to nanosecond
granularity or year-month intervals with up to month granularity.
An interval of year-month consists of +years-months with values ranging from -9999-11 to +9999-11.
An interval of day-time consists of +days hours:minutes:seconds.fractional with values ranging from
-999999 23:59:59.999999999 to +999999 23:59:59.999999999.
The value representation is the same for all types of resolutions.
For example, an interval of months of 50 is always represented in an interval-of-years-to-months
format (with default year precision): +04-02; an interval of seconds of 70 is always represented in
an interval-of-days-to-seconds format (with default precisions): +00 00:01:10.000000.
"""
if dtype.unit in YearsToMonthsInterval.units:
interval = YearsToMonthsInterval(value, dtype.unit.value)
else:
interval = DaysToSecondsInterval(value, dtype.unit.value)

interval_segments = interval.interval_segments
nonzero_interval_segments = {k: v for k, v in interval_segments.items() if v != 0}

# YEAR, MONTH, DAY, HOUR, MINUTE, SECOND
if len(nonzero_interval_segments) == 1:
unit = next(iter(nonzero_interval_segments))
value = nonzero_interval_segments[unit]
return f"'{value}' {unit.value}{format_precision(value, unit)}"

# YEAR TO MONTH, DAY TO SECOND
return interval.format_as_string()
120 changes: 119 additions & 1 deletion ibis/backends/flink/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from typing import TYPE_CHECKING

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sql.registry import helpers
from ibis.backends.base.sql.registry import helpers, window
from ibis.backends.base.sql.registry import (
operation_registry as base_operation_registry,
)
from ibis.backends.flink.utils import translate_literal
from ibis.common.temporal import TimestampUnit

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,6 +61,120 @@ def _filter(translator: ExprTranslator, op: ops.Node) -> str:
return f"CASE WHEN {bool_expr} THEN {true_expr} ELSE {false_null_expr} END"


def _literal(translator: ExprTranslator, op: ops.Literal) -> str:
return translate_literal(op)


def _format_window_start(translator: ExprTranslator, boundary):
if boundary is None:
return 'UNBOUNDED PRECEDING'

if isinstance(boundary.value, ops.Literal) and boundary.value.value == 0:
return "CURRENT ROW"

value = translator.translate(boundary.value)
return f'{value} PRECEDING'


def _format_window_end(translator: ExprTranslator, boundary):
if boundary is None:
raise com.UnsupportedOperationError(
"OVER RANGE FOLLOWING windows are not supported in Flink yet"
)

value = boundary.value
if isinstance(value, ops.Cast):
value = boundary.value.arg
if isinstance(value, ops.Literal):
if value.value != 0:
raise com.UnsupportedOperationError(
"OVER RANGE FOLLOWING windows are not supported in Flink yet"
)

return "CURRENT ROW"


def _format_window_frame(translator: ExprTranslator, func, frame):
components = []

if frame.group_by:
partition_args = ', '.join(map(translator.translate, frame.group_by))
components.append(f'PARTITION BY {partition_args}')

(order_by,) = frame.order_by
if order_by.descending is True:
raise com.UnsupportedOperationError(
"Flink only supports windows ordered in ASCENDING mode"
)
components.append(f'ORDER BY {translator.translate(order_by)}')

if frame.start is None and frame.end is None:
# no-op, default is full sample
pass
elif not isinstance(func, translator._forbids_frame_clause):
# [NOTE] Flink allows
# "ROWS BETWEEN INTERVAL [...] PRECEDING AND CURRENT ROW"
# but not
# "RANGE BETWEEN [...] PRECEDING AND CURRENT ROW",
# but `.over(rows=(-ibis.interval(...), 0)` is not allowed in Ibis
if isinstance(frame, ops.RangeWindowFrame):
if not frame.start.value.output_dtype.is_interval():
# [TODO] need to expand support for range-based interval windowing on expr
# side, for now only ibis intervals can be used
raise com.UnsupportedOperationError(
"Data Type mismatch between ORDER BY and RANGE clause"
)

start = _format_window_start(translator, frame.start)
end = _format_window_end(translator, frame.end)

frame = f'{frame.how.upper()} BETWEEN {start} AND {end}'
components.append(frame)

return 'OVER ({})'.format(' '.join(components))


def _window(translator: ExprTranslator, op: ops.Node) -> str:
frame = op.frame
if not frame.order_by:
raise com.UnsupportedOperationError(
"Flink engine does not support generic window clause with no order by"
)
if len(frame.order_by) > 1:
raise com.UnsupportedOperationError(
"Windows in Flink can only be ordered by a single time column"
)

_unsupported_reductions = translator._unsupported_reductions

func = op.func.__window_op__

if isinstance(func, _unsupported_reductions):
raise com.UnsupportedOperationError(
f'{type(func)} is not supported in window functions'
)

if isinstance(func, ops.CumulativeOp):
arg = window.cumulative_to_window(translator, func, op.frame)
return translator.translate(arg)

if isinstance(frame, ops.RowsWindowFrame):
if frame.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented for SQL-based backends.'
)

window_formatted = _format_window_frame(translator, func, frame)

arg_formatted = translator.translate(func.__window_op__)
result = f'{arg_formatted} {window_formatted}'

if isinstance(func, ops.RankBase):
return f'({result} - 1)'
else:
return result


operation_registry.update(
{
ops.CountStar: _count_star,
Expand All @@ -71,7 +187,9 @@ def _filter(translator: ExprTranslator, op: ops.Node) -> str:
ops.ExtractHour: _extract_field("hour"), # equivalent to HOUR(timestamp)
ops.ExtractMinute: _extract_field("minute"), # equivalent to MINUTE(timestamp)
ops.ExtractSecond: _extract_field("second"), # equivalent to SECOND(timestamp)
ops.Literal: _literal,
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.Where: _filter,
ops.Window: _window,
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC RANGE BETWEEN INTERVAL '00 08:20:00.000000' DAY TO SECOND PRECEDING AND CURRENT ROW) AS `Sum(f)`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC ROWS BETWEEN 1000 PRECEDING AND CURRENT ROW) AS `Sum(f)`
FROM table t0
88 changes: 88 additions & 0 deletions ibis/backends/flink/tests/test_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import pytest
from pytest import param

import ibis
from ibis.backends.flink.compiler.core import translate
from ibis.common.exceptions import UnsupportedOperationError


def test_window_requires_order_by(simple_table):
expr = simple_table.mutate(simple_table.c - simple_table.c.mean())
with pytest.raises(
UnsupportedOperationError,
match="Flink engine does not support generic window clause with no order by",
):
translate(expr.as_table().op())


def test_window_does_not_support_multiple_order_by(simple_table):
expr = simple_table.f.sum().over(
rows=(-1, 1),
group_by=[simple_table.g, simple_table.a],
order_by=[simple_table.f, simple_table.d],
)
with pytest.raises(
UnsupportedOperationError,
match="Windows in Flink can only be ordered by a single time column",
):
translate(expr.as_table().op())


def test_window_does_not_support_desc_order(simple_table):
expr = simple_table.f.sum().over(
rows=(-1, 1),
group_by=[simple_table.g, simple_table.a],
order_by=[simple_table.f.desc()],
)
with pytest.raises(
UnsupportedOperationError,
match="Flink only supports windows ordered in ASCENDING mode",
):
translate(expr.as_table().op())


@pytest.mark.parametrize(
("window", "err"),
[
param(
{"rows": (-1, 1)},
"OVER RANGE FOLLOWING windows are not supported in Flink yet",
id="bounded_rows_following",
),
param(
{"rows": (-1, None)},
"OVER RANGE FOLLOWING windows are not supported in Flink yet",
id="unbounded_rows_following",
),
param(
{"rows": (-500, 1)},
"OVER RANGE FOLLOWING windows are not supported in Flink yet",
id="casted_bounded_rows_following",
),
param(
{"range": (-1000, 0)},
"Data Type mismatch between ORDER BY and RANGE clause",
id="int_range",
),
],
)
def test_window_invalid_start_end(simple_table, window, err):
expr = simple_table.f.sum().over(**window, order_by=simple_table.f)
with pytest.raises(UnsupportedOperationError, match=err):
translate(expr.as_table().op())


def test_range_window(snapshot, simple_table):
expr = simple_table.f.sum().over(
range=(-ibis.interval(minutes=500), 0), order_by=simple_table.f
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_rows_window(snapshot, simple_table):
expr = simple_table.f.sum().over(rows=(-1000, 0), order_by=simple_table.f)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")
Loading

0 comments on commit e173cd7

Please sign in to comment.