Skip to content

Commit

Permalink
CLN: more accurate is_scalar checks (pandas-dev#52971)
Browse files Browse the repository at this point in the history
* REF: avoid is_scalar

* comment

* infer_dtype->is_bool_array

* fix invalid refs
  • Loading branch information
jbrockmendel authored and NumanIjaz committed May 1, 2023
1 parent e253c80 commit 9dd2faf
Show file tree
Hide file tree
Showing 17 changed files with 53 additions and 65 deletions.
5 changes: 2 additions & 3 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
is_integer_dtype,
is_list_like,
is_object_dtype,
is_scalar,
is_signed_integer_dtype,
needs_i8_conversion,
)
Expand Down Expand Up @@ -1321,15 +1320,15 @@ def searchsorted(
# Before searching below, we therefore try to give `value` the
# same dtype as `arr`, while guarding against integer overflows.
iinfo = np.iinfo(arr.dtype.type)
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
value_arr = np.array([value]) if is_integer(value) else np.array(value)
if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all():
# value within bounds, so no overflow, so can convert value dtype
# to dtype of arr
dtype = arr.dtype
else:
dtype = value_arr.dtype

if is_scalar(value):
if is_integer(value):
# We know that value is int
value = cast(int, dtype.type(value))
else:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/array_algos/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import numpy as np

from pandas.core.dtypes.common import (
is_bool,
is_re,
is_re_compilable,
is_scalar,
)
from pandas.core.dtypes.missing import isna

Expand Down Expand Up @@ -72,7 +72,7 @@ def _check_comparison_types(
Raises an error if the two arrays (a,b) cannot be compared.
Otherwise, returns the comparison result as expected.
"""
if is_scalar(result) and isinstance(a, np.ndarray):
if is_bool(result) and isinstance(a, np.ndarray):
type_names = [type(a).__name__, type(b).__name__]

type_names[0] = f"ndarray(dtype={a.dtype})"
Expand Down
2 changes: 0 additions & 2 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,8 +1120,6 @@ def searchsorted(
) -> npt.NDArray[np.intp] | np.intp:
msg = "searchsorted requires high memory usage."
warnings.warn(msg, PerformanceWarning, stacklevel=find_stack_level())
if not is_scalar(v):
v = np.asarray(v)
v = np.asarray(v)
return np.asarray(self, dtype=self.dtype.subtype).searchsorted(v, side, sorter)

Expand Down
3 changes: 1 addition & 2 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
ABCSeries,
)
from pandas.core.dtypes.inference import iterable_not_string
from pandas.core.dtypes.missing import isna

if TYPE_CHECKING:
from pandas._typing import (
Expand Down Expand Up @@ -129,7 +128,7 @@ def is_bool_indexer(key: Any) -> bool:

if not lib.is_bool_array(key_array):
na_msg = "Cannot mask with non-boolean array containing NA / NaN values"
if lib.infer_dtype(key_array) == "boolean" and isna(key_array).any():
if lib.is_bool_array(key_array, skipna=True):
# Don't raise on e.g. ["A", "B", np.nan], see
# test_loc_getitem_list_of_labels_categoricalindex_with_na
raise ValueError(na_msg)
Expand Down
2 changes: 0 additions & 2 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,6 @@ def maybe_cast_pointwise_result(
result maybe casted to the dtype.
"""

assert not is_scalar(result)

if isinstance(dtype, ExtensionDtype):
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
# TODO: avoid this special-casing
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3914,7 +3914,7 @@ def isetitem(self, loc, value) -> None:
``frame[frame.columns[i]] = value``.
"""
if isinstance(value, DataFrame):
if is_scalar(loc):
if is_integer(loc):
loc = [loc]

if len(loc) != len(value.columns):
Expand Down
4 changes: 1 addition & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8346,9 +8346,7 @@ def clip(
lower, upper = min(lower, upper), max(lower, upper)

# fast-path for scalars
if (lower is None or (is_scalar(lower) and is_number(lower))) and (
upper is None or (is_scalar(upper) and is_number(upper))
):
if (lower is None or is_number(lower)) and (upper is None or is_number(upper)):
return self._clip_with_scalar(lower, upper, inplace=inplace)

result = self
Expand Down
11 changes: 6 additions & 5 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,9 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str_t, *inputs, **kwargs):
if ufunc.nout == 2:
# i.e. np.divmod, np.modf, np.frexp
return tuple(self.__array_wrap__(x) for x in result)
elif method == "reduce":
result = lib.item_from_zerodim(result)
return result

if result.dtype == np.float16:
result = result.astype(np.float32)
Expand All @@ -928,11 +931,9 @@ def __array_wrap__(self, result, context=None):
Gets called after a ufunc and other functions e.g. np.split.
"""
result = lib.item_from_zerodim(result)
if (
(not isinstance(result, Index) and is_bool_dtype(result.dtype))
or lib.is_scalar(result)
or np.ndim(result) > 1
):
if (not isinstance(result, Index) and is_bool_dtype(result.dtype)) or np.ndim(
result
) > 1:
# exclude Index to avoid warning from is_bool_dtype deprecation;
# in the Index case it doesn't matter which path we go down.
# reached in plotting tests with e.g. np.nonzero(index)
Expand Down
17 changes: 9 additions & 8 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
is_integer,
is_numeric_dtype,
is_object_dtype,
is_scalar,
needs_i8_conversion,
pandas_dtype,
)
Expand Down Expand Up @@ -291,7 +290,6 @@ def _get_values(
# In _get_values is only called from within nanops, and in all cases
# with scalar fill_value. This guarantee is important for the
# np.where call below
assert is_scalar(fill_value)

mask = _maybe_get_mask(values, skipna, mask)

Expand Down Expand Up @@ -876,12 +874,15 @@ def _get_counts_nanvar(
d = count - dtype.type(ddof)

# always return NaN, never inf
if is_scalar(count):
if is_float(count):
if count <= ddof:
count = np.nan
# error: Incompatible types in assignment (expression has type
# "float", variable has type "Union[floating[Any], ndarray[Any,
# dtype[floating[Any]]]]")
count = np.nan # type: ignore[assignment]
d = np.nan
else:
# count is not narrowed by is_scalar check
# count is not narrowed by is_float check
count = cast(np.ndarray, count)
mask = count <= ddof
if mask.any():
Expand Down Expand Up @@ -1444,8 +1445,8 @@ def _get_counts(
values_shape: Shape,
mask: npt.NDArray[np.bool_] | None,
axis: AxisInt | None,
dtype: np.dtype = np.dtype(np.float64),
) -> float | np.ndarray:
dtype: np.dtype[np.floating] = np.dtype(np.float64),
) -> np.floating | npt.NDArray[np.floating]:
"""
Get the count of non-null values along an axis
Expand Down Expand Up @@ -1476,7 +1477,7 @@ def _get_counts(
else:
count = values_shape[axis]

if is_scalar(count):
if is_integer(count):
return dtype.type(count)
return count.astype(dtype, copy=False)

Expand Down
1 change: 1 addition & 0 deletions pandas/core/ops/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _fill_zeros(result, x, y):
is_scalar_type = is_scalar(y)

if not is_variable_type and not is_scalar_type:
# e.g. test_series_ops_name_retention with mod we get here with list/tuple
return result

if is_scalar_type:
Expand Down
5 changes: 1 addition & 4 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,10 +1019,7 @@ def _get_with(self, key):
if not isinstance(key, (list, np.ndarray, ExtensionArray, Series, Index)):
key = list(key)

if isinstance(key, Index):
key_type = key.inferred_type
else:
key_type = lib.infer_dtype(key, skipna=False)
key_type = lib.infer_dtype(key, skipna=False)

# Note: The key_type == "boolean" case should be caught by the
# com.is_bool_indexer check in __getitem__
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/strings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TYPE_CHECKING,
Callable,
Literal,
Sequence,
)

import numpy as np
Expand Down Expand Up @@ -79,7 +80,7 @@ def _str_replace(
pass

@abc.abstractmethod
def _str_repeat(self, repeats):
def _str_repeat(self, repeats: int | Sequence[int]):
pass

@abc.abstractmethod
Expand Down
19 changes: 12 additions & 7 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
TYPE_CHECKING,
Callable,
Literal,
Sequence,
cast,
)
import unicodedata

Expand All @@ -17,7 +19,6 @@
import pandas._libs.missing as libmissing
import pandas._libs.ops as libops

from pandas.core.dtypes.common import is_scalar
from pandas.core.dtypes.missing import isna

from pandas.core.strings.base import BaseStringArrayMethods
Expand Down Expand Up @@ -177,14 +178,15 @@ def _str_replace(

return self._str_map(f, dtype=str)

def _str_repeat(self, repeats):
if is_scalar(repeats):
def _str_repeat(self, repeats: int | Sequence[int]):
if lib.is_integer(repeats):
rint = cast(int, repeats)

def scalar_rep(x):
try:
return bytes.__mul__(x, repeats)
return bytes.__mul__(x, rint)
except TypeError:
return str.__mul__(x, repeats)
return str.__mul__(x, rint)

return self._str_map(scalar_rep, dtype=str)
else:
Expand All @@ -198,8 +200,11 @@ def rep(x, r):
except TypeError:
return str.__mul__(x, r)

repeats = np.asarray(repeats, dtype=object)
result = libops.vec_binop(np.asarray(self), repeats, rep)
result = libops.vec_binop(
np.asarray(self),
np.asarray(repeats, dtype=object),
rep,
)
if isinstance(self, BaseStringArray):
# Not going through map, so we have to do this here.
result = type(self)._from_sequence(result)
Expand Down
4 changes: 1 addition & 3 deletions pandas/core/tools/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
is_integer_dtype,
is_list_like,
is_numeric_dtype,
is_scalar,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.generic import (
Expand Down Expand Up @@ -599,8 +598,7 @@ def _adjust_to_origin(arg, origin, unit):
else:
# arg must be numeric
if not (
(is_scalar(arg) and (is_integer(arg) or is_float(arg)))
or is_numeric_dtype(np.asarray(arg))
(is_integer(arg) or is_float(arg)) or is_numeric_dtype(np.asarray(arg))
):
raise ValueError(
f"'{arg}' is not compatible with origin='{origin}'; "
Expand Down
28 changes: 11 additions & 17 deletions pandas/io/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,23 +250,18 @@ def _has_complex_date_col(self) -> bool:

@final
def _should_parse_dates(self, i: int) -> bool:
if isinstance(self.parse_dates, bool):
return self.parse_dates
if lib.is_bool(self.parse_dates):
return bool(self.parse_dates)
else:
if self.index_names is not None:
name = self.index_names[i]
else:
name = None
j = i if self.index_col is None else self.index_col[i]

if is_scalar(self.parse_dates):
return (j == self.parse_dates) or (
name is not None and name == self.parse_dates
)
else:
return (j in self.parse_dates) or (
name is not None and name in self.parse_dates
)
return (j in self.parse_dates) or (
name is not None and name in self.parse_dates
)

@final
def _extract_multi_indexer_columns(
Expand Down Expand Up @@ -1370,13 +1365,12 @@ def _validate_parse_dates_arg(parse_dates):
"for the 'parse_dates' parameter"
)

if parse_dates is not None:
if is_scalar(parse_dates):
if not lib.is_bool(parse_dates):
raise TypeError(msg)

elif not isinstance(parse_dates, (list, dict)):
raise TypeError(msg)
if not (
parse_dates is None
or lib.is_bool(parse_dates)
or isinstance(parse_dates, (list, dict))
):
raise TypeError(msg)

return parse_dates

Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/frame/methods/test_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
import pandas._testing as tm
from pandas.api.types import CategoricalDtype as CDT
import pandas.core.common as com


class TestReindexSetIndex:
Expand Down Expand Up @@ -355,7 +354,7 @@ def test_reindex_frame_add_nat(self):
result = df.reindex(range(15))
assert np.issubdtype(result["B"].dtype, np.dtype("M8[ns]"))

mask = com.isna(result)["B"]
mask = isna(result)["B"]
assert mask[-5:].all()
assert not mask[:-5].any()

Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/frame/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Series,
)
import pandas._testing as tm
import pandas.core.common as com
from pandas.core.computation import expressions as expr
from pandas.core.computation.expressions import (
_MIN_ELEMENTS,
Expand Down Expand Up @@ -1246,12 +1245,12 @@ def test_operators_none_as_na(self, op):
filled = df.fillna(np.nan)
result = op(df, 3)
expected = op(filled, 3).astype(object)
expected[com.isna(expected)] = None
expected[pd.isna(expected)] = None
tm.assert_frame_equal(result, expected)

result = op(df, df)
expected = op(filled, filled).astype(object)
expected[com.isna(expected)] = None
expected[pd.isna(expected)] = None
tm.assert_frame_equal(result, expected)

result = op(df, df.fillna(7))
Expand Down

0 comments on commit 9dd2faf

Please sign in to comment.