Skip to content

Commit

Permalink
feat(flink): add map support (#8425)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfatihaktas authored Feb 24, 2024
1 parent 965b6d9 commit 68739a2
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 53 deletions.
22 changes: 20 additions & 2 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.IsInf,
ops.IsNan,
ops.Levenshtein,
ops.MapMerge,
ops.Median,
ops.MultiQuantile,
ops.NthValue,
Expand All @@ -81,7 +80,8 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ExtractDayOfYear: "dayofyear",
ops.First: "first_value",
ops.Last: "last_value",
ops.Map: "map_from_arrays",
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.Power: "power",
ops.RandomScalar: "rand",
ops.RegexSearch: "regexp",
Expand Down Expand Up @@ -548,3 +548,21 @@ def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, self.f.array(arg)[2])
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_MapContains(self, op: ops.MapContains, *, arg, key):
return self.f.array_contains(self.f.map_keys(arg), key)

def visit_Map(self, op: ops.Map, *, keys, values):
return self.cast(self.f.map_from_arrays(keys, values), op.dtype)

def visit_MapMerge(self, op: ops.MapMerge, *, left, right):
left_keys = self.f.map_keys(left)
left_values = self.f.map_values(left)

right_keys = self.f.map_keys(right)
right_values = self.f.map_values(right)

keys = self.f.array_concat(left_keys, right_keys)
values = self.f.array_concat(left_values, right_values)

return self.cast(self.f.map_from_arrays(keys, values), op.dtype)
5 changes: 3 additions & 2 deletions ibis/backends/flink/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def from_ibis(cls, dtype: dt.DataType) -> DataType:
return DataTypes.ARRAY(cls.from_ibis(dtype.value_type), nullable=nullable)
elif dtype.is_map():
return DataTypes.MAP(
key_type=cls.from_ibis(dtype.key_type),
value_type=cls.from_ibis(dtype.key_type),
# keys *must* be non-nullable
key_type=cls.from_ibis(dtype.key_type.copy(nullable=False)),
value_type=cls.from_ibis(dtype.value_type),
nullable=nullable,
)
elif dtype.is_struct():
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/flink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class TestConf(BackendTest):
force_sort = True
stateful = False
supports_map = True
deps = "pandas", "pyflink"

@staticmethod
Expand Down Expand Up @@ -63,6 +64,17 @@ def _load_data(self, **_: Any) -> None:
con.create_table("json_t", json_types, temp=True)
con.create_table("struct", struct_types, temp=True)
con.create_table("win", win, temp=True)
con.create_table(
"map",
pd.DataFrame(
{
"idx": [1, 2],
"kv": [{"a": 1, "b": 2, "c": 3}, {"d": 4, "e": 5, "f": 6}],
}
),
schema=ibis.schema({"idx": "int64", "kv": "map<string, int64>"}),
temp=True,
)


class TestConfForStreaming(TestConf):
Expand Down
18 changes: 17 additions & 1 deletion ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,9 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
# key cannot be nullable in clickhouse
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])
return sge.DataType(
this=typecode.MAP, expressions=[key_type, value_type], nested=True
)


class FlinkType(SqlglotType):
Expand All @@ -1041,3 +1043,17 @@ class FlinkType(SqlglotType):
@classmethod
def _from_ibis_Binary(cls, dtype: dt.Binary) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.VARBINARY)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
# key cannot be nullable in clickhouse
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(
this=typecode.MAP,
expressions=[
sge.Var(this=key_type.sql(cls.dialect) + " NOT NULL"),
value_type,
],
nested=True,
)
53 changes: 8 additions & 45 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest
from pytest import param

Expand Down Expand Up @@ -56,20 +58,15 @@ def test_column_map_merge(backend):
table = backend.map
expr = table.select(
"idx",
merged=table.kv.cast("map<string, int8>") + ibis.map({"d": 1}),
merged=table.kv + ibis.map({"d": np.int64(1)}),
).order_by("idx")
result = expr.execute().merged
expected = pd.Series(
[{"a": 1, "b": 2, "c": 3, "d": 1}, {"d": 1, "e": 5, "f": 6}], name="merged"
)
backend.assert_series_equal(result, expected)
tm.assert_series_equal(result, expected)


@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapKeys'>",
)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand All @@ -85,11 +82,6 @@ def test_literal_map_keys(con):
assert np.array_equal(result, ["1", "2"])


@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapValues'>",
)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand All @@ -104,11 +96,6 @@ def test_literal_map_values(con):


@pytest.mark.notimpl(["postgres", "risingwave"])
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.arrays.ArrayContains'>",
)
def test_scalar_isin_literal_map_keys(con):
mapping = ibis.literal({"a": 1, "b": 2})
a = ibis.literal("a")
Expand All @@ -122,11 +109,6 @@ def test_scalar_isin_literal_map_keys(con):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapContains'>",
)
def test_map_scalar_contains_key_scalar(con):
mapping = ibis.literal({"a": 1, "b": 2})
a = ibis.literal("a")
Expand All @@ -137,11 +119,6 @@ def test_map_scalar_contains_key_scalar(con):
assert con.execute(false) == False # noqa: E712


@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapContains'>",
)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand All @@ -159,11 +136,6 @@ def test_map_scalar_contains_key_column(backend, alltypes, df):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason=("No translation rule for <class 'ibis.expr.operations.maps.MapContains'>"),
)
def test_map_column_contains_key_scalar(backend, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
series = df.apply(lambda row: {row["string_col"]: row["int_col"]}, axis=1)
Expand All @@ -177,11 +149,6 @@ def test_map_column_contains_key_scalar(backend, alltypes, df):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapContains'>",
)
def test_map_column_contains_key_column(alltypes):
map_expr = ibis.map(
ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])
Expand All @@ -194,11 +161,6 @@ def test_map_column_contains_key_column(alltypes):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
["flink"],
raises=exc.OperationNotDefinedError,
reason="No translation rule for <class 'ibis.expr.operations.maps.MapMerge'>",
)
def test_literal_map_merge(con):
a = ibis.literal({"a": 0, "b": 2})
b = ibis.literal({"a": 1, "c": 3})
Expand Down Expand Up @@ -270,10 +232,10 @@ def test_map_construct_dict(con, keys, values):
@pytest.mark.notyet(
["postgres", "risingwave"], reason="only support maps of string -> string"
)
@pytest.mark.notimpl(
@pytest.mark.broken(
["flink"],
raises=Py4JJavaError,
reason="Map key type should be non-nullable",
raises=pa.lib.ArrowInvalid,
reason="Map array child array should have no nulls",
)
def test_map_construct_array_column(con, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
Expand Down Expand Up @@ -383,6 +345,7 @@ def test_map_length(con):
assert con.execute(expr) == 2


@pytest.mark.notimpl(["flink"], raises=exc.OperationNotDefinedError)
def test_map_keys_unnest(backend):
expr = backend.map.kv.keys().unnest()
result = expr.to_pandas()
Expand Down
9 changes: 6 additions & 3 deletions ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import ibis.common.exceptions as com
from ibis import _, udf
from ibis.backends.tests.errors import Py4JJavaError

no_python_udfs = mark.notimpl(
[
Expand Down Expand Up @@ -54,7 +53,9 @@ def num_vowels(s: str, include_y: bool = False) -> int:
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notimpl(["polars"])
@mark.notimpl(["flink"], raises=Py4JJavaError)
@mark.never(
["flink"], strict=False, reason="broken with Python 3.9; works in Python 3.10"
)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(
["sqlite"], raises=com.IbisTypeError, reason="sqlite doesn't support map types"
Expand Down Expand Up @@ -84,7 +85,9 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notimpl(["polars"])
@mark.notimpl(["flink"], raises=Py4JJavaError)
@mark.never(
["flink"], strict=False, reason="broken with Python 3.9; works in Python 3.10"
)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(["sqlite"], raises=TypeError, reason="sqlite doesn't support map types")
def test_map_merge_udf(batting):
Expand Down

0 comments on commit 68739a2

Please sign in to comment.