Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: DataFrame.plot.scatter argument c now accepts a column of strings, where rows with the same string are colored identically #59239

Merged
merged 17 commits into from
Sep 3, 2024
Merged
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Other enhancements
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`)
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
- Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`)
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
Expand Down
41 changes: 41 additions & 0 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,22 @@ def _make_plot(self, fig: Figure) -> None:
label = self.label
else:
label = None

# if a list of non color strings is passed in as c, color points
# by uniqueness of the strings, such same strings get same color
create_colors = not self._are_valid_colors(c_values)
if create_colors:
color_mapping = self._get_color_mapping(c_values)
c_values = [color_mapping[s] for s in c_values]

# build legend for labeling custom colors
ax.legend(
handles=[
mpl.patches.Circle((0, 0), facecolor=c, label=s)
for s, c in color_mapping.items()
]
)

scatter = ax.scatter(
data[x].values,
data[y].values,
Expand All @@ -1353,6 +1369,7 @@ def _make_plot(self, fig: Figure) -> None:
s=self.s,
**self.kwds,
)

if cb:
cbar_label = c if c_is_column else ""
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
Expand Down Expand Up @@ -1392,6 +1409,30 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
c_values = c
return c_values

def _are_valid_colors(self, c_values: Series) -> bool:
# check if c_values contains strings and if these strings are valid mpl colors.
# no need to check numerics as these (and mpl colors) will be validated for us
# in .Axes.scatter._parse_scatter_color_args(...)
unique = np.unique(c_values)
try:
if len(c_values) and all(isinstance(c, str) for c in unique):
mpl.colors.to_rgba_array(unique)

return True

except (TypeError, ValueError) as _:
return False

def _get_color_mapping(self, c_values: Series) -> dict[str, np.ndarray]:
unique = np.unique(c_values)
n_colors = len(unique)

# passing `None` here will default to :rc:`image.cmap`
cmap = mpl.colormaps.get_cmap(self.colormap)
colors = cmap(np.linspace(0, 1, n_colors)) # RGB tuples

return dict(zip(unique, colors))

def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
c = self.c
if self.colormap is not None:
Expand Down
54 changes: 53 additions & 1 deletion pandas/tests/plotting/frame/test_frame_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,53 @@ def test_scatter_with_c_column_name_with_colors(self, cmap):
ax = df.plot.scatter(x=0, y=1, cmap=cmap, c="species")
else:
ax = df.plot.scatter(x=0, y=1, c="species", cmap=cmap)

assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == 3 # r/g/b
assert (
np.unique(ax.collections[0].get_facecolor(), axis=0)
== np.array(
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.5, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]]
) # r/g/b
).all()
assert ax.collections[0].colorbar is None

def test_scatter_with_c_column_name_without_colors(self):
# Given
colors = ["NY", "MD", "MA", "CA"]
color_count = 4 # 4 unique colors

# When
df = DataFrame(
{
"dataX": range(100),
"dataY": range(100),
"color": (colors[i % len(colors)] for i in range(100)),
}
)

# Then
ax = df.plot.scatter("dataX", "dataY", c="color")
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == color_count

# Given
colors = ["r", "g", "not-a-color"]
color_count = 3
# Also, since not all are mpl-colors, points matching 'r' or 'g'
# are not necessarily red or green

# When
df = DataFrame(
{
"dataX": range(100),
"dataY": range(100),
"color": (colors[i % len(colors)] for i in range(100)),
}
)

# Then
ax = df.plot.scatter("dataX", "dataY", c="color")
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == color_count

def test_scatter_colors(self):
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
with pytest.raises(TypeError, match="Specify exactly one of `c` and `color`"):
Expand All @@ -229,7 +274,14 @@ def test_scatter_colors_not_raising_warnings(self):
# provided via 'c'. Parameters 'cmap' will be ignored
df = DataFrame({"x": [1, 2, 3], "y": [1, 2, 3]})
with tm.assert_produces_warning(None):
df.plot.scatter(x="x", y="y", c="b")
ax = df.plot.scatter(x="x", y="y", c="b")
assert (
len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == 1
) # blue
assert (
np.unique(ax.collections[0].get_facecolor(), axis=0)
== np.array([[0.0, 0.0, 1.0, 1.0]])
).all() # blue

def test_scatter_colors_default(self):
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
Expand Down