From b632063dc9b7fcc722dcb4e5bb98b10aa4f7d54e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 29 Jun 2023 11:56:14 -0400 Subject: [PATCH] fix(druid): handle conversion issues from string, binary, and timestamp --- ibis/backends/druid/__init__.py | 40 +++++++++++++------- ibis/backends/druid/compiler.py | 13 +++++++ ibis/backends/druid/datatypes.py | 55 ++++++++++++++++++++++++++++ ibis/backends/tests/test_binary.py | 2 - ibis/backends/tests/test_export.py | 8 ++-- ibis/backends/tests/test_temporal.py | 5 +++ 6 files changed, 102 insertions(+), 21 deletions(-) diff --git a/ibis/backends/druid/__init__.py b/ibis/backends/druid/__init__.py index 4f2db7f06521..0a63954eed9f 100644 --- a/ibis/backends/druid/__init__.py +++ b/ibis/backends/druid/__init__.py @@ -8,7 +8,6 @@ from typing import Any, Iterable import sqlalchemy as sa -from pydruid.db.sqlalchemy import DruidDialect import ibis.backends.druid.datatypes as ddt import ibis.expr.datatypes as dt @@ -50,14 +49,27 @@ def do_connect( # workaround a broken pydruid `has_table` implementation engine.dialect.has_table = self._has_table + @staticmethod + def _new_sa_metadata(): + meta = sa.MetaData() + + @sa.event.listens_for(meta, "column_reflect") + def column_reflect(inspector, table, column_info): + if isinstance(typ := column_info["type"], sa.DateTime): + column_info["type"] = ddt.DruidDateTime() + elif isinstance(typ, (sa.LargeBinary, sa.BINARY, sa.VARBINARY)): + column_info["type"] = ddt.DruidBinary() + elif isinstance(typ, sa.String): + column_info["type"] = ddt.DruidString() + + return meta + @contextlib.contextmanager def _safe_raw_sql(self, query, *args, **kwargs): - if not isinstance(query, str): - query = str( - query.compile( - dialect=DruidDialect(), compile_kwargs=dict(literal_binds=True) - ) - ) + query = query.compile( + dialect=self.con.dialect, compile_kwargs=dict(literal_binds=True) + ) + with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -65,7 +77,7 @@ def _safe_raw_sql(self, query, *args, **kwargs): category=sa.exc.SAWarning, ) with self.begin() as con: - yield con.exec_driver_sql(query, *args, **kwargs) + yield con.execute(query, *args, **kwargs) def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: query = f"EXPLAIN PLAN FOR {query}" @@ -87,12 +99,12 @@ def _get_temp_view_definition( raise NotImplementedError() def _has_table(self, connection, table_name: str, schema) -> bool: - query = sa.text( - """\ -SELECT COUNT(*) > 0 as c -FROM INFORMATION_SCHEMA.TABLES -WHERE TABLE_NAME = :table_name""" - ).bindparams(table_name=table_name) + t = sa.table( + "TABLES", sa.column("TABLE_NAME", sa.TEXT), schema="INFORMATION_SCHEMA" + ) + query = sa.select( + sa.func.sum(sa.cast(t.c.TABLE_NAME == table_name, sa.INTEGER)) + ).compile(dialect=self.con.dialect) return bool(connection.execute(query).scalar()) diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index de6cb0ffe0c1..c7e8f1cab668 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -1,5 +1,10 @@ from __future__ import annotations +import contextlib + +import sqlalchemy as sa + +import ibis.backends.druid.datatypes as ddt from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator from ibis.backends.druid.registry import operation_registry @@ -9,6 +14,14 @@ class DruidExprTranslator(AlchemyExprTranslator): _rewrites = AlchemyExprTranslator._rewrites.copy() _dialect_name = "druid" + type_mapper = ddt.DruidType + + def translate(self, op): + result = super().translate(op) + with contextlib.suppress(AttributeError): + result = result.scalar_subquery() + return sa.type_coerce(result, self.type_mapper.from_ibis(op.output_dtype)) + rewrites = DruidExprTranslator.rewrites diff --git a/ibis/backends/druid/datatypes.py b/ibis/backends/druid/datatypes.py index 97c514e2bdbe..c760131bc0e6 100644 --- a/ibis/backends/druid/datatypes.py +++ b/ibis/backends/druid/datatypes.py @@ -2,9 +2,12 @@ import parsy import sqlalchemy as sa +import sqlalchemy.types as sat +from dateutil.parser import parse as timestamp_parse from sqlalchemy.ext.compiler import compiles import ibis.expr.datatypes as dt +from ibis.backends.base.sql.alchemy.datatypes import AlchemyType from ibis.common.parsing import ( LANGLE, RANGLE, @@ -12,6 +15,32 @@ ) +class DruidDateTime(sat.TypeDecorator): + impl = sa.TIMESTAMP + + cache_ok = True + + def process_result_value(self, value, dialect): + return None if value is None else timestamp_parse(value) + + +class DruidBinary(sa.LargeBinary): + def result_processor(self, dialect, coltype): + def process(value): + return None if value is None else value.encode("utf-8") + + return process + + +class DruidString(sat.TypeDecorator): + impl = sa.String + + cache_ok = True + + def process_result_value(self, value, dialect): + return value + + @compiles(sa.BIGINT, "druid") @compiles(sa.BigInteger, "druid") def _bigint(element, compiler, **kw): @@ -47,3 +76,29 @@ def parse(text: str) -> dt.DataType: ty.become(primitive | array | json) return ty.parse(text) + + +class DruidType(AlchemyType): + dialect = "hive" + + @classmethod + def to_ibis(cls, typ, nullable=True): + if isinstance(typ, DruidDateTime): + return dt.Timestamp(nullable=nullable) + elif isinstance(typ, DruidBinary): + return dt.Binary(nullable=nullable) + elif isinstance(typ, DruidString): + return dt.String(nullable=nullable) + else: + return super().to_ibis(typ, nullable=nullable) + + @classmethod + def from_ibis(cls, dtype): + if dtype.is_timestamp(): + return DruidDateTime() + elif dtype.is_binary(): + return DruidBinary() + elif dtype.is_string(): + return DruidString() + else: + return super().from_ibis(dtype) diff --git a/ibis/backends/tests/test_binary.py b/ibis/backends/tests/test_binary.py index 42dc2a13b6b7..d3ece6101b5a 100644 --- a/ibis/backends/tests/test_binary.py +++ b/ibis/backends/tests/test_binary.py @@ -16,8 +16,6 @@ "postgres": "bytea", } -pytestmark = pytest.mark.broken(["druid"], raises=AssertionError) - @pytest.mark.broken( ['trino'], diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 25cd6bb3a5c9..0bb52482c161 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -50,7 +50,6 @@ @pytest.mark.parametrize("limit", limit_no_limit) -@pytest.mark.notimpl(["druid"]) def test_table_to_pyarrow_batches(limit, awards_players): with awards_players.to_pyarrow_batches(limit=limit) as batch_reader: assert isinstance(batch_reader, pa.ipc.RecordBatchReader) @@ -73,7 +72,6 @@ def test_column_to_pyarrow_batches(limit, awards_players): @pytest.mark.parametrize("limit", limit_no_limit) -@pytest.mark.notimpl(["druid"]) def test_table_to_pyarrow_table(limit, awards_players): table = awards_players.to_pyarrow(limit=limit) assert isinstance(table, pa.Table) @@ -144,7 +142,7 @@ def test_column_to_pyarrow_table_schema(awards_players): assert array.type == pa.string() or array.type == pa.large_string() -@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion", "druid"]) +@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion"]) @pytest.mark.notyet( ["clickhouse"], raises=AssertionError, @@ -176,7 +174,7 @@ def test_column_pyarrow_batch_chunk_size(awards_players): util.consume(batch_reader) -@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion", "druid"]) +@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion"]) @pytest.mark.broken( ["sqlite"], raises=pa.ArrowException, @@ -212,7 +210,7 @@ def test_to_pyarrow_batches_memtable(con): assert n == 3 -@pytest.mark.notimpl(["dask", "impala", "pyspark", "druid"]) +@pytest.mark.notimpl(["dask", "impala", "pyspark"]) def test_table_to_parquet(tmp_path, backend, awards_players): outparquet = tmp_path / "out.parquet" awards_players.to_parquet(outparquet) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 175e23dac178..96e5fdfec959 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -1938,6 +1938,11 @@ def test_extract_time_from_timestamp(con, microsecond): "invalid type [CAST(INTERVAL_LITERAL('second', '1') AS VARIANT)] for parameter 'TO_VARIANT'", raises=sa.exc.ProgrammingError, ) +@pytest.mark.broken( + ['druid'], + 'No literal value renderer is available for literal value "1" with datatype DATETIME', + raises=sa.exc.CompileError, +) @pytest.mark.broken( ['impala'], 'AnalysisException: Syntax error in line 1: SELECT typeof(INTERVAL 1 SECOND) AS `TypeOf(1)` '