Skip to content

Commit

Permalink
feat(datatypes): add _as_nullable/_as_non_nullable private conversion…
Browse files Browse the repository at this point in the history
… methods to datatype and schema
  • Loading branch information
cpcloud committed Jun 3, 2024
1 parent 61b0b04 commit 02f1bfc
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 3 deletions.
7 changes: 4 additions & 3 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ def to_pyarrow_batches(
schema = table_expr.schema()
names = schema.names

struct_schema = schema.as_struct().to_pyarrow()
nullable_schema = schema._as_nullable()
struct_schema = nullable_schema.as_struct().to_pyarrow()

def make_gen():
yield from (
Expand All @@ -525,13 +526,13 @@ def make_gen():
# cast the struct array to the desired types to work around
# https://github.com/apache/arrow-datafusion-python/issues/534
.to_struct_array()
.cast(struct_schema)
.cast(struct_schema, safe=False)
)
for batch in frame.collect()
)

return pa.ipc.RecordBatchReader.from_batches(
schema.to_pyarrow(),
nullable_schema.to_pyarrow(),
make_gen(),
)

Expand Down
38 changes: 38 additions & 0 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,12 @@ def is_variadic(self) -> bool:
"""Return true if an instance of a Variadic type."""
return isinstance(self, Variadic)

def _as_nullable(self):
return self.copy(nullable=True)

def _as_non_nullable(self):
return self.copy(nullable=False)


@public
class Unknown(DataType, Singleton):
Expand Down Expand Up @@ -886,6 +892,18 @@ def _pretty_piece(self) -> str:
pairs = ", ".join(map("{}: {}".format, self.names, self.types))
return f"<{pairs}>"

def _as_nullable(self):
return self.copy(
fields={name: typ._as_nullable() for name, typ in self.items()},
nullable=True,
)

def _as_non_nullable(self):
return self.copy(
fields={name: typ._as_non_nullable() for name, typ in self.items()},
nullable=False,
)


T = TypeVar("T", bound=DataType, covariant=True)

Expand All @@ -903,6 +921,12 @@ class Array(Variadic, Parametric, Generic[T]):
def _pretty_piece(self) -> str:
return f"<{self.value_type}>"

def _as_nullable(self):
return self.copy(value_type=self.value_type._as_nullable(), nullable=True)

def _as_non_nullable(self):
return self.copy(value_type=self.value_type._as_non_nullable(), nullable=False)


K = TypeVar("K", bound=DataType, covariant=True)
V = TypeVar("V", bound=DataType, covariant=True)
Expand All @@ -922,6 +946,20 @@ class Map(Variadic, Parametric, Generic[K, V]):
def _pretty_piece(self) -> str:
return f"<{self.key_type}, {self.value_type}>"

def _as_nullable(self):
return self.copy(
key_type=self.key_type._as_nullable(),
value_type=self.value_type._as_nullable(),
nullable=True,
)

def _as_non_nullable(self):
return self.copy(
key_type=self.key_type._as_non_nullable(),
value_type=self.value_type._as_non_nullable(),
nullable=False,
)


@public
class JSON(Variadic):
Expand Down
75 changes: 75 additions & 0 deletions ibis/expr/datatypes/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

import datetime # noqa: TCH003
import decimal # noqa: TCH003
import string
import uuid # noqa: TCH003
from dataclasses import dataclass
from typing import Annotated, NamedTuple

import hypothesis as h
import hypothesis.strategies as st
import pytest
from pytest import param

import ibis.expr.datatypes as dt
import ibis.tests.strategies as its
from ibis.common.annotations import ValidationError
from ibis.common.patterns import As, Attrs, NoMatch, Pattern
from ibis.common.temporal import TimestampUnit, TimeUnit
Expand Down Expand Up @@ -709,3 +713,74 @@ def test_type_roundtrip(dtype, fmt):
def test_dtype_from_polars():
pl = pytest.importorskip("polars")
assert dt.dtype(pl.Int64) == dt.int64


@pytest.mark.parametrize(
("dtype", "ndtype"),
[
param(dt.int8, dt.Int8(nullable=False), id="primitive"),
param(
dt.Array(dt.int8),
dt.Array(dt.Int8(nullable=False), nullable=False),
id="array",
),
param(
dt.Map(dt.int8, dt.string),
dt.Map(dt.Int8(nullable=False), dt.String(nullable=False), nullable=False),
id="map",
),
param(
dt.Struct(
{
"a": "int8",
"b": "array<array<string>>",
"c": "struct<d: array<float32>, e: map<string, array<array<int64>>>>",
}
),
dt.Struct(
{
"a": "!int8",
"b": "!array<!array<!string>>",
"c": "!struct<d: !array<!float32>, e: !map<!string, !array<!array<!int64>>>>",
},
nullable=False,
),
id="struct",
),
],
)
def test_as_nullable_as_non_nullable(dtype, ndtype):
assert dtype._as_non_nullable() == ndtype
assert ndtype._as_nullable() == dtype


field_names = st.text(
alphabet=st.characters(
whitelist_characters=string.ascii_letters + string.digits,
whitelist_categories=(),
)
)

roundtrippable_dtypes = st.deferred(
lambda: (
its.primitive_dtypes()
| its.string_like_dtypes()
| its.temporal_dtypes()
| its.interval_dtype()
| its.variadic_dtypes()
| its.struct_dtypes(names=field_names, types=roundtrippable_dtypes)
| its.array_dtypes(roundtrippable_dtypes)
| its.map_dtypes(roundtrippable_dtypes, roundtrippable_dtypes)
)
)


@h.given(roundtrippable_dtypes)
def test_as_nullable_as_non_nullable_simple(dtype):
nullable_dtype = dtype._as_nullable()
assert nullable_dtype.nullable is True
assert "!" not in str(nullable_dtype)

non_nullable_dtype = dtype._as_non_nullable()
assert non_nullable_dtype.nullable is False
assert "!" in str(non_nullable_dtype)
10 changes: 10 additions & 0 deletions ibis/expr/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ def equals(self, other: Schema) -> bool:
)
return self == other

def _as_nullable(self):
"""Recursively convert non-nullable types to be nullable."""
return self.__class__({name: typ._as_nullable() for name, typ in self.items()})

def _as_non_nullable(self):
"""Recursively convert nullable types to be non-nullable."""
return self.__class__(
{name: typ._as_non_nullable() for name, typ in self.items()}
)

@classmethod
def from_tuples(
cls,
Expand Down

0 comments on commit 02f1bfc

Please sign in to comment.