From f983bfa1cf4234cbf1100ee24e9544a262217f1c Mon Sep 17 00:00:00 2001 From: Zhenzhong Xu Date: Thu, 30 Nov 2023 11:25:43 -0800 Subject: [PATCH] fix(flink): implement TypeMapper and SchemaMapper for Flink backend --- ibis/backends/flink/__init__.py | 4 +- ibis/backends/flink/datatypes.py | 93 ++++++++++++++++++++++++++++++++ ibis/backends/flink/registry.py | 5 +- ibis/backends/flink/utils.py | 40 ++------------ 4 files changed, 100 insertions(+), 42 deletions(-) create mode 100644 ibis/backends/flink/datatypes.py diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index babe6521e786..f7936b798354 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -12,6 +12,7 @@ from ibis.backends.base import BaseBackend, CanCreateDatabase from ibis.backends.base.sql.ddl import fully_qualified_re, is_fully_qualified from ibis.backends.flink.compiler.core import FlinkCompiler +from ibis.backends.flink.datatypes import FlinkRowSchema from ibis.backends.flink.ddl import ( CreateDatabase, CreateTableFromConnector, @@ -19,7 +20,6 @@ DropTable, InsertSelect, ) -from ibis.backends.flink.utils import ibis_schema_to_flink_schema if TYPE_CHECKING: from collections.abc import Mapping @@ -354,7 +354,7 @@ def create_table( obj = obj.to_pandas() if isinstance(obj, pd.DataFrame): table = self._table_env.from_pandas( - obj, ibis_schema_to_flink_schema(schema) + obj, FlinkRowSchema.from_ibis(schema) ) if isinstance(obj, ir.Table): table = obj diff --git a/ibis/backends/flink/datatypes.py b/ibis/backends/flink/datatypes.py new file mode 100644 index 000000000000..95e9a3fe3d7d --- /dev/null +++ b/ibis/backends/flink/datatypes.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import pyflink.table.types as fl + +import ibis.expr.datatypes as dt +import ibis.expr.schema as sch +from ibis.formats import SchemaMapper, TypeMapper + + +class FlinkRowSchema(SchemaMapper): + @classmethod + def from_ibis(cls, schema: sch.Schema | None) -> list[fl.RowType]: + if schema is None: + return None + + return fl.DataTypes.ROW( + [ + fl.DataTypes.FIELD(k, FlinkType.from_ibis(v)) + for k, v in schema.fields.items() + ] + ) + + +class FlinkType(TypeMapper): + @classmethod + def to_ibis(cls, typ: fl.DataType, nullable=True) -> dt.DataType: + """Convert a flink type to an ibis type.""" + if typ == fl.DataTypes.STRING(): + return dt.String(nullable=nullable) + elif typ == fl.DataTypes.BOOLEAN(): + return dt.Boolean(nullable=nullable) + elif typ == fl.DataTypes.BYTES(): + return dt.Binary(nullable=nullable) + elif typ == fl.DataTypes.TINYINT(): + return dt.Int8(nullable=nullable) + elif typ == fl.DataTypes.SMALLINT(): + return dt.Int16(nullable=nullable) + elif typ == fl.DataTypes.INT(): + return dt.Int32(nullable=nullable) + elif typ == fl.DataTypes.BIGINT(): + return dt.Int64(nullable=nullable) + elif typ == fl.DataTypes.FLOAT(): + return dt.Float32(nullable=nullable) + elif typ == fl.DataTypes.DOUBLE(): + return dt.Float64(nullable=nullable) + elif typ == fl.DataTypes.DATE(): + return dt.Date(nullable=nullable) + elif typ == fl.DataTypes.TIME(): + return dt.Time(nullable=nullable) + elif typ == fl.DataTypes.TIMESTAMP(): + return dt.Timestamp(nullable=nullable) + else: + return super().to_ibis(typ, nullable=nullable) + + @classmethod + def from_ibis(cls, dtype: dt.DataType) -> fl.DataType: + """Convert an ibis type to a flink type.""" + if dtype.is_string(): + return fl.DataTypes.STRING() + elif dtype.is_boolean(): + return fl.DataTypes.BOOLEAN() + elif dtype.is_binary(): + return fl.DataTypes.BYTES() + elif dtype.is_int8(): + return fl.DataTypes.TINYINT() + elif dtype.is_int16(): + return fl.DataTypes.SMALLINT() + elif dtype.is_int32(): + return fl.DataTypes.INT() + elif dtype.is_int64(): + return fl.DataTypes.BIGINT() + elif dtype.is_uint8(): + return fl.DataTypes.TINYINT() + elif dtype.is_uint16(): + return fl.DataTypes.SMALLINT() + elif dtype.is_uint32(): + return fl.DataTypes.INT() + elif dtype.is_uint64(): + return fl.DataTypes.BIGINT() + elif dtype.is_float16(): + return fl.DataTypes.FLOAT() + elif dtype.is_float32(): + return fl.DataTypes.FLOAT() + elif dtype.is_float64(): + return fl.DataTypes.DOUBLE() + elif dtype.is_date(): + return fl.DataTypes.DATE() + elif dtype.is_time(): + return fl.DataTypes.TIME() + elif dtype.is_timestamp(): + return fl.DataTypes.TIMESTAMP() + else: + return super().from_ibis(dtype) diff --git a/ibis/backends/flink/registry.py b/ibis/backends/flink/registry.py index f8123a4cdb1c..76ea1002fe43 100644 --- a/ibis/backends/flink/registry.py +++ b/ibis/backends/flink/registry.py @@ -9,6 +9,7 @@ operation_registry as base_operation_registry, ) from ibis.backends.base.sql.registry.main import varargs +from ibis.backends.flink.datatypes import FlinkType from ibis.common.temporal import TimestampUnit if TYPE_CHECKING: @@ -221,8 +222,6 @@ def _window(translator: ExprTranslator, op: ops.Node) -> str: def _clip(translator: ExprTranslator, op: ops.Node) -> str: - from ibis.backends.flink.utils import _to_pyflink_types - arg = translator.translate(op.arg) if op.upper is not None: @@ -233,7 +232,7 @@ def _clip(translator: ExprTranslator, op: ops.Node) -> str: lower = translator.translate(op.lower) arg = f"IF({arg} < {lower} AND {arg} IS NOT NULL, {lower}, {arg})" - return f"CAST({arg} AS {_to_pyflink_types[type(op.dtype)]!s})" + return f"CAST({arg} AS {FlinkType.from_ibis(op.dtype)!s})" def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str: diff --git a/ibis/backends/flink/utils.py b/ibis/backends/flink/utils.py index f794cb48ae05..c27bdbf6737e 100644 --- a/ibis/backends/flink/utils.py +++ b/ibis/backends/flink/utils.py @@ -5,11 +5,9 @@ from abc import ABC, abstractmethod from collections import defaultdict -from pyflink.table.types import DataTypes, RowType - import ibis.expr.datatypes as dt import ibis.expr.operations as ops -import ibis.expr.schema as sch +from ibis.backends.flink.datatypes import FlinkType from ibis.common.temporal import IntervalUnit from ibis.util import convert_unit @@ -247,27 +245,6 @@ def _translate_interval(value, dtype): return interval.format_as_string() -_to_pyflink_types = { - dt.String: DataTypes.STRING(), - dt.Boolean: DataTypes.BOOLEAN(), - dt.Binary: DataTypes.BYTES(), - dt.Int8: DataTypes.TINYINT(), - dt.Int16: DataTypes.SMALLINT(), - dt.Int32: DataTypes.INT(), - dt.Int64: DataTypes.BIGINT(), - dt.UInt8: DataTypes.TINYINT(), - dt.UInt16: DataTypes.SMALLINT(), - dt.UInt32: DataTypes.INT(), - dt.UInt64: DataTypes.BIGINT(), - dt.Float16: DataTypes.FLOAT(), - dt.Float32: DataTypes.FLOAT(), - dt.Float64: DataTypes.DOUBLE(), - dt.Date: DataTypes.DATE(), - dt.Time: DataTypes.TIME(), - dt.Timestamp: DataTypes.TIMESTAMP(), -} - - def translate_literal(op: ops.Literal) -> str: value = op.value dtype = op.dtype @@ -275,7 +252,7 @@ def translate_literal(op: ops.Literal) -> str: if value is None: if dtype.is_null(): return "NULL" - return f"CAST(NULL AS {_to_pyflink_types[type(dtype)]!s})" + return f"CAST(NULL AS {FlinkType.from_ibis(dtype)!s})" if dtype.is_boolean(): # TODO(chloeh13q): Flink supports a third boolean called "UNKNOWN" @@ -305,7 +282,7 @@ def translate_literal(op: ops.Literal) -> str: raise ValueError("The precision can be up to 38 in Flink") return f"CAST({value} AS DECIMAL({precision}, {scale}))" - return f"CAST({value} AS {_to_pyflink_types[type(dtype)]!s})" + return f"CAST({value} AS {FlinkType.from_ibis(dtype)!s})" elif dtype.is_timestamp(): # TODO(chloeh13q): support timestamp with local timezone if isinstance(value, datetime.datetime): @@ -327,14 +304,3 @@ def translate_literal(op: ops.Literal) -> str: return f"ARRAY{list(value)}" raise NotImplementedError(f"No translation rule for {dtype}") - - -def ibis_schema_to_flink_schema(schema: sch.Schema) -> RowType: - if schema is None: - return None - return DataTypes.ROW( - [ - DataTypes.FIELD(key, _to_pyflink_types[type(value)]) - for key, value in schema.fields.items() - ] - )