Skip to content

Commit

Permalink
fix(druid): handle conversion issues from string, binary, and timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and jcrist committed Jul 5, 2023
1 parent 310db2b commit b632063
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 21 deletions.
40 changes: 26 additions & 14 deletions ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Iterable

import sqlalchemy as sa
from pydruid.db.sqlalchemy import DruidDialect

import ibis.backends.druid.datatypes as ddt
import ibis.expr.datatypes as dt
Expand Down Expand Up @@ -50,22 +49,35 @@ def do_connect(
# workaround a broken pydruid `has_table` implementation
engine.dialect.has_table = self._has_table

@staticmethod
def _new_sa_metadata():
meta = sa.MetaData()

@sa.event.listens_for(meta, "column_reflect")
def column_reflect(inspector, table, column_info):
if isinstance(typ := column_info["type"], sa.DateTime):
column_info["type"] = ddt.DruidDateTime()
elif isinstance(typ, (sa.LargeBinary, sa.BINARY, sa.VARBINARY)):
column_info["type"] = ddt.DruidBinary()
elif isinstance(typ, sa.String):
column_info["type"] = ddt.DruidString()

return meta

@contextlib.contextmanager
def _safe_raw_sql(self, query, *args, **kwargs):
if not isinstance(query, str):
query = str(
query.compile(
dialect=DruidDialect(), compile_kwargs=dict(literal_binds=True)
)
)
query = query.compile(
dialect=self.con.dialect, compile_kwargs=dict(literal_binds=True)
)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Dialect druid:rest will not make use of SQL compilation caching",
category=sa.exc.SAWarning,
)
with self.begin() as con:
yield con.exec_driver_sql(query, *args, **kwargs)
yield con.execute(query, *args, **kwargs)

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
query = f"EXPLAIN PLAN FOR {query}"
Expand All @@ -87,12 +99,12 @@ def _get_temp_view_definition(
raise NotImplementedError()

def _has_table(self, connection, table_name: str, schema) -> bool:
query = sa.text(
"""\
SELECT COUNT(*) > 0 as c
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_NAME = :table_name"""
).bindparams(table_name=table_name)
t = sa.table(
"TABLES", sa.column("TABLE_NAME", sa.TEXT), schema="INFORMATION_SCHEMA"
)
query = sa.select(
sa.func.sum(sa.cast(t.c.TABLE_NAME == table_name, sa.INTEGER))
).compile(dialect=self.con.dialect)

return bool(connection.execute(query).scalar())

Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from __future__ import annotations

import contextlib

import sqlalchemy as sa

import ibis.backends.druid.datatypes as ddt
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.druid.registry import operation_registry

Expand All @@ -9,6 +14,14 @@ class DruidExprTranslator(AlchemyExprTranslator):
_rewrites = AlchemyExprTranslator._rewrites.copy()
_dialect_name = "druid"

type_mapper = ddt.DruidType

def translate(self, op):
result = super().translate(op)
with contextlib.suppress(AttributeError):
result = result.scalar_subquery()
return sa.type_coerce(result, self.type_mapper.from_ibis(op.output_dtype))


rewrites = DruidExprTranslator.rewrites

Expand Down
55 changes: 55 additions & 0 deletions ibis/backends/druid/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,45 @@

import parsy
import sqlalchemy as sa
import sqlalchemy.types as sat
from dateutil.parser import parse as timestamp_parse
from sqlalchemy.ext.compiler import compiles

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType
from ibis.common.parsing import (
LANGLE,
RANGLE,
spaceless_string,
)


class DruidDateTime(sat.TypeDecorator):
impl = sa.TIMESTAMP

cache_ok = True

def process_result_value(self, value, dialect):
return None if value is None else timestamp_parse(value)


class DruidBinary(sa.LargeBinary):
def result_processor(self, dialect, coltype):
def process(value):
return None if value is None else value.encode("utf-8")

return process


class DruidString(sat.TypeDecorator):
impl = sa.String

cache_ok = True

def process_result_value(self, value, dialect):
return value


@compiles(sa.BIGINT, "druid")
@compiles(sa.BigInteger, "druid")
def _bigint(element, compiler, **kw):
Expand Down Expand Up @@ -47,3 +76,29 @@ def parse(text: str) -> dt.DataType:

ty.become(primitive | array | json)
return ty.parse(text)


class DruidType(AlchemyType):
dialect = "hive"

@classmethod
def to_ibis(cls, typ, nullable=True):
if isinstance(typ, DruidDateTime):
return dt.Timestamp(nullable=nullable)
elif isinstance(typ, DruidBinary):
return dt.Binary(nullable=nullable)
elif isinstance(typ, DruidString):
return dt.String(nullable=nullable)
else:
return super().to_ibis(typ, nullable=nullable)

@classmethod
def from_ibis(cls, dtype):
if dtype.is_timestamp():
return DruidDateTime()
elif dtype.is_binary():
return DruidBinary()
elif dtype.is_string():
return DruidString()
else:
return super().from_ibis(dtype)
2 changes: 0 additions & 2 deletions ibis/backends/tests/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
"postgres": "bytea",
}

pytestmark = pytest.mark.broken(["druid"], raises=AssertionError)


@pytest.mark.broken(
['trino'],
Expand Down
8 changes: 3 additions & 5 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@


@pytest.mark.parametrize("limit", limit_no_limit)
@pytest.mark.notimpl(["druid"])
def test_table_to_pyarrow_batches(limit, awards_players):
with awards_players.to_pyarrow_batches(limit=limit) as batch_reader:
assert isinstance(batch_reader, pa.ipc.RecordBatchReader)
Expand All @@ -73,7 +72,6 @@ def test_column_to_pyarrow_batches(limit, awards_players):


@pytest.mark.parametrize("limit", limit_no_limit)
@pytest.mark.notimpl(["druid"])
def test_table_to_pyarrow_table(limit, awards_players):
table = awards_players.to_pyarrow(limit=limit)
assert isinstance(table, pa.Table)
Expand Down Expand Up @@ -144,7 +142,7 @@ def test_column_to_pyarrow_table_schema(awards_players):
assert array.type == pa.string() or array.type == pa.large_string()


@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion", "druid"])
@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion"])
@pytest.mark.notyet(
["clickhouse"],
raises=AssertionError,
Expand Down Expand Up @@ -176,7 +174,7 @@ def test_column_pyarrow_batch_chunk_size(awards_players):
util.consume(batch_reader)


@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion", "druid"])
@pytest.mark.notimpl(["pandas", "dask", "impala", "pyspark", "datafusion"])
@pytest.mark.broken(
["sqlite"],
raises=pa.ArrowException,
Expand Down Expand Up @@ -212,7 +210,7 @@ def test_to_pyarrow_batches_memtable(con):
assert n == 3


@pytest.mark.notimpl(["dask", "impala", "pyspark", "druid"])
@pytest.mark.notimpl(["dask", "impala", "pyspark"])
def test_table_to_parquet(tmp_path, backend, awards_players):
outparquet = tmp_path / "out.parquet"
awards_players.to_parquet(outparquet)
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,11 @@ def test_extract_time_from_timestamp(con, microsecond):
"invalid type [CAST(INTERVAL_LITERAL('second', '1') AS VARIANT)] for parameter 'TO_VARIANT'",
raises=sa.exc.ProgrammingError,
)
@pytest.mark.broken(
['druid'],
'No literal value renderer is available for literal value "1" with datatype DATETIME',
raises=sa.exc.CompileError,
)
@pytest.mark.broken(
['impala'],
'AnalysisException: Syntax error in line 1: SELECT typeof(INTERVAL 1 SECOND) AS `TypeOf(1)` '
Expand Down

0 comments on commit b632063

Please sign in to comment.