Skip to content

Commit

Permalink
fix Series.split with expand=True (#199)
Browse files Browse the repository at this point in the history
* fix Series.split with expand=True

* align asterisk in split params
  • Loading branch information
Dr-Irv authored Aug 17, 2022
1 parent 3d24a9a commit 30a87ca
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
3 changes: 2 additions & 1 deletion pandas-stubs/core/indexes/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from typing import (
import numpy as np
from pandas import (
DataFrame,
MultiIndex,
Series,
)
from pandas.core.arrays import ExtensionArray
Expand Down Expand Up @@ -58,7 +59,7 @@ class Index(IndexOpsMixin, PandasObject):
tupleize_cols: bool = ...,
): ...
@property
def str(self) -> StringMethods[Index]: ...
def str(self) -> StringMethods[Index, MultiIndex]: ...
@property
def asi8(self) -> np_ndarray_int64: ...
def is_(self, other) -> bool: ...
Expand Down
2 changes: 1 addition & 1 deletion pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
) -> Series[S1]: ...
def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ...
@property
def str(self) -> StringMethods[Series]: ...
def str(self) -> StringMethods[Series, DataFrame]: ...
@property
def dt(self) -> CombinedDatetimelikeProperties: ...
@property
Expand Down
26 changes: 22 additions & 4 deletions pandas-stubs/core/strings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,25 @@ from typing import (
Generic,
Literal,
Sequence,
TypeVar,
overload,
)

import numpy as np
import pandas as pd
from pandas import Series
from pandas import (
DataFrame,
MultiIndex,
Series,
)
from pandas.core.base import NoNewAttributesMixin

from pandas._typing import T

class StringMethods(NoNewAttributesMixin, Generic[T]):
# The _TS type is what is used for the result of str.split with expand=True
_TS = TypeVar("_TS", DataFrame, MultiIndex)

class StringMethods(NoNewAttributesMixin, Generic[T, _TS]):
def __init__(self, data: T) -> None: ...
def __getitem__(self, key: slice | int) -> T: ...
def __iter__(self) -> T: ...
Expand Down Expand Up @@ -44,11 +52,21 @@ class StringMethods(NoNewAttributesMixin, Generic[T]):
na_rep: str | None = ...,
join: Literal["left", "right", "outer", "inner"] = ...,
) -> T: ...
@overload
def split(
self, pat: str = ..., n: int = ..., *, expand: Literal[True], regex: bool = ...
) -> _TS: ...
@overload
def split(
self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ...
self, pat: str = ..., n: int = ..., *, expand: bool = ..., regex: bool = ...
) -> T: ...
@overload
def rsplit(
self, pat: str = ..., n: int = ..., *, expand: Literal[True], regex: bool = ...
) -> T: ...
@overload
def rsplit(
self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ...
self, pat: str = ..., n: int = ..., *, expand: bool = ..., regex: bool = ...
) -> T: ...
@overload
def partition(self, sep: str = ...) -> pd.DataFrame: ...
Expand Down
7 changes: 7 additions & 0 deletions tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,10 @@ def test_difference_none() -> None:
# https://github.com/pandas-dev/pandas-stubs/issues/17
ind = pd.Index([1, 2, 3])
check(assert_type(ind.difference([1, None]), "pd.Index"), pd.Index, int)


def test_str_split() -> None:
# GH 194
ind = pd.Index(["a-b", "c-d"])
check(assert_type(ind.str.split("-"), pd.Index), pd.Index)
check(assert_type(ind.str.split("-", expand=True), pd.MultiIndex), pd.MultiIndex)
2 changes: 2 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,8 @@ def test_string_accessors():
check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series)
check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series)
check(assert_type(s.str.split("a"), pd.Series), pd.Series)
# GH 194
check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame)
check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, bool)
check(assert_type(s.str.strip(), pd.Series), pd.Series)
check(assert_type(s.str.swapcase(), pd.Series), pd.Series)
Expand Down

0 comments on commit 30a87ca

Please sign in to comment.