Skip to content

Commit

Permalink
chore(datatypes): make annotations more precise
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Apr 13, 2023
1 parent b217cde commit 7c747ae
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions ibis/backends/pyarrow/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
dt.String: pa.string(),
dt.Binary: pa.binary(),
dt.Boolean: pa.bool_(),
dt.Timestamp: pa.timestamp('ns'),
dt.Date: pa.date64(),
dt.Time: pa.time64("us"),
dt.Timestamp: pa.timestamp('ns'),
dt.JSON: pa.string(),
dt.Null: pa.null(),
# assume unknown types can be converted into strings
Expand All @@ -34,36 +35,37 @@


@functools.singledispatch
def to_pyarrow_type(dtype: dt.DataType):
arrow_type = _to_pyarrow_types.get(dtype.__class__)
if not arrow_type:
raise NotImplementedError(f'Unsupported type: {dtype!r}')
def to_pyarrow_type(dtype: dt.DataType) -> pa.DataType:
if (arrow_type := _to_pyarrow_types.get(dtype.__class__)) is None:
raise NotImplementedError(
f"Unsupported conversion from ibis type to pyarrow type: {dtype!r}"
)
return arrow_type


@to_pyarrow_type.register(dt.Array)
@to_pyarrow_type.register(dt.Set)
def from_ibis_collection(dtype: dt.Array | dt.Set):
def from_ibis_collection(dtype: dt.Array | dt.Set) -> pa.ListType:
return pa.list_(to_pyarrow_type(dtype.value_type))


@to_pyarrow_type.register
def from_ibis_interval(dtype: dt.Interval):
def from_ibis_interval(dtype: dt.Interval) -> pa.DurationType:
try:
return pa.duration(dtype.unit)
except ValueError:
raise com.IbisTypeError(f"Unsupported interval unit: {dtype.unit}")


@to_pyarrow_type.register
def from_ibis_struct(dtype: dt.Struct):
def from_ibis_struct(dtype: dt.Struct) -> pa.StructType:
return pa.struct(
pa.field(name, to_pyarrow_type(typ)) for name, typ in dtype.fields.items()
)


@to_pyarrow_type.register
def from_ibis_map(dtype: dt.Map):
def from_ibis_map(dtype: dt.Map) -> pa.MapType:
return pa.map_(to_pyarrow_type(dtype.key_type), to_pyarrow_type(dtype.value_type))


Expand Down Expand Up @@ -96,12 +98,11 @@ def from_ibis_map(dtype: dt.Map):
def from_pyarrow_primitive(
arrow_type: pa.DataType, nullable: bool = True
) -> dt.DataType:
dtype = _to_ibis_dtypes.get(arrow_type, dt.unknown)
dtype = _to_ibis_dtypes.get(arrow_type, dt.Unknown)
return dtype(nullable=nullable)


@dt.dtype.register(pa.Decimal128Type)
@dt.dtype.register(pa.Decimal256Type)
@dt.dtype.register((pa.Decimal128Type, pa.Decimal256Type))
def from_pyarrow_decimal(
arrow_type: pa.Decimal128Type | pa.Decimal256Type, nullable: bool = True
) -> dt.Decimal:
Expand All @@ -110,21 +111,20 @@ def from_pyarrow_decimal(
)


@dt.dtype.register(pa.Time32Type) # type: ignore[misc]
@dt.dtype.register(pa.Time64Type) # type: ignore[misc]
@dt.dtype.register((pa.Time32Type, pa.Time64Type)) # type: ignore[misc]
def from_pyarrow_time(
_: pa.Time32Type | pa.Time64Type, nullable: bool = True
) -> dt.DataType:
) -> dt.Time:
return dt.Time(nullable=nullable)


@dt.dtype.register(pa.ListType) # type: ignore[misc]
def from_pyarrow_list(arrow_type: pa.ListType, nullable: bool = True) -> dt.DataType:
def from_pyarrow_list(arrow_type: pa.ListType, nullable: bool = True) -> dt.Array:
return dt.Array(dt.dtype(arrow_type.value_type), nullable=nullable)


@dt.dtype.register(pa.MapType) # type: ignore[misc]
def from_pyarrow_map(arrow_type: pa.MapType, nullable: bool = True) -> dt.DataType:
def from_pyarrow_map(arrow_type: pa.MapType, nullable: bool = True) -> dt.Map:
return dt.Map(
dt.dtype(arrow_type.key_type),
dt.dtype(arrow_type.item_type),
Expand All @@ -133,9 +133,7 @@ def from_pyarrow_map(arrow_type: pa.MapType, nullable: bool = True) -> dt.DataTy


@dt.dtype.register(pa.StructType) # type: ignore[misc]
def from_pyarrow_struct(
arrow_type: pa.StructType, nullable: bool = True
) -> dt.DataType:
def from_pyarrow_struct(arrow_type: pa.StructType, nullable: bool = True) -> dt.Struct:
return dt.Struct.from_tuples(
((field.name, dt.dtype(field.type)) for field in arrow_type),
nullable=nullable,
Expand All @@ -145,18 +143,20 @@ def from_pyarrow_struct(
@dt.dtype.register(pa.TimestampType) # type: ignore[misc]
def from_pyarrow_timestamp(
arrow_type: pa.TimestampType, nullable: bool = True
) -> dt.DataType:
) -> dt.Timestamp:
return dt.Timestamp(timezone=arrow_type.tz, nullable=nullable)


@sch.schema.register(pa.Schema) # type: ignore[misc]
def from_pyarrow_schema(schema: pa.Schema) -> sch.Schema:
return sch.schema([(f.name, dt.dtype(f.type, nullable=f.nullable)) for f in schema])
return sch.schema((f.name, dt.dtype(f.type, nullable=f.nullable)) for f in schema)


def _schema_to_pyarrow_schema_fields(schema: sch.Schema) -> Iterable[pa.Field]:
for name, dtype in schema.items():
yield pa.field(name, dtype.to_pyarrow(), nullable=dtype.nullable)
return (
pa.field(name, dtype.to_pyarrow(), nullable=dtype.nullable)
for name, dtype in schema.items()
)


def ibis_to_pyarrow_struct(schema: sch.Schema) -> pa.StructType:
Expand Down

0 comments on commit 7c747ae

Please sign in to comment.