Skip to content

Commit

Permalink
Backport PR #48443 on branch 1.5.x (BUG: Fix pyarrow groupby tests) (#…
Browse files Browse the repository at this point in the history
…48494)

* BUG: Fix pyarrow groupby tests (#48443)

# Conflicts:
#	pandas/tests/extension/test_arrow.py

* CI: Fix failing tests (#48493)

Co-authored-by: jbrockmendel <jbrockmendel@gmail.com>
  • Loading branch information
phofl and jbrockmendel committed Sep 12, 2022
1 parent ad087f5 commit e6a014f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 50 deletions.
5 changes: 4 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,10 @@ def _set_axis(self, axis: int, labels: AnyArrayLike | list) -> None:
"""
labels = ensure_index(labels)

if labels._is_all_dates:
if labels._is_all_dates and not (
type(labels) is Index and not isinstance(labels.dtype, np.dtype)
):
# exclude e.g. timestamp[ns][pyarrow] dtype from this casting
deep_labels = labels
if isinstance(labels, CategoricalIndex):
deep_labels = labels.categories
Expand Down
60 changes: 11 additions & 49 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
pa_version_under8p0,
pa_version_under9p0,
)
from pandas.errors import PerformanceWarning

import pandas as pd
import pandas._testing as tm
Expand Down Expand Up @@ -515,15 +516,6 @@ def test_groupby_extension_no_sort(self, data_for_grouping, request):
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
)
)
elif pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
super().test_groupby_extension_no_sort(data_for_grouping)

def test_groupby_extension_transform(self, data_for_grouping, request):
Expand All @@ -547,8 +539,7 @@ def test_groupby_extension_apply(
self, data_for_grouping, groupby_apply_op, request
):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
# Is there a better way to get the "series" ID for groupby_apply_op?
is_series = "series" in request.node.nodeid
# TODO: Is there a better way to get the "object" ID for groupby_apply_op?
is_object = "object" in request.node.nodeid
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
Expand All @@ -567,14 +558,10 @@ def test_groupby_extension_apply(
reason="GH 47514: _concat_datetime expects axis arg.",
)
)
elif not is_series:
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)
with tm.maybe_produces_warning(
PerformanceWarning, pa_version_under7p0, check_stacklevel=False
):
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)

def test_in_numeric_groupby(self, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
Expand Down Expand Up @@ -603,17 +590,10 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
)
)
elif as_index is True and (
pa.types.is_date(pa_dtype)
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
with tm.maybe_produces_warning(
PerformanceWarning, pa_version_under7p0, check_stacklevel=False
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
super().test_groupby_extension_agg(as_index, data_for_grouping)
super().test_groupby_extension_agg(as_index, data_for_grouping)


class TestBaseDtype(base.BaseDtypeTests):
Expand Down Expand Up @@ -1443,16 +1423,7 @@ def test_diff(self, data, periods, request):
@pytest.mark.parametrize("dropna", [True, False])
def test_value_counts(self, all_data, dropna, request):
pa_dtype = all_data.dtype.pyarrow_dtype
if pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
elif pa.types.is_duration(pa_dtype):
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand All @@ -1463,16 +1434,7 @@ def test_value_counts(self, all_data, dropna, request):

def test_value_counts_with_normalize(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
elif pa.types.is_duration(pa_dtype):
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand Down

0 comments on commit e6a014f

Please sign in to comment.