diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 37851690ac06..a2ed95d5d54f 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -25,7 +25,7 @@ from ibis.backends.sql.compiler import C from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType -from ibis.formats.pyarrow import PyArrowType +from ibis.formats.pyarrow import PyArrowSchema, PyArrowType from ibis.util import deprecated, gen_name, normalize_filename try: @@ -43,6 +43,25 @@ import polars as pl +def as_nullable(dtype: dt.DataType) -> dt.DataType: + """Recursively convert a possibly non-nullable datatype to a nullable one.""" + if dtype.is_struct(): + return dtype.copy( + fields={name: as_nullable(typ) for name, typ in dtype.items()}, + nullable=True, + ) + elif dtype.is_array(): + return dtype.copy(value_type=as_nullable(dtype.value_type), nullable=True) + elif dtype.is_map(): + return dtype.copy( + key_type=as_nullable(dtype.key_type), + value_type=as_nullable(dtype.value_type), + nullable=True, + ) + else: + return dtype.copy(nullable=True) + + class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl): name = "datafusion" supports_in_memory_tables = True @@ -114,23 +133,11 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: pass try: - result = ( - self.raw_sql(f"DESCRIBE {table.sql(self.name)}") - .to_arrow_table() - .to_pydict() - ) + df = self.con.table(name) finally: self.drop_view(name) - return sch.Schema( - { - name: self.compiler.type_mapper.from_string( - type_string, nullable=is_nullable == "YES" - ) - for name, type_string, is_nullable in zip( - result["column_name"], result["data_type"], result["is_nullable"] - ) - } - ) + + return PyArrowSchema.to_ibis(df.schema()) def _register_builtin_udfs(self): from ibis.backends.datafusion import udfs @@ -523,7 +530,9 @@ def to_pyarrow_batches( frame = self.con.sql(raw_sql) - schema = table_expr.schema() + schema = sch.Schema( + {name: as_nullable(typ) for name, typ in table_expr.schema().items()} + ) names = schema.names struct_schema = schema.as_struct().to_pyarrow() @@ -537,15 +546,12 @@ def make_gen(): # cast the struct array to the desired types to work around # https://github.com/apache/arrow-datafusion-python/issues/534 .to_struct_array() - .cast(struct_schema) + .cast(struct_schema, safe=False) ) for batch in frame.collect() ) - return pa.ipc.RecordBatchReader.from_batches( - schema.to_pyarrow(), - make_gen(), - ) + return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), make_gen()) def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table: batch_reader = self.to_pyarrow_batches(expr, **kwargs) diff --git a/ibis/backends/datafusion/tests/test_datatypes.py b/ibis/backends/datafusion/tests/test_datatypes.py new file mode 100644 index 000000000000..846684b2b74f --- /dev/null +++ b/ibis/backends/datafusion/tests/test_datatypes.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import hypothesis as h + +import ibis.tests.strategies as its +from ibis.backends.datafusion import as_nullable + + +def is_nullable(dtype): + if dtype.is_struct(): + return all(map(is_nullable, dtype.values())) + elif dtype.is_array(): + return is_nullable(dtype.value_type) + elif dtype.is_map(): + return is_nullable(dtype.key_type) and is_nullable(dtype.value_type) + else: + return dtype.nullable is True + + +@h.given(its.all_dtypes()) +def test_as_nullable(dtype): + nullable_dtype = as_nullable(dtype) + assert nullable_dtype.nullable is True + assert is_nullable(nullable_dtype) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 3ad0ec961225..5bfd29f52d7d 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -246,7 +246,7 @@ def test_query_schema(ddl_backend, expr_fn, expected): } -@pytest.mark.notimpl(["datafusion", "mssql"]) +@pytest.mark.notimpl(["mssql"]) @pytest.mark.never(["dask", "pandas"], reason="dask and pandas do not support SQL") def test_sql(backend, con): # execute the expression using SQL query diff --git a/ibis/tests/strategies.py b/ibis/tests/strategies.py index 8d867eb58281..f417b3719a12 100644 --- a/ibis/tests/strategies.py +++ b/ibis/tests/strategies.py @@ -223,8 +223,6 @@ def all_dtypes(nullable=_nullable): | temporal_dtypes(nullable=nullable) | interval_dtype(nullable=nullable) | geospatial_dtypes(nullable=nullable) - | variadic_dtypes(nullable=nullable) - | struct_dtypes(nullable=nullable) | array_dtypes(recursive, nullable=nullable) | map_dtypes(recursive, recursive, nullable=nullable) | struct_dtypes(recursive, nullable=nullable)