Skip to content

Commit

Permalink
gh-624 : Added index parameter to to_dict and added axis argume…
Browse files Browse the repository at this point in the history
…nt to `Series.add_suffix(), DataFrame.add_suffix(), Series.add_prefix() and DataFrame.add_prefix()` (#638)

* added arguments

* changed axis parameters

* req changes and created a different overload and corrected index args

* creating overloads

* Update frame.pyi

* corrected diff overload args

* added the tests

* corrected the tests

* Update test_series.py
  • Loading branch information
ramvikrams authored Apr 12, 2023
1 parent c03e23c commit 5dd1820
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 14 deletions.
53 changes: 41 additions & 12 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -268,32 +268,61 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "split", "tight", "index"],
orient: Literal["records"],
into: Mapping | type[Mapping],
index: Literal[True] = ...,
) -> list[Mapping[Hashable, Any]]: ...
@overload
def to_dict(
self,
orient: Literal["records"],
into: None = ...,
index: Literal[True] = ...,
) -> list[dict[Hashable, Any]]: ...
@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "index"],
into: Mapping | type[Mapping],
index: Literal[True] = ...,
) -> Mapping[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
*,
orient: Literal["split", "tight"],
into: Mapping | type[Mapping],
index: bool = ...,
) -> Mapping[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
into: None = ...,
) -> dict[Hashable, Any]: ...
orient: Literal["dict", "list", "series", "index"] = ...,
*,
into: Mapping | type[Mapping],
index: Literal[True] = ...,
) -> Mapping[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["records"],
orient: Literal["split", "tight"] = ...,
*,
into: Mapping | type[Mapping],
) -> list[Mapping[Hashable, Any]]: ...
index: bool = ...,
) -> Mapping[Hashable, Any]: ...
@overload
def to_dict(
self, orient: Literal["records"], into: None = ...
) -> list[dict[Hashable, Any]]: ...
self,
orient: Literal["dict", "list", "series", "index"] = ...,
into: None = ...,
index: Literal[True] = ...,
) -> dict[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["split", "tight"] = ...,
into: None = ...,
index: bool = ...,
) -> dict[Hashable, Any]: ...
def to_gbq(
self,
destination_table: str,
Expand Down Expand Up @@ -1400,8 +1429,8 @@ class DataFrame(NDFrame, OpsMixin):
level: Level | None = ...,
fill_value: float | None = ...,
) -> DataFrame: ...
def add_prefix(self, prefix: _str) -> DataFrame: ...
def add_suffix(self, suffix: _str) -> DataFrame: ...
def add_prefix(self, prefix: _str, axis: Axis | None = None) -> DataFrame: ...
def add_suffix(self, suffix: _str, axis: Axis | None = None) -> DataFrame: ...
@overload
def all(
self,
Expand Down
4 changes: 2 additions & 2 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,8 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
def pop(self, item: Hashable) -> S1: ...
def squeeze(self, axis: AxisIndex | None = ...) -> Scalar: ...
def __abs__(self) -> Series[S1]: ...
def add_prefix(self, prefix: _str) -> Series[S1]: ...
def add_suffix(self, suffix: _str) -> Series[S1]: ...
def add_prefix(self, prefix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...
def add_suffix(self, suffix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...
def reindex(
self,
index: Axes | None = ...,
Expand Down
44 changes: 44 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2526,3 +2526,47 @@ def test_loc_returns_series() -> None:
df1 = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40])
df2 = df1.loc[10, :]
check(assert_type(df2, Union[pd.Series, pd.DataFrame]), pd.Series)


def test_to_dict_index() -> None:
df = pd.DataFrame({"a": [1, 2], "b": [9, 10]})
check(
assert_type(
df.to_dict(orient="records", index=True), List[Dict[Hashable, Any]]
),
list,
)
check(assert_type(df.to_dict(orient="dict", index=True), Dict[Hashable, Any]), dict)
check(
assert_type(df.to_dict(orient="series", index=True), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="index", index=True), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="split", index=True), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="tight", index=True), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="tight", index=False), Dict[Hashable, Any]), dict
)
check(
assert_type(df.to_dict(orient="split", index=False), Dict[Hashable, Any]), dict
)
if TYPE_CHECKING_INVALID_USAGE:
check(assert_type(df.to_dict(orient="records", index=False), List[Dict[Hashable, Any]]), list) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
check(assert_type(df.to_dict(orient="dict", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
check(assert_type(df.to_dict(orient="series", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]
check(assert_type(df.to_dict(orient="index", index=False), Dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportGeneralTypeIssues]


def test_suffix_prefix_index() -> None:
df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]})
check(assert_type(df.add_suffix("_col", axis=1), pd.DataFrame), pd.DataFrame)
check(assert_type(df.add_suffix("_col", axis="index"), pd.DataFrame), pd.DataFrame)
check(assert_type(df.add_prefix("_col", axis="index"), pd.DataFrame), pd.DataFrame)
check(
assert_type(df.add_prefix("_col", axis="columns"), pd.DataFrame), pd.DataFrame
)
12 changes: 12 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,3 +1813,15 @@ def test_types_apply_set() -> None:
{"list1": [1, 2, 3], "list2": ["a", "b", "c"], "list3": [True, False, True]}
)
check(assert_type(series_of_lists.apply(lambda x: set(x)), pd.Series), pd.Series)


def test_prefix_summix_axis() -> None:
s = pd.Series([1, 2, 3, 4])
check(assert_type(s.add_suffix("_item", axis=0), pd.Series), pd.Series)
check(assert_type(s.add_suffix("_item", axis="index"), pd.Series), pd.Series)
check(assert_type(s.add_prefix("_item", axis=0), pd.Series), pd.Series)
check(assert_type(s.add_prefix("_item", axis="index"), pd.Series), pd.Series)

if TYPE_CHECKING_INVALID_USAGE:
check(assert_type(s.add_prefix("_item", axis=1), pd.Series), pd.Series) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
check(assert_type(s.add_suffix("_item", axis="columns"), pd.Series), pd.Series) # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]

0 comments on commit 5dd1820

Please sign in to comment.