Skip to content

Commit

Permalink
fix(flink): fix compilation of memtable with nested data (#8751)
Browse files Browse the repository at this point in the history
## Description of changes

This PR aims to fix the compilation of memtables with nested data.

### What was broken

In particular, [Flink does not support the ``STRUCT(1 AS `a`)`` aliasing
syntax to define named
STRUCTs](https://issues.apache.org/jira/browse/FLINK-9161). In order to
do so, we must use a workaround using `CAST`, e.g.,
```sql
SELECT CAST(('a', 1) as ROW<a STRING, b INT>);
```

However, Flink also does not allow you to directly construct ARRAYs of
named STRUCTs using the `ARRAY[]` constructor. This is a bug that I
identified and I have filed it with the Flink community (JIRA ticket
ref: https://issues.apache.org/jira/browse/FLINK-34898).

For the time being, we will need to use another `CAST` workaound that
casts the entire nested array, e.g.,
```sql
SELECT cast(ARRAY[ROW(1)] as ARRAY<ROW<a INT>>);  -- instead of ARRAY[CAST(ROW(1) AS ROW<a INT>)]
```

### How to fix

To summarize,
- if it’s an array of named structs
`CAST(ARRAY[] AS ARRAY<ROW<>, ROW<>>)`
- if it’s named structs
`CAST(ROW() AS ROW<datatype of each field>)`
- if it’s unnamed structs (but I'm not sure how to write this in Ibis)
`ROW()`

I thought of two approaches to this:
1. Rewrite the operator mapping in the Flink backend (i.e., change the
`visit_NonNullLiteral()` method)
2. Rewrite the translation rule in Flink's `Generator`

I found both implementations in different scenarios and decided to go
with option (2).

## Issues closed

#8516
  • Loading branch information
chloeh13q authored Mar 27, 2024
1 parent 3623788 commit 364a6ee
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 16 deletions.
3 changes: 3 additions & 0 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,6 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right):
values = self.f.array_concat(left_values, right_values)

return self.cast(self.f.map_from_arrays(keys, values), op.dtype)

def visit_StructColumn(self, op, *, names, values):
return self.cast(sge.Struct(expressions=list(values)), op.dtype)
48 changes: 48 additions & 0 deletions ibis/backends/flink/tests/test_memtable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

import pytest
from pyflink.common.types import Row

import ibis
from ibis.backends.tests.errors import Py4JJavaError


@pytest.mark.parametrize(
"data,schema,expected",
[
pytest.param(
{"value": [{"a": 1}, {"a": 2}]},
{"value": "!struct<a: !int>"},
[Row(Row([1])), Row(Row([2]))],
id="simple_named_struct",
),
pytest.param(
{"value": [[{"a": 1}, {"a": 2}], [{"a": 3}, {"a": 4}]]},
{"value": "!array<!struct<a: !int>>"},
[Row([Row([1]), Row([2])]), Row([Row([3]), Row([4])])],
id="single_field_named_struct_array",
),
pytest.param(
{"value": [[{"a": 1, "b": 2}, {"a": 2, "b": 2}]]},
{"value": "!array<!struct<a: !int, b: !int>>"},
[Row([Row([1, 2]), Row([2, 2])])],
id="named_struct_array",
),
],
)
def test_create_memtable(con, data, schema, expected):
t = ibis.memtable(data, schema=ibis.schema(schema))
# cannot use con.execute(t) directly because of some behavioral discrepancy between
# `TableEnvironment.execute_sql()` and `TableEnvironment.sql_query()`
result = con.raw_sql(con.compile(t))
# raw_sql() returns a `TableResult` object and doesn't natively convert to pandas
assert list(result.collect()) == expected


@pytest.mark.notyet(
["flink"],
raises=Py4JJavaError,
reason="cannot create an ARRAY of named STRUCTs directly from the ARRAY[] constructor; https://issues.apache.org/jira/browse/FLINK-34898",
)
def test_create_named_struct_array_with_array_constructor(con):
con.raw_sql("SELECT ARRAY[cast(ROW(1) as ROW<a INT>)];")
60 changes: 56 additions & 4 deletions ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import math
from copy import deepcopy

import sqlglot.expressions as sge
from sqlglot import transforms
Expand All @@ -18,6 +19,7 @@
Trino,
)
from sqlglot.dialects.dialect import rename_func
from sqlglot.helper import seq_get

ClickHouse.Generator.TRANSFORMS |= {
sge.ArraySize: rename_func("length"),
Expand Down Expand Up @@ -113,6 +115,7 @@ class Flink(Hive):
class Generator(Hive.Generator):
TYPE_MAPPING = Hive.Generator.TYPE_MAPPING.copy() | {
sge.DataType.Type.TIME: "TIME",
sge.DataType.Type.STRUCT: "ROW",
}

TRANSFORMS = Hive.Generator.TRANSFORMS.copy() | {
Expand All @@ -121,10 +124,6 @@ class Generator(Hive.Generator):
sge.StddevSamp: rename_func("stddev_samp"),
sge.Variance: rename_func("var_samp"),
sge.VariancePop: rename_func("var_pop"),
sge.Array: (
lambda self,
e: f"ARRAY[{', '.join(arg.sql(self.dialect) for arg in e.expressions)}]"
),
sge.ArrayConcat: rename_func("array_concat"),
sge.Length: rename_func("char_length"),
sge.TryCast: lambda self,
Expand All @@ -135,6 +134,59 @@ class Generator(Hive.Generator):
sge.Interval: _interval_with_precision,
}

def struct_sql(self, expression: sge.Struct) -> str:
from sqlglot.optimizer.annotate_types import annotate_types

expression = annotate_types(expression)

values = []
schema = []

for e in expression.expressions:
if isinstance(e, sge.PropertyEQ):
e = sge.alias_(e.expression, e.this)
# named structs
if isinstance(e, sge.Alias):
if e.type and e.type.is_type(sge.DataType.Type.UNKNOWN):
self.unsupported(
"Cannot convert untyped key-value definitions (try annotate_types)."
)
else:
schema.append(f"{self.sql(e, 'alias')} {self.sql(e.type)}")
values.append(self.sql(e, "this"))
else:
values.append(self.sql(e))

if not (size := len(expression.expressions)) or len(schema) != size:
return self.func("ROW", *values)
return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))"

def array_sql(self, expression: sge.Array) -> str:
# workaround for the time being because you cannot construct an array of named
# STRUCTs directly from the ARRAY[] constructor
# https://issues.apache.org/jira/browse/FLINK-34898
from sqlglot.optimizer.annotate_types import annotate_types

expression = annotate_types(expression)
first_arg = seq_get(expression.expressions, 0)
# it's an array of structs
if isinstance(first_arg, sge.Struct):
# get rid of aliasing because we want to compile this as CAST instead
args = deepcopy(expression.expressions)
for arg in args:
for e in arg.expressions:
arg.set("expressions", [e.unalias() for e in arg.expressions])

format_values = ", ".join([self.sql(arg) for arg in args])
# all elements of the array should have the same type
format_dtypes = self.sql(first_arg.type)

return f"CAST(ARRAY[{format_values}] AS ARRAY<{format_dtypes}>)"

return (
f"ARRAY[{', '.join(self.sql(arg) for arg in expression.expressions)}]"
)

class Tokenizer(Hive.Tokenizer):
# In Flink, embedded single quotes are escaped like most other SQL
# dialects: doubling up the single quote
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df):
),
pytest.mark.notyet(["pandas", "dask"]),
mark_notyet_postgres,
pytest.mark.notimpl("flink"),
pytest.mark.notyet(
["flink"],
raises=Py4JJavaError,
reason="does not support selecting struct key from map",
),
mark_notyet_snowflake,
],
id="struct",
Expand Down Expand Up @@ -304,7 +308,6 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df):
marks=[
pytest.mark.notyet("clickhouse", reason="nested types can't be null"),
mark_notyet_postgres,
pytest.mark.notimpl("flink", reason="can't construct structs"),
],
id="struct",
),
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def test_scalar_param_array(con):
["mysql", "sqlite", "mssql"],
reason="mysql and sqlite will never implement struct types",
)
@pytest.mark.notimpl(["flink"], "WIP")
def test_scalar_param_struct(con):
value = dict(a=1, b="abc", c=3.0)
param = ibis.param("struct<a: int64, b: string, c: float64>")
Expand Down
9 changes: 0 additions & 9 deletions ibis/backends/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def test_all_fields(struct, struct_df):

@pytest.mark.notimpl(["postgres", "risingwave"])
@pytest.mark.parametrize("field", ["a", "b", "c"])
@pytest.mark.notyet(
["flink"], reason="flink doesn't support creating struct columns from literals"
)
def test_literal(backend, con, field):
query = _STRUCT_LITERAL[field]
dtype = query.type().to_pandas()
Expand All @@ -89,9 +86,6 @@ def test_literal(backend, con, field):
@pytest.mark.notyet(
["clickhouse"], reason="clickhouse doesn't support nullable nested types"
)
@pytest.mark.notyet(
["flink"], reason="flink doesn't support creating struct columns from literals"
)
def test_null_literal(backend, con, field):
query = _NULL_STRUCT_LITERAL[field]
result = pd.Series([con.execute(query)])
Expand All @@ -101,9 +95,6 @@ def test_null_literal(backend, con, field):


@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"])
@pytest.mark.notyet(
["flink"], reason="flink doesn't support creating struct columns from literals"
)
def test_struct_column(alltypes, df):
t = alltypes
expr = t.select(s=ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col)))
Expand Down

0 comments on commit 364a6ee

Please sign in to comment.