Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-624 : Added index parameter to to_dict and added axis argument to Series.add_suffix(), DataFrame.add_suffix(), Series.add_prefix() and DataFrame.add_prefix() #638

Merged
merged 9 commits into from
Apr 12, 2023
53 changes: 41 additions & 12 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -268,32 +268,61 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "split", "tight", "index"],
orient: Literal["records"],
into: Mapping | type[Mapping],
index: Literal[True] = ...,
) -> list[Mapping[Hashable, Any]]: ...
@overload
def to_dict(
self,
orient: Literal["records"],
into: None = ...,
index: Literal[True] = ...,
) -> list[dict[Hashable, Any]]: ...
@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "index"],
into: Mapping | type[Mapping],
index: Literal[True] = ...,
) -> Mapping[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
*,
orient: Literal["split", "tight"],
ramvikrams marked this conversation as resolved.
Show resolved Hide resolved
into: Mapping | type[Mapping],
index: bool = ...,
) -> Mapping[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
into: None = ...,
) -> dict[Hashable, Any]: ...
orient: Literal["dict", "list", "series", "index"] = ...,
*,
into: Mapping | type[Mapping],
index: Literal[True] = ...,
) -> Mapping[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["records"],
orient: Literal["split", "tight"] = ...,
*,
into: Mapping | type[Mapping],
) -> list[Mapping[Hashable, Any]]: ...
index: bool = ...,
) -> Mapping[Hashable, Any]: ...
@overload
def to_dict(
self, orient: Literal["records"], into: None = ...
) -> list[dict[Hashable, Any]]: ...
self,
orient: Literal["dict", "list", "series", "index"] = ...,
into: None = ...,
index: Literal[True] = ...,
) -> dict[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["split", "tight"] = ...,
into: None = ...,
index: bool = ...,
) -> dict[Hashable, Any]: ...
ramvikrams marked this conversation as resolved.
Show resolved Hide resolved
def to_gbq(
self,
destination_table: str,
Expand Down Expand Up @@ -1400,8 +1429,8 @@ class DataFrame(NDFrame, OpsMixin):
level: Level | None = ...,
fill_value: float | None = ...,
) -> DataFrame: ...
def add_prefix(self, prefix: _str) -> DataFrame: ...
def add_suffix(self, suffix: _str) -> DataFrame: ...
def add_prefix(self, prefix: _str, axis: Axis | None = None) -> DataFrame: ...
def add_suffix(self, suffix: _str, axis: Axis | None = None) -> DataFrame: ...
@overload
def all(
self,
Expand Down
4 changes: 2 additions & 2 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,8 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
def pop(self, item: Hashable) -> S1: ...
def squeeze(self, axis: AxisIndex | None = ...) -> Scalar: ...
def __abs__(self) -> Series[S1]: ...
def add_prefix(self, prefix: _str) -> Series[S1]: ...
def add_suffix(self, suffix: _str) -> Series[S1]: ...
def add_prefix(self, prefix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...
def add_suffix(self, suffix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...
def reindex(
self,
index: Axes | None = ...,
Expand Down
44 changes: 44 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2526,3 +2526,47 @@ def test_loc_returns_series() -> None:
df1 = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40])
df2 = df1.loc[10, :]
check(assert_type(df2, Union[pd.Series, pd.DataFrame]), pd.Series)


def test_to_dict_index() -> None:
df = pd.DataFrame({"a": [1, 2], "b": [9, 10]})
check(
assert_type(
df.to_dict(orient="records", index=True), List[Dict[Hashable, Any]]
),
list,
)
check(assert_type(df.to_dict(orient="dict", index=True), Dict[Hashable, Any]), dict)
check(
assert_type(df.to_dict(orient="series", index=True), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="index", index=True), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="split", index=True), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="tight", index=True), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="tight", index=False), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="split", index=False), Dict[Hashable, Any]), dict
)
if TYPE_CHECKING_INVALID_USAGE:
check(assert_type(df.to_dict(orient="records", index=False), List[Dict[Hashable, Any]]), list) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
check(assert_type(df.to_dict(orient="dict", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
check(assert_type(df.to_dict(orient="series", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
check(assert_type(df.to_dict(orient="index", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]


def test_suffix_prefix_index() -> None:
df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]})
check(assert_type(df.add_suffix("_col", axis=1), pd.DataFrame), pd.DataFrame)
check(assert_type(df.add_suffix("_col", axis="index"), pd.DataFrame), pd.DataFrame)
check(assert_type(df.add_prefix("_col", axis="index"), pd.DataFrame), pd.DataFrame)
check(
assert_type(df.add_prefix("_col", axis="columns"), pd.DataFrame), pd.DataFrame
)
12 changes: 12 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,3 +1813,15 @@ def test_types_apply_set() -> None:
{"list1": [1, 2, 3], "list2": ["a", "b", "c"], "list3": [True, False, True]}
)
check(assert_type(series_of_lists.apply(lambda x: set(x)), pd.Series), pd.Series)


def test_prefix_summix_axis() -> None:
s = pd.Series([1, 2, 3, 4])
check(assert_type(s.add_suffix("_item", axis=0), pd.Series), pd.Series)
check(assert_type(s.add_suffix("_item", axis="index"), pd.Series), pd.Series)
check(assert_type(s.add_prefix("_item", axis=0), pd.Series), pd.Series)
check(assert_type(s.add_prefix("_item", axis="index"), pd.Series), pd.Series)
ramvikrams marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING_INVALID_USAGE:
check(assert_type(s.add_prefix("_item", axis=1), pd.Series), pd.Series) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
check(assert_type(s.add_suffix("_item", axis="columns"), pd.Series), pd.Series) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
ramvikrams marked this conversation as resolved.
Show resolved Hide resolved