From ee2b8b564c10292a86202129295f1ebc47c9bafe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Vinot?= Date: Mon, 8 Mar 2021 15:32:31 +0100 Subject: [PATCH] ENH: Implement rounding for floating dtype array #38844 (#39751) --- doc/source/whatsnew/v1.3.0.rst | 1 + pandas/core/arrays/numeric.py | 32 +++++++++++++++ pandas/tests/arrays/masked/test_function.py | 44 +++++++++++++++++++++ pandas/tests/series/methods/test_round.py | 28 ++++++++----- 4 files changed, 95 insertions(+), 10 deletions(-) create mode 100644 pandas/tests/arrays/masked/test_function.py diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 880e1051c87bf..370ea28832758 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -140,6 +140,7 @@ Other enhancements - :meth:`pandas.read_stata` and :class:`StataReader` support reading data from compressed files. - Add support for parsing ``ISO 8601``-like timestamps with negative signs to :meth:`pandas.Timedelta` (:issue:`37172`) - Add support for unary operators in :class:`FloatingArray` (:issue:`38749`) +- :meth:`round` being enabled for the nullable integer and floating dtypes (:issue:`38844`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/arrays/numeric.py b/pandas/core/arrays/numeric.py index 0dd98c5e3d3f2..f06099a642833 100644 --- a/pandas/core/arrays/numeric.py +++ b/pandas/core/arrays/numeric.py @@ -6,6 +6,7 @@ TYPE_CHECKING, Any, List, + TypeVar, Union, ) @@ -15,6 +16,7 @@ Timedelta, missing as libmissing, ) +from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.core.dtypes.common import ( @@ -34,6 +36,8 @@ if TYPE_CHECKING: import pyarrow +T = TypeVar("T", bound="NumericArray") + class NumericDtype(BaseMaskedDtype): def __from_arrow__( @@ -208,3 +212,31 @@ def __pos__(self): def __abs__(self): return type(self)(abs(self._data), self._mask.copy()) + + def round(self: T, decimals: int = 0, *args, **kwargs) -> T: + """ + Round each value in the array a to the given number of decimals. + + Parameters + ---------- + decimals : int, default 0 + Number of decimal places to round to. If decimals is negative, + it specifies the number of positions to the left of the decimal point. + *args, **kwargs + Additional arguments and keywords have no effect but might be + accepted for compatibility with NumPy. + + Returns + ------- + NumericArray + Rounded values of the NumericArray. + + See Also + -------- + numpy.around : Round values of an np.array. + DataFrame.round : Round values of a DataFrame. + Series.round : Round values of a Series. + """ + nv.validate_round(args, kwargs) + values = np.round(self._data, decimals=decimals, **kwargs) + return type(self)(values, self._mask.copy()) diff --git a/pandas/tests/arrays/masked/test_function.py b/pandas/tests/arrays/masked/test_function.py new file mode 100644 index 0000000000000..1c0e0820f7dcc --- /dev/null +++ b/pandas/tests/arrays/masked/test_function.py @@ -0,0 +1,44 @@ +import numpy as np +import pytest + +from pandas.core.dtypes.common import is_integer_dtype + +import pandas as pd +import pandas._testing as tm + +arrays = [pd.array([1, 2, 3, None], dtype=dtype) for dtype in tm.ALL_EA_INT_DTYPES] +arrays += [ + pd.array([0.141, -0.268, 5.895, None], dtype=dtype) for dtype in tm.FLOAT_EA_DTYPES +] + + +@pytest.fixture(params=arrays, ids=[a.dtype.name for a in arrays]) +def data(request): + return request.param + + +@pytest.fixture() +def numpy_dtype(data): + # For integer dtype, the numpy conversion must be done to float + if is_integer_dtype(data): + numpy_dtype = float + else: + numpy_dtype = data.dtype.type + return numpy_dtype + + +def test_round(data, numpy_dtype): + # No arguments + result = data.round() + expected = pd.array( + np.round(data.to_numpy(dtype=numpy_dtype, na_value=None)), dtype=data.dtype + ) + tm.assert_extension_array_equal(result, expected) + + # Decimals argument + result = data.round(decimals=2) + expected = pd.array( + np.round(data.to_numpy(dtype=numpy_dtype, na_value=None), decimals=2), + dtype=data.dtype, + ) + tm.assert_extension_array_equal(result, expected) diff --git a/pandas/tests/series/methods/test_round.py b/pandas/tests/series/methods/test_round.py index 88d5c428712dc..7ab19a05159a4 100644 --- a/pandas/tests/series/methods/test_round.py +++ b/pandas/tests/series/methods/test_round.py @@ -16,33 +16,41 @@ def test_round(self, datetime_series): tm.assert_series_equal(result, expected) assert result.name == datetime_series.name - def test_round_numpy(self): + def test_round_numpy(self, any_float_allowed_nullable_dtype): # See GH#12600 - ser = Series([1.53, 1.36, 0.06]) + ser = Series([1.53, 1.36, 0.06], dtype=any_float_allowed_nullable_dtype) out = np.round(ser, decimals=0) - expected = Series([2.0, 1.0, 0.0]) + expected = Series([2.0, 1.0, 0.0], dtype=any_float_allowed_nullable_dtype) tm.assert_series_equal(out, expected) msg = "the 'out' parameter is not supported" with pytest.raises(ValueError, match=msg): np.round(ser, decimals=0, out=ser) - def test_round_numpy_with_nan(self): + def test_round_numpy_with_nan(self, any_float_allowed_nullable_dtype): # See GH#14197 - ser = Series([1.53, np.nan, 0.06]) + ser = Series([1.53, np.nan, 0.06], dtype=any_float_allowed_nullable_dtype) with tm.assert_produces_warning(None): result = ser.round() - expected = Series([2.0, np.nan, 0.0]) + expected = Series([2.0, np.nan, 0.0], dtype=any_float_allowed_nullable_dtype) tm.assert_series_equal(result, expected) - def test_round_builtin(self): - ser = Series([1.123, 2.123, 3.123], index=range(3)) + def test_round_builtin(self, any_float_allowed_nullable_dtype): + ser = Series( + [1.123, 2.123, 3.123], + index=range(3), + dtype=any_float_allowed_nullable_dtype, + ) result = round(ser) - expected_rounded0 = Series([1.0, 2.0, 3.0], index=range(3)) + expected_rounded0 = Series( + [1.0, 2.0, 3.0], index=range(3), dtype=any_float_allowed_nullable_dtype + ) tm.assert_series_equal(result, expected_rounded0) decimals = 2 - expected_rounded = Series([1.12, 2.12, 3.12], index=range(3)) + expected_rounded = Series( + [1.12, 2.12, 3.12], index=range(3), dtype=any_float_allowed_nullable_dtype + ) result = round(ser, decimals) tm.assert_series_equal(result, expected_rounded)