Skip to content

Commit

Permalink
feat(datafusion): use pyarrow for type conversion (#9299)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Jun 4, 2024
1 parent 62e48f5 commit 5bef96a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 25 deletions.
50 changes: 28 additions & 22 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ibis.backends.sql.compiler import C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType
from ibis.formats.pyarrow import PyArrowSchema, PyArrowType
from ibis.util import deprecated, gen_name, normalize_filename

try:
Expand All @@ -43,6 +43,25 @@
import polars as pl


def as_nullable(dtype: dt.DataType) -> dt.DataType:
"""Recursively convert a possibly non-nullable datatype to a nullable one."""
if dtype.is_struct():
return dtype.copy(
fields={name: as_nullable(typ) for name, typ in dtype.items()},
nullable=True,
)
elif dtype.is_array():
return dtype.copy(value_type=as_nullable(dtype.value_type), nullable=True)
elif dtype.is_map():
return dtype.copy(
key_type=as_nullable(dtype.key_type),
value_type=as_nullable(dtype.value_type),
nullable=True,
)
else:
return dtype.copy(nullable=True)


class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl):
name = "datafusion"
supports_in_memory_tables = True
Expand Down Expand Up @@ -114,23 +133,11 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
pass

try:
result = (
self.raw_sql(f"DESCRIBE {table.sql(self.name)}")
.to_arrow_table()
.to_pydict()
)
df = self.con.table(name)
finally:
self.drop_view(name)
return sch.Schema(
{
name: self.compiler.type_mapper.from_string(
type_string, nullable=is_nullable == "YES"
)
for name, type_string, is_nullable in zip(
result["column_name"], result["data_type"], result["is_nullable"]
)
}
)

return PyArrowSchema.to_ibis(df.schema())

def _register_builtin_udfs(self):
from ibis.backends.datafusion import udfs
Expand Down Expand Up @@ -523,7 +530,9 @@ def to_pyarrow_batches(

frame = self.con.sql(raw_sql)

schema = table_expr.schema()
schema = sch.Schema(
{name: as_nullable(typ) for name, typ in table_expr.schema().items()}
)
names = schema.names

struct_schema = schema.as_struct().to_pyarrow()
Expand All @@ -537,15 +546,12 @@ def make_gen():
# cast the struct array to the desired types to work around
# https://github.com/apache/arrow-datafusion-python/issues/534
.to_struct_array()
.cast(struct_schema)
.cast(struct_schema, safe=False)
)
for batch in frame.collect()
)

return pa.ipc.RecordBatchReader.from_batches(
schema.to_pyarrow(),
make_gen(),
)
return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), make_gen())

def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table:
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
Expand Down
24 changes: 24 additions & 0 deletions ibis/backends/datafusion/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import hypothesis as h

import ibis.tests.strategies as its
from ibis.backends.datafusion import as_nullable


def is_nullable(dtype):
if dtype.is_struct():
return all(map(is_nullable, dtype.values()))
elif dtype.is_array():
return is_nullable(dtype.value_type)
elif dtype.is_map():
return is_nullable(dtype.key_type) and is_nullable(dtype.value_type)
else:
return dtype.nullable is True


@h.given(its.all_dtypes())
def test_as_nullable(dtype):
nullable_dtype = as_nullable(dtype)
assert nullable_dtype.nullable is True
assert is_nullable(nullable_dtype)
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_query_schema(ddl_backend, expr_fn, expected):
}


@pytest.mark.notimpl(["datafusion", "mssql"])
@pytest.mark.notimpl(["mssql"])
@pytest.mark.never(["dask", "pandas"], reason="dask and pandas do not support SQL")
def test_sql(backend, con):
# execute the expression using SQL query
Expand Down
2 changes: 0 additions & 2 deletions ibis/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,6 @@ def all_dtypes(nullable=_nullable):
| temporal_dtypes(nullable=nullable)
| interval_dtype(nullable=nullable)
| geospatial_dtypes(nullable=nullable)
| variadic_dtypes(nullable=nullable)
| struct_dtypes(nullable=nullable)
| array_dtypes(recursive, nullable=nullable)
| map_dtypes(recursive, recursive, nullable=nullable)
| struct_dtypes(recursive, nullable=nullable)
Expand Down

0 comments on commit 5bef96a

Please sign in to comment.