diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index 19fb4bdcd9536..cb29e291f28b0 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -840,6 +840,7 @@ ExtensionArray ^^^^^^^^^^^^^^ - Bug in :class:`arrays.PandasArray` when setting a scalar string (:issue:`28118`, :issue:`28150`). +- Bug where nullable integers could not be compared to strings (:issue:`28930`) - diff --git a/pandas/conftest.py b/pandas/conftest.py index 3553a411a27f8..6b43bf58b5046 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -654,6 +654,24 @@ def any_int_dtype(request): return request.param +@pytest.fixture(params=ALL_EA_INT_DTYPES) +def any_nullable_int_dtype(request): + """ + Parameterized fixture for any nullable integer dtype. + + * 'UInt8' + * 'Int8' + * 'UInt16' + * 'Int16' + * 'UInt32' + * 'Int32' + * 'UInt64' + * 'Int64' + """ + + return request.param + + @pytest.fixture(params=ALL_REAL_DTYPES) def any_real_dtype(request): """ diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index 2bfb53aa1c800..08a3eca1e9055 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -26,6 +26,7 @@ from pandas.core import nanops, ops from pandas.core.algorithms import take from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin +from pandas.core.ops import invalid_comparison from pandas.core.ops.common import unpack_zerodim_and_defer from pandas.core.tools.numeric import to_numeric @@ -646,7 +647,11 @@ def cmp_method(self, other): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "elementwise", FutureWarning) with np.errstate(all="ignore"): - result = op(self._data, other) + method = getattr(self._data, f"__{op_name}__") + result = method(other) + + if result is NotImplemented: + result = invalid_comparison(self._data, other, op) # nans propagate if mask is None: diff --git a/pandas/tests/extension/test_integer.py b/pandas/tests/extension/test_integer.py index d051345fdd12d..f94dbfcc3ec6c 100644 --- a/pandas/tests/extension/test_integer.py +++ b/pandas/tests/extension/test_integer.py @@ -168,6 +168,27 @@ def check_opname(self, s, op_name, other, exc=None): def _compare_other(self, s, data, op_name, other): self.check_opname(s, op_name, other) + def test_compare_to_string(self, any_nullable_int_dtype): + # GH 28930 + s = pd.Series([1, None], dtype=any_nullable_int_dtype) + result = s == "a" + expected = pd.Series([False, False]) + + self.assert_series_equal(result, expected) + + def test_compare_to_int(self, any_nullable_int_dtype, all_compare_operators): + # GH 28930 + s1 = pd.Series([1, 2, 3], dtype=any_nullable_int_dtype) + s2 = pd.Series([1, 2, 3], dtype="int") + + method = getattr(s1, all_compare_operators) + result = method(2) + + method = getattr(s2, all_compare_operators) + expected = method(2) + + self.assert_series_equal(result, expected) + class TestInterface(base.BaseInterfaceTests): pass