Skip to content

Commit

Permalink
ENH: Implement rounding for floating dtype array #38844 (#39751)
Browse files Browse the repository at this point in the history
  • Loading branch information
benoit9126 committed Mar 8, 2021
1 parent 8c62fbb commit ee2b8b5
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 10 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Other enhancements
- :meth:`pandas.read_stata` and :class:`StataReader` support reading data from compressed files.
- Add support for parsing ``ISO 8601``-like timestamps with negative signs to :meth:`pandas.Timedelta` (:issue:`37172`)
- Add support for unary operators in :class:`FloatingArray` (:issue:`38749`)
- :meth:`round` being enabled for the nullable integer and floating dtypes (:issue:`38844`)

.. ---------------------------------------------------------------------------
Expand Down
32 changes: 32 additions & 0 deletions pandas/core/arrays/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Any,
List,
TypeVar,
Union,
)

Expand All @@ -15,6 +16,7 @@
Timedelta,
missing as libmissing,
)
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError

from pandas.core.dtypes.common import (
Expand All @@ -34,6 +36,8 @@
if TYPE_CHECKING:
import pyarrow

T = TypeVar("T", bound="NumericArray")


class NumericDtype(BaseMaskedDtype):
def __from_arrow__(
Expand Down Expand Up @@ -208,3 +212,31 @@ def __pos__(self):

def __abs__(self):
return type(self)(abs(self._data), self._mask.copy())

def round(self: T, decimals: int = 0, *args, **kwargs) -> T:
"""
Round each value in the array a to the given number of decimals.
Parameters
----------
decimals : int, default 0
Number of decimal places to round to. If decimals is negative,
it specifies the number of positions to the left of the decimal point.
*args, **kwargs
Additional arguments and keywords have no effect but might be
accepted for compatibility with NumPy.
Returns
-------
NumericArray
Rounded values of the NumericArray.
See Also
--------
numpy.around : Round values of an np.array.
DataFrame.round : Round values of a DataFrame.
Series.round : Round values of a Series.
"""
nv.validate_round(args, kwargs)
values = np.round(self._data, decimals=decimals, **kwargs)
return type(self)(values, self._mask.copy())
44 changes: 44 additions & 0 deletions pandas/tests/arrays/masked/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_integer_dtype

import pandas as pd
import pandas._testing as tm

arrays = [pd.array([1, 2, 3, None], dtype=dtype) for dtype in tm.ALL_EA_INT_DTYPES]
arrays += [
pd.array([0.141, -0.268, 5.895, None], dtype=dtype) for dtype in tm.FLOAT_EA_DTYPES
]


@pytest.fixture(params=arrays, ids=[a.dtype.name for a in arrays])
def data(request):
return request.param


@pytest.fixture()
def numpy_dtype(data):
# For integer dtype, the numpy conversion must be done to float
if is_integer_dtype(data):
numpy_dtype = float
else:
numpy_dtype = data.dtype.type
return numpy_dtype


def test_round(data, numpy_dtype):
# No arguments
result = data.round()
expected = pd.array(
np.round(data.to_numpy(dtype=numpy_dtype, na_value=None)), dtype=data.dtype
)
tm.assert_extension_array_equal(result, expected)

# Decimals argument
result = data.round(decimals=2)
expected = pd.array(
np.round(data.to_numpy(dtype=numpy_dtype, na_value=None), decimals=2),
dtype=data.dtype,
)
tm.assert_extension_array_equal(result, expected)
28 changes: 18 additions & 10 deletions pandas/tests/series/methods/test_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,41 @@ def test_round(self, datetime_series):
tm.assert_series_equal(result, expected)
assert result.name == datetime_series.name

def test_round_numpy(self):
def test_round_numpy(self, any_float_allowed_nullable_dtype):
# See GH#12600
ser = Series([1.53, 1.36, 0.06])
ser = Series([1.53, 1.36, 0.06], dtype=any_float_allowed_nullable_dtype)
out = np.round(ser, decimals=0)
expected = Series([2.0, 1.0, 0.0])
expected = Series([2.0, 1.0, 0.0], dtype=any_float_allowed_nullable_dtype)
tm.assert_series_equal(out, expected)

msg = "the 'out' parameter is not supported"
with pytest.raises(ValueError, match=msg):
np.round(ser, decimals=0, out=ser)

def test_round_numpy_with_nan(self):
def test_round_numpy_with_nan(self, any_float_allowed_nullable_dtype):
# See GH#14197
ser = Series([1.53, np.nan, 0.06])
ser = Series([1.53, np.nan, 0.06], dtype=any_float_allowed_nullable_dtype)
with tm.assert_produces_warning(None):
result = ser.round()
expected = Series([2.0, np.nan, 0.0])
expected = Series([2.0, np.nan, 0.0], dtype=any_float_allowed_nullable_dtype)
tm.assert_series_equal(result, expected)

def test_round_builtin(self):
ser = Series([1.123, 2.123, 3.123], index=range(3))
def test_round_builtin(self, any_float_allowed_nullable_dtype):
ser = Series(
[1.123, 2.123, 3.123],
index=range(3),
dtype=any_float_allowed_nullable_dtype,
)
result = round(ser)
expected_rounded0 = Series([1.0, 2.0, 3.0], index=range(3))
expected_rounded0 = Series(
[1.0, 2.0, 3.0], index=range(3), dtype=any_float_allowed_nullable_dtype
)
tm.assert_series_equal(result, expected_rounded0)

decimals = 2
expected_rounded = Series([1.12, 2.12, 3.12], index=range(3))
expected_rounded = Series(
[1.12, 2.12, 3.12], index=range(3), dtype=any_float_allowed_nullable_dtype
)
result = round(ser, decimals)
tm.assert_series_equal(result, expected_rounded)

Expand Down

0 comments on commit ee2b8b5

Please sign in to comment.