Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: DataFrame.plot.scatter argument c now accepts a column of strings, where rows with the same string are colored identically #59239

Merged
merged 17 commits into from
Sep 3, 2024
Merged
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Other enhancements
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`)
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)

Expand Down
49 changes: 49 additions & 0 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterator,
Sequence,
)
from random import shuffle
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -1337,6 +1338,13 @@ def _make_plot(self, fig: Figure) -> None:
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical)
cb = self._get_colorbar(c_values, c_is_column)

# if a list of non color strings is passed in as c, generate a list
# colored by uniqueness of the strings, such same strings get same color
create_colors = not self._are_valid_colors(c_values)
if create_colors:
custom_color_mapping, c_values = self._uniquely_color_strs(c_values)
cb = False # no colorbar; opt for legend

if self.legend:
label = self.label
else:
Expand Down Expand Up @@ -1367,6 +1375,15 @@ def _make_plot(self, fig: Figure) -> None:
label, # type: ignore[arg-type]
)

# build legend for labeling custom colors
if create_colors:
ax.legend(
handles=[
mpl.patches.Circle((0, 0), facecolor=color, label=string)
for string, color in custom_color_mapping.items()
]
)

errors_x = self._get_errorbars(label=x, index=0, yerr=False)
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
if len(errors_x) > 0 or len(errors_y) > 0:
Expand All @@ -1390,6 +1407,38 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
c_values = c
return c_values

def _are_valid_colors(self, c_values: np.ndarray | list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what instances is c_values a list? Might be misreading but would be better if we only worked with a pd.Series and could call .unique on that, instead of checking every single value in a loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to take a pd.Series, not np.ndarray | list

# check if c_values contains strings and if these strings are valid mpl colors.
# no need to check numerics as these (and mpl colors) will be validated for us
# in .Axes.scatter._parse_scatter_color_args(...)
try:
if len(c_values) and all(isinstance(c, str) for c in c_values):
mpl.colors.to_rgba_array(c_values)

return True

except (TypeError, ValueError) as _:
return False

def _uniquely_color_strs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am no matplotlib expert but I think we need to defer to that somehow to get the desired colors, instead of trying to write this out ourselves

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some research looking into how other libs which support this functionality, including seaborn, handle this workflow. Each utilize the current mpl.colormap and draw colors from a linear space across this map. This allows users to change the choice colors of the same way they would with all other graphs

Additionally, I've added an automatic legend to this type of plot since the chosen colors are not exposed to the user, similar to how a colorbar is drawn in some cases of the same function

self, c_values: np.ndarray | list
) -> tuple[dict, np.ndarray]:
# well, almost uniquely color them (up to 949)
unique = np.unique(c_values)

# for up to 7, lets keep colors consistent
if len(unique) <= 7:
possible_colors = list(mpl.colors.BASE_COLORS.values()) # Hex
# explore better ways to handle this case
else:
possible_colors = list(mpl.colors.XKCD_COLORS.values()) # Hex
shuffle(possible_colors)

colors = [possible_colors[i % len(possible_colors)] for i in range(len(unique))]
color_mapping = dict(zip(unique, colors))

return color_mapping, np.array(list(map(color_mapping.get, c_values)))

def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
c = self.c
if self.colormap is not None:
Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/plotting/frame/test_frame_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,21 @@ def test_scatter_with_c_column_name_with_colors(self, cmap):
ax = df.plot.scatter(x=0, y=1, c="species", cmap=cmap)
assert ax.collections[0].colorbar is None

def test_scatter_with_c_column_name_without_colors(self):
df = DataFrame(
{
"dataX": range(100),
"dataY": range(100),
"state": ["NY", "MD", "MA", "CA"] * 25,
}
)
df.plot.scatter("dataX", "dataY", c="state")

with tm.assert_produces_warning(None):
ax = df.plot.scatter(x=0, y=1, c="state")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure that the test here should be in the tm.assert_produces_warning(None) context. My intuition is that should be removed.

There should also be tests where the column state contains values that are colors, as well as a mix of colors and non-colors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the tm.assert_produces_warning(None). I added 2 test cases: 1 containing only strings and 1 containing both valid mpl strings and invalid mpl strings.

Additionally, I updated some other functionality tests to confirm they are not executing the new logic added and coloring properly


assert len(np.unique(ax.collections[0].get_facecolor())) == 4 # 4 states

def test_scatter_colors(self):
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
with pytest.raises(TypeError, match="Specify exactly one of `c` and `color`"):
Expand Down