Skip to content

Commit

Permalink
TST: ArrowExtensionArray with string and binary types (#49172)
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke authored Oct 19, 2022
1 parent 91de4cc commit a04754e
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 14 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ Conversion
- Bug in :meth:`DataFrame.eval` incorrectly raising an ``AttributeError`` when there are negative values in function call (:issue:`46471`)
- Bug in :meth:`Series.convert_dtypes` not converting dtype to nullable dtype when :class:`Series` contains ``NA`` and has dtype ``object`` (:issue:`48791`)
- Bug where any :class:`ExtensionDtype` subclass with ``kind="M"`` would be interpreted as a timezone type (:issue:`34986`)
-
- Bug in :class:`.arrays.ArrowExtensionArray` that would raise ``NotImplementedError`` when passed a sequence of strings or binary (:issue:`49172`)

Strings
^^^^^^^
Expand Down
7 changes: 6 additions & 1 deletion pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,11 @@
SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES

# pa.float16 doesn't seem supported
# https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
STRING_PYARROW_DTYPES = [pa.string(), pa.utf8()]
STRING_PYARROW_DTYPES = [pa.string()]
BINARY_PYARROW_DTYPES = [pa.binary()]

TIME_PYARROW_DTYPES = [
pa.time32("s"),
Expand All @@ -225,6 +228,8 @@
ALL_PYARROW_DTYPES = (
ALL_INT_PYARROW_DTYPES
+ FLOAT_PYARROW_DTYPES
+ STRING_PYARROW_DTYPES
+ BINARY_PYARROW_DTYPES
+ TIME_PYARROW_DTYPES
+ DATE_PYARROW_DTYPES
+ DATETIME_PYARROW_DTYPES
Expand Down
9 changes: 7 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,13 @@ def _from_sequence_of_strings(
Construct a new ExtensionArray from a sequence of strings.
"""
pa_type = to_pyarrow_type(dtype)
if pa_type is None:
# Let pyarrow try to infer or raise
if (
pa_type is None
or pa.types.is_binary(pa_type)
or pa.types.is_string(pa_type)
):
# pa_type is None: Let pa.array infer
# pa_type is string/binary: scalars already correct type
scalars = strings
elif pa.types.is_timestamp(pa_type):
from pandas.core.tools.datetimes import to_datetime
Expand Down
110 changes: 100 additions & 10 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
time,
timedelta,
)
from io import (
BytesIO,
StringIO,
)

import numpy as np
import pytest
Expand Down Expand Up @@ -90,6 +94,10 @@ def data(dtype):
+ [None]
+ [time(0, 5), time(5, 0)]
)
elif pa.types.is_string(pa_dtype):
data = ["a", "b"] * 4 + [None] + ["1", "2"] * 44 + [None] + ["!", ">"]
elif pa.types.is_binary(pa_dtype):
data = [b"a", b"b"] * 4 + [None] + [b"1", b"2"] * 44 + [None] + [b"!", b">"]
else:
raise NotImplementedError
return pd.array(data, dtype=dtype)
Expand Down Expand Up @@ -155,6 +163,14 @@ def data_for_grouping(dtype):
A = time(0, 0)
B = time(0, 12)
C = time(12, 12)
elif pa.types.is_string(pa_dtype):
A = "a"
B = "b"
C = "c"
elif pa.types.is_binary(pa_dtype):
A = b"a"
B = b"b"
C = b"c"
else:
raise NotImplementedError
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
Expand Down Expand Up @@ -203,17 +219,30 @@ def na_value():


class TestBaseCasting(base.BaseCastingTests):
pass
def test_astype_str(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_binary(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"For {pa_dtype} .astype(str) decodes.",
)
)
super().test_astype_str(data)


class TestConstructors(base.BaseConstructorsTests):
def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz:
if (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) or pa.types.is_string(
pa_dtype
):
if pa.types.is_string(pa_dtype):
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
else:
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
reason=reason,
)
)
super().test_from_dtype(data)
Expand Down Expand Up @@ -302,7 +331,7 @@ class TestGetitemTests(base.BaseGetitemTests):
reason=(
"data.dtype.type return pyarrow.DataType "
"but this (intentionally) returns "
"Python scalars or pd.Na"
"Python scalars or pd.NA"
)
)
def test_getitem_scalar(self, data):
Expand Down Expand Up @@ -361,7 +390,11 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
or pa.types.is_boolean(pa_dtype)
) and not (
all_numeric_reductions in {"min", "max"}
and (pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype))
and (
(pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype))
or pa.types.is_string(pa_dtype)
or pa.types.is_binary(pa_dtype)
)
):
request.node.add_marker(xfail_mark)
elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in {
Expand Down Expand Up @@ -494,6 +527,16 @@ def test_construct_from_string_own_name(self, dtype, request):
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)
elif pa.types.is_string(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=(
"Still support StringDtype('pyarrow') "
"over ArrowDtype(pa.string())"
),
)
)
super().test_construct_from_string_own_name(dtype)

def test_is_dtype_from_name(self, dtype, request):
Expand All @@ -505,6 +548,15 @@ def test_is_dtype_from_name(self, dtype, request):
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)
elif pa.types.is_string(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=(
"Still support StringDtype('pyarrow') "
"over ArrowDtype(pa.string())"
),
)
)
super().test_is_dtype_from_name(dtype)

def test_construct_from_string(self, dtype, request):
Expand All @@ -516,6 +568,16 @@ def test_construct_from_string(self, dtype, request):
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)
elif pa.types.is_string(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=(
"Still support StringDtype('pyarrow') "
"over ArrowDtype(pa.string())"
),
)
)
super().test_construct_from_string(dtype)

def test_construct_from_string_another_type_raises(self, dtype):
Expand All @@ -533,6 +595,8 @@ def test_get_common_dtype(self, dtype, request):
and (pa_dtype.unit != "ns" or pa_dtype.tz is not None)
)
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
or pa.types.is_string(pa_dtype)
or pa.types.is_binary(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
Expand Down Expand Up @@ -592,7 +656,21 @@ def test_EA_types(self, engine, data, request):
reason=f"Parameterized types with tz={pa_dtype.tz} not supported.",
)
)
super().test_EA_types(engine, data)
elif pa.types.is_binary(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
)
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
csv_output = df.to_csv(index=False, na_rep=np.nan)
if pa.types.is_binary(pa_dtype):
csv_output = BytesIO(csv_output)
else:
csv_output = StringIO(csv_output)
result = pd.read_csv(
csv_output, dtype={"with_dtype": str(data.dtype)}, engine=engine
)
expected = df
self.assert_frame_equal(result, expected)


class TestBaseUnaryOps(base.BaseUnaryOpsTests):
Expand Down Expand Up @@ -899,7 +977,11 @@ def test_arith_series_with_scalar(
or all_arithmetic_operators in ("__sub__", "__rsub__")
and pa.types.is_temporal(pa_dtype)
)
if all_arithmetic_operators in {
if all_arithmetic_operators == "__rmod__" and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
pytest.skip("Skip testing Python string formatting")
elif all_arithmetic_operators in {
"__mod__",
"__rmod__",
}:
Expand Down Expand Up @@ -965,7 +1047,11 @@ def test_arith_frame_with_scalar(
or all_arithmetic_operators in ("__sub__", "__rsub__")
and pa.types.is_temporal(pa_dtype)
)
if all_arithmetic_operators in {
if all_arithmetic_operators == "__rmod__" and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
pytest.skip("Skip testing Python string formatting")
elif all_arithmetic_operators in {
"__mod__",
"__rmod__",
}:
Expand Down Expand Up @@ -1224,7 +1310,11 @@ def test_quantile(data, interpolation, quantile, request):
)
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_temporal(pa_dtype):
if (
pa.types.is_temporal(pa_dtype)
or pa.types.is_string(pa_dtype)
or pa.types.is_binary(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand Down

0 comments on commit a04754e

Please sign in to comment.