Skip to content

Commit

Permalink
feat: implement window table valued functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chloeh13q authored and jcrist committed Oct 18, 2023
1 parent ff2ab08 commit a35a756
Show file tree
Hide file tree
Showing 35 changed files with 480 additions and 73 deletions.
7 changes: 7 additions & 0 deletions ibis/backends/base/sql/compiler/select_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def _collect(self, op, toplevel=False):
self._collect_PhysicalTable(op, toplevel=toplevel)
elif isinstance(op, ops.Join):
self._collect_Join(op, toplevel=toplevel)
elif isinstance(op, ops.WindowingTVF):
self._collect_WindowingTVF(op, toplevel=toplevel)
else:
raise NotImplementedError(type(op))

Expand Down Expand Up @@ -231,6 +233,11 @@ def _collect_SelfReference(self, op, toplevel=False):
if toplevel:
self._collect(op.table, toplevel=toplevel)

def _collect_WindowingTVF(self, op, toplevel=False):
if toplevel:
self.table_set = op
self.select_set = [op]

# --------------------------------------------------------------------
# Subquery analysis / extraction

Expand Down
9 changes: 0 additions & 9 deletions ibis/backends/flink/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
from __future__ import annotations

from public import public

from ibis.backends.flink.compiler.core import translate

public(
translate=translate,
)
99 changes: 81 additions & 18 deletions ibis/backends/flink/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql.compiler import (
Compiler,
Select,
SelectBuilder,
TableSetFormatter,
)
from ibis.backends.base.sql.registry import quote_identifier
from ibis.backends.flink.translator import FlinkExprTranslator


Expand All @@ -28,8 +30,22 @@ def _format_in_memory_table(self, op):
rows = ", ".join(f"({raw_row})" for raw_row in raw_rows)
return f"(VALUES {rows})"

def _format_window_tvf(self, op) -> str:
if isinstance(op, ops.TumbleWindowingTVF):
function_type = "TUMBLE"
elif isinstance(op, ops.HopWindowingTVF):
function_type = "HOP"
elif isinstance(op, ops.CumulateWindowingTVF):
function_type = "CUMULATE"
return f"TABLE({function_type}({format_windowing_tvf_params(op, self)}))"

def _format_table(self, op) -> str:
result = super()._format_table(op)
ctx = self.context
if isinstance(op, ops.WindowingTVF):
formatted_table = self._format_window_tvf(op)
return f"{formatted_table} {ctx.get_ref(op)}"
else:
result = super()._format_table(op)

ref_op = op
if isinstance(op, ops.SelfReference):
Expand Down Expand Up @@ -77,25 +93,72 @@ class FlinkCompiler(Compiler):

cheap_in_memory_tables = True

@classmethod
def to_sql(cls, node, context=None, params=None):
if isinstance(node, ir.Expr):
node = node.op()

def translate(op: ops.TableNode) -> str:
return translate_op(op)


@functools.singledispatch
def translate_op(op: ops.TableNode) -> str:
raise com.OperationNotDefinedError(f"No translation rule for {type(op)}")

if isinstance(node, ops.Literal):
from ibis.backends.flink.utils import translate_literal

@translate_op.register(ops.Literal)
def _literal(op: ops.Literal) -> str:
from ibis.backends.flink.utils import translate_literal
return translate_literal(node)

return translate_literal(op)
return super().to_sql(node, context, params)


@translate_op.register(ops.Selection)
@translate_op.register(ops.Aggregation)
@translate_op.register(ops.Limit)
def _(op: ops.Selection | ops.Aggregation | ops.Limit) -> str:
return FlinkCompiler.to_sql(op) # to_sql uses to_ast, which builds a select tree
@functools.singledispatch
def format_windowing_tvf_params(
op: ops.WindowingTVF, formatter: TableSetFormatter
) -> str:
raise com.OperationNotDefinedError(f"No formatting rule for {type(op)}")


@format_windowing_tvf_params.register(ops.TumbleWindowingTVF)
def _tumble_window_params(
op: ops.TumbleWindowingTVF, formatter: TableSetFormatter
) -> str:
return ", ".join(
filter(
None,
[
f"TABLE {quote_identifier(op.table.name)}",
f"DESCRIPTOR({formatter._translate(op.time_col)})",
formatter._translate(op.window_size),
formatter._translate(op.offset) if op.offset else None,
],
)
)


@format_windowing_tvf_params.register(ops.HopWindowingTVF)
def _hop_window_params(op: ops.HopWindowingTVF, formatter: TableSetFormatter) -> str:
return ", ".join(
filter(
None,
[
f"TABLE {quote_identifier(op.table.name)}",
f"DESCRIPTOR({formatter._translate(op.time_col)})",
formatter._translate(op.window_slide),
formatter._translate(op.window_size),
formatter._translate(op.offset) if op.offset else None,
],
)
)


@format_windowing_tvf_params.register(ops.CumulateWindowingTVF)
def _cumulate_window_params(
op: ops.CumulateWindowingTVF, formatter: TableSetFormatter
) -> str:
return ", ".join(
filter(
None,
[
f"TABLE {quote_identifier(op.table.name)}",
f"DESCRIPTOR({formatter._translate(op.time_col)})",
formatter._translate(op.window_step),
formatter._translate(op.window_size),
formatter._translate(op.offset) if op.offset else None,
],
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT t0.*
FROM TABLE(CUMULATE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '10' SECOND, INTERVAL '1' MINUTE)) t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT t0.*
FROM TABLE(HOP(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '1' MINUTE, INTERVAL '15' MINUTE)) t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT t0.*
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) t0
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import pytest
from pytest import param

from ibis.backends.flink.compiler.core import translate
import ibis


def test_translate_sum(snapshot, simple_table):
def test_sum(con, snapshot, simple_table):
expr = simple_table.a.sum()
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(str(result), "out.sql")


def test_translate_count_star(snapshot, simple_table):
def test_count_star(con, snapshot, simple_table):
expr = simple_table.group_by(simple_table.i).size()
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(str(result), "out.sql")


Expand All @@ -25,28 +25,28 @@ def test_translate_count_star(snapshot, simple_table):
param("s", id="timestamp_s"),
],
)
def test_translate_timestamp_from_unix(snapshot, simple_table, unit):
def test_timestamp_from_unix(con, snapshot, simple_table, unit):
expr = simple_table.d.to_timestamp(unit=unit)
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


def test_translate_complex_projections(snapshot, simple_table):
def test_complex_projections(con, snapshot, simple_table):
expr = (
simple_table.group_by(["a", "c"])
.aggregate(the_sum=simple_table.b.sum())
.group_by("a")
.aggregate(mad=lambda x: x.the_sum.abs().mean())
)
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


def test_translate_filter(snapshot, simple_table):
def test_filter(con, snapshot, simple_table):
expr = simple_table[
((simple_table.c > 0) | (simple_table.c < 0)) & simple_table.g.isin(["A", "B"])
]
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


Expand All @@ -64,50 +64,82 @@ def test_translate_filter(snapshot, simple_table):
"second",
],
)
def test_translate_extract_fields(snapshot, simple_table, kind):
def test_extract_fields(con, snapshot, simple_table, kind):
expr = getattr(simple_table.i, kind)().name("tmp")
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


def test_translate_complex_groupby_aggregation(snapshot, simple_table):
def test_complex_groupby_aggregation(con, snapshot, simple_table):
keys = [simple_table.i.year().name("year"), simple_table.i.month().name("month")]
b_unique = simple_table.b.nunique()
expr = simple_table.group_by(keys).aggregate(
total=simple_table.count(), b_unique=b_unique
)
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


def test_translate_simple_filtered_agg(snapshot, simple_table):
def test_simple_filtered_agg(con, snapshot, simple_table):
expr = simple_table.b.nunique(where=simple_table.g == "A")
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


def test_translate_complex_filtered_agg(snapshot, simple_table):
def test_complex_filtered_agg(con, snapshot, simple_table):
expr = simple_table.group_by("b").aggregate(
total=simple_table.count(),
avg_a=simple_table.a.mean(),
avg_a_A=simple_table.a.mean(where=simple_table.g == "A"),
avg_a_B=simple_table.a.mean(where=simple_table.g == "B"),
)
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


def test_translate_value_counts(snapshot, simple_table):
def test_value_counts(con, snapshot, simple_table):
expr = simple_table.i.year().value_counts()
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


def test_translate_having(snapshot, simple_table):
def test_having(con, snapshot, simple_table):
expr = (
simple_table.group_by("g")
.having(simple_table.count() >= 1000)
.aggregate(simple_table.b.sum().name("b_sum"))
)
result = translate(expr.as_table().op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")


@pytest.mark.parametrize(
"function_type,params",
[
pytest.param(
"tumble", {"window_size": ibis.interval(minutes=15)}, id="tumble_window"
),
pytest.param(
"hop",
{
"window_size": ibis.interval(minutes=15),
"window_slide": ibis.interval(minutes=1),
},
id="hop_window",
),
pytest.param(
"cumulate",
{
"window_size": ibis.interval(minutes=1),
"window_step": ibis.interval(seconds=10),
},
id="cumulate_window",
),
],
)
def test_tvf(con, snapshot, simple_table, function_type, params):
expr = getattr(simple_table.window_by(time_col=simple_table.i), function_type)(
**params
)
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")
17 changes: 8 additions & 9 deletions ibis/backends/flink/tests/test_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import ibis
import ibis.expr.datatypes as dt
from ibis.backends.flink.compiler.core import translate


@pytest.mark.parametrize(
Expand All @@ -20,9 +19,9 @@
param(False, "FALSE", id="false"),
],
)
def test_simple_literals(value, expected):
def test_simple_literals(con, value, expected):
expr = ibis.literal(value)
result = translate(expr.op())
result = con.compile(expr)
assert result == expected


Expand All @@ -34,9 +33,9 @@ def test_simple_literals(value, expected):
param('An "escape"', """'An "escape"'""", id="nested_token"),
],
)
def test_string_literals(value, expected):
def test_string_literals(con, value, expected):
expr = ibis.literal(value)
result = translate(expr.op())
result = con.compile(expr)
assert result == expected


Expand All @@ -54,9 +53,9 @@ def test_string_literals(value, expected):
param(ibis.interval(seconds=5), "INTERVAL '5' SECOND", id="5seconds"),
],
)
def test_translate_interval_literal(value, expected):
def test_translate_interval_literal(con, value, expected):
expr = ibis.literal(value)
result = translate(expr.op())
result = con.compile(expr)
assert result == expected


Expand All @@ -75,7 +74,7 @@ def test_translate_interval_literal(value, expected):
param("04:55:59", dt.time, id="string_time"),
],
)
def test_literal_timestamp_or_time(snapshot, case, dtype):
def test_literal_timestamp_or_time(con, snapshot, case, dtype):
expr = ibis.literal(case, type=dtype)
result = translate(expr.op())
result = con.compile(expr)
snapshot.assert_match(result, "out.sql")
Loading

0 comments on commit a35a756

Please sign in to comment.