Skip to content

Commit

Permalink
ENH: DataFrame.plot.scatter argument c now accepts a column of stri…
Browse files Browse the repository at this point in the history
…ngs, where rows with the same string are colored identically (#59239)
  • Loading branch information
michaelmannino authored Sep 3, 2024
1 parent 57a4fb9 commit f3e1991
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Other enhancements
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`)
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
- Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`)
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
Expand Down
41 changes: 41 additions & 0 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,22 @@ def _make_plot(self, fig: Figure) -> None:
label = self.label
else:
label = None

# if a list of non color strings is passed in as c, color points
# by uniqueness of the strings, such same strings get same color
create_colors = not self._are_valid_colors(c_values)
if create_colors:
color_mapping = self._get_color_mapping(c_values)
c_values = [color_mapping[s] for s in c_values]

# build legend for labeling custom colors
ax.legend(
handles=[
mpl.patches.Circle((0, 0), facecolor=c, label=s)
for s, c in color_mapping.items()
]
)

scatter = ax.scatter(
data[x].values,
data[y].values,
Expand All @@ -1353,6 +1369,7 @@ def _make_plot(self, fig: Figure) -> None:
s=self.s,
**self.kwds,
)

if cb:
cbar_label = c if c_is_column else ""
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
Expand Down Expand Up @@ -1392,6 +1409,30 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
c_values = c
return c_values

def _are_valid_colors(self, c_values: Series) -> bool:
# check if c_values contains strings and if these strings are valid mpl colors.
# no need to check numerics as these (and mpl colors) will be validated for us
# in .Axes.scatter._parse_scatter_color_args(...)
unique = np.unique(c_values)
try:
if len(c_values) and all(isinstance(c, str) for c in unique):
mpl.colors.to_rgba_array(unique)

return True

except (TypeError, ValueError) as _:
return False

def _get_color_mapping(self, c_values: Series) -> dict[str, np.ndarray]:
unique = np.unique(c_values)
n_colors = len(unique)

# passing `None` here will default to :rc:`image.cmap`
cmap = mpl.colormaps.get_cmap(self.colormap)
colors = cmap(np.linspace(0, 1, n_colors)) # RGB tuples

return dict(zip(unique, colors))

def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
c = self.c
if self.colormap is not None:
Expand Down
54 changes: 53 additions & 1 deletion pandas/tests/plotting/frame/test_frame_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,53 @@ def test_scatter_with_c_column_name_with_colors(self, cmap):
ax = df.plot.scatter(x=0, y=1, cmap=cmap, c="species")
else:
ax = df.plot.scatter(x=0, y=1, c="species", cmap=cmap)

assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == 3 # r/g/b
assert (
np.unique(ax.collections[0].get_facecolor(), axis=0)
== np.array(
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.5, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]]
) # r/g/b
).all()
assert ax.collections[0].colorbar is None

def test_scatter_with_c_column_name_without_colors(self):
# Given
colors = ["NY", "MD", "MA", "CA"]
color_count = 4 # 4 unique colors

# When
df = DataFrame(
{
"dataX": range(100),
"dataY": range(100),
"color": (colors[i % len(colors)] for i in range(100)),
}
)

# Then
ax = df.plot.scatter("dataX", "dataY", c="color")
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == color_count

# Given
colors = ["r", "g", "not-a-color"]
color_count = 3
# Also, since not all are mpl-colors, points matching 'r' or 'g'
# are not necessarily red or green

# When
df = DataFrame(
{
"dataX": range(100),
"dataY": range(100),
"color": (colors[i % len(colors)] for i in range(100)),
}
)

# Then
ax = df.plot.scatter("dataX", "dataY", c="color")
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == color_count

def test_scatter_colors(self):
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
with pytest.raises(TypeError, match="Specify exactly one of `c` and `color`"):
Expand All @@ -229,7 +274,14 @@ def test_scatter_colors_not_raising_warnings(self):
# provided via 'c'. Parameters 'cmap' will be ignored
df = DataFrame({"x": [1, 2, 3], "y": [1, 2, 3]})
with tm.assert_produces_warning(None):
df.plot.scatter(x="x", y="y", c="b")
ax = df.plot.scatter(x="x", y="y", c="b")
assert (
len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == 1
) # blue
assert (
np.unique(ax.collections[0].get_facecolor(), axis=0)
== np.array([[0.0, 0.0, 1.0, 1.0]])
).all() # blue

def test_scatter_colors_default(self):
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
Expand Down

0 comments on commit f3e1991

Please sign in to comment.