Skip to content

Commit

Permalink
feat(type-system): infer pandas' string dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 8, 2022
1 parent 786a50f commit 5f0eb5d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
9 changes: 7 additions & 2 deletions ibis/backends/pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,15 @@ def from_pandas_tzdtype(value):


@dt.dtype.register(CategoricalDtype)
def from_pandas_categorical(value):
def from_pandas_categorical(_):
return dt.Category()


@dt.dtype.register(pd.core.arrays.string_.StringDtype)
def from_pandas_string(_):
return dt.String()


@dt.infer.register(np.generic)
def infer_numpy_scalar(value):
return dt.dtype(value.dtype)
Expand Down Expand Up @@ -206,7 +211,7 @@ def infer_pandas_schema(df, schema=None):
schema = schema if schema is not None else {}

pairs = []
for column_name, pandas_dtype in df.dtypes.iteritems():
for column_name in df.dtypes.keys():
if not isinstance(column_name, str):
raise TypeError(
'Column names must be strings to use the pandas backend'
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/pandas/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_numpy_dtype(numpy_dtype, ibis_dtype):
dt.Timestamp('US/Eastern'),
),
(CategoricalDtype(), dt.Category()),
(pd.Series([], dtype="string").dtype, dt.String()),
],
)
def test_pandas_dtype(pandas_dtype, ibis_dtype):
Expand Down Expand Up @@ -206,6 +207,7 @@ def test_pandas_dtype(pandas_dtype, ibis_dtype):
(pd.Series([b'1', '2', 3.0]), dt.binary),
# empty
(pd.Series([], dtype='object'), dt.binary),
(pd.Series([], dtype="string"), dt.string),
],
)
def test_schema_infer(col_data, schema_type):
Expand All @@ -214,3 +216,10 @@ def test_schema_infer(col_data, schema_type):
inferred = sch.infer(df)
expected = ibis.schema([('col', schema_type)])
assert inferred == expected


def test_pyarrow_string():
pytest.importorskip("pyarrow")

s = pd.Series([], dtype="string[pyarrow]")
assert dt.dtype(s.dtype) == dt.String()

0 comments on commit 5f0eb5d

Please sign in to comment.