diff --git a/ibis/backends/pyarrow/datatypes.py b/ibis/backends/pyarrow/datatypes.py index cb7069f4fd7d..7b189bf54299 100644 --- a/ibis/backends/pyarrow/datatypes.py +++ b/ibis/backends/pyarrow/datatypes.py @@ -30,28 +30,36 @@ @functools.singledispatch -def to_pyarrow_type(dtype): +def to_pyarrow_type(dtype: dt.DataType): return _to_pyarrow_types[dtype.__class__] @to_pyarrow_type.register(dt.Array) -def from_ibis_array(dtype): - return pa.list_(to_pyarrow_type(dtype.value_type)) - - @to_pyarrow_type.register(dt.Set) -def from_ibis_set(dtype): +def from_ibis_collection(dtype: dt.Array | dt.Set): return pa.list_(to_pyarrow_type(dtype.value_type)) -@to_pyarrow_type.register(dt.Interval) -def from_ibis_interval(dtype): +@to_pyarrow_type.register +def from_ibis_interval(dtype: dt.Interval): try: return pa.duration(dtype.unit) except ValueError: raise com.IbisTypeError(f"Unsupported interval unit: {dtype.unit}") +@to_pyarrow_type.register +def from_ibis_struct(dtype: dt.Struct): + return pa.struct( + pa.field(name, to_pyarrow_type(typ)) for name, typ in dtype.pairs.items() + ) + + +@to_pyarrow_type.register +def from_ibis_map(dtype: dt.Map): + return pa.map_(to_pyarrow_type(dtype.key_type), to_pyarrow_type(dtype.value_type)) + + _to_ibis_dtypes = { pa.int8(): dt.Int8, pa.int16(): dt.Int16,