Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datafusion): use pyarrow for type conversion #9299

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ibis.backends.sql.compiler import C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType
from ibis.formats.pyarrow import PyArrowSchema, PyArrowType
from ibis.util import deprecated, gen_name, normalize_filename

try:
Expand All @@ -43,6 +43,25 @@
import polars as pl


def as_nullable(dtype: dt.DataType) -> dt.DataType:
"""Recursively convert a possibly non-nullable datatype to a nullable one."""
if dtype.is_struct():
return dtype.copy(
fields={name: as_nullable(typ) for name, typ in dtype.items()},
nullable=True,
)
elif dtype.is_array():
return dtype.copy(value_type=as_nullable(dtype.value_type), nullable=True)
elif dtype.is_map():
return dtype.copy(
key_type=as_nullable(dtype.key_type),
value_type=as_nullable(dtype.value_type),
nullable=True,
)
else:
return dtype.copy(nullable=True)


class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl):
name = "datafusion"
supports_in_memory_tables = True
Expand Down Expand Up @@ -114,23 +133,11 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
pass

try:
result = (
self.raw_sql(f"DESCRIBE {table.sql(self.name)}")
.to_arrow_table()
.to_pydict()
)
df = self.con.table(name)
finally:
self.drop_view(name)
return sch.Schema(
{
name: self.compiler.type_mapper.from_string(
type_string, nullable=is_nullable == "YES"
)
for name, type_string, is_nullable in zip(
result["column_name"], result["data_type"], result["is_nullable"]
)
}
)

return PyArrowSchema.to_ibis(df.schema())

def _register_builtin_udfs(self):
from ibis.backends.datafusion import udfs
Expand Down Expand Up @@ -523,7 +530,9 @@ def to_pyarrow_batches(

frame = self.con.sql(raw_sql)

schema = table_expr.schema()
schema = sch.Schema(
{name: as_nullable(typ) for name, typ in table_expr.schema().items()}
)
names = schema.names

struct_schema = schema.as_struct().to_pyarrow()
Expand All @@ -537,15 +546,12 @@ def make_gen():
# cast the struct array to the desired types to work around
# https://github.com/apache/arrow-datafusion-python/issues/534
.to_struct_array()
.cast(struct_schema)
.cast(struct_schema, safe=False)
)
for batch in frame.collect()
)

return pa.ipc.RecordBatchReader.from_batches(
schema.to_pyarrow(),
make_gen(),
)
return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), make_gen())

def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table:
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
Expand Down
24 changes: 24 additions & 0 deletions ibis/backends/datafusion/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import hypothesis as h

import ibis.tests.strategies as its
from ibis.backends.datafusion import as_nullable


def is_nullable(dtype):
if dtype.is_struct():
return all(map(is_nullable, dtype.values()))
elif dtype.is_array():
return is_nullable(dtype.value_type)
elif dtype.is_map():
return is_nullable(dtype.key_type) and is_nullable(dtype.value_type)
else:
return dtype.nullable is True


@h.given(its.all_dtypes())
def test_as_nullable(dtype):
nullable_dtype = as_nullable(dtype)
assert nullable_dtype.nullable is True
assert is_nullable(nullable_dtype)
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_query_schema(ddl_backend, expr_fn, expected):
}


@pytest.mark.notimpl(["datafusion", "mssql"])
@pytest.mark.notimpl(["mssql"])
@pytest.mark.never(["dask", "pandas"], reason="dask and pandas do not support SQL")
def test_sql(backend, con):
# execute the expression using SQL query
Expand Down
2 changes: 0 additions & 2 deletions ibis/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,6 @@ def all_dtypes(nullable=_nullable):
| temporal_dtypes(nullable=nullable)
| interval_dtype(nullable=nullable)
| geospatial_dtypes(nullable=nullable)
| variadic_dtypes(nullable=nullable)
| struct_dtypes(nullable=nullable)
| array_dtypes(recursive, nullable=nullable)
| map_dtypes(recursive, recursive, nullable=nullable)
| struct_dtypes(recursive, nullable=nullable)
Expand Down