-
-
Notifications
You must be signed in to change notification settings - Fork 18k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: DataFrame.plot.scatter argument c
now accepts a column of strings, where rows with the same string are colored identically
#59239
Changes from 10 commits
b91e635
8609ea5
b4440c1
571c0c8
e9511d0
1ca57ed
fb0d6e4
4bcdbfc
7972138
45886d9
1713727
62427ad
609fe40
6e86858
5223f2a
d97606c
7e5a02a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
Iterator, | ||
Sequence, | ||
) | ||
from random import shuffle | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
|
@@ -1337,6 +1338,13 @@ def _make_plot(self, fig: Figure) -> None: | |
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical) | ||
cb = self._get_colorbar(c_values, c_is_column) | ||
|
||
# if a list of non color strings is passed in as c, generate a list | ||
# colored by uniqueness of the strings, such same strings get same color | ||
create_colors = not self._are_valid_colors(c_values) | ||
if create_colors: | ||
custom_color_mapping, c_values = self._uniquely_color_strs(c_values) | ||
cb = False # no colorbar; opt for legend | ||
|
||
if self.legend: | ||
label = self.label | ||
else: | ||
|
@@ -1367,6 +1375,15 @@ def _make_plot(self, fig: Figure) -> None: | |
label, # type: ignore[arg-type] | ||
) | ||
|
||
# build legend for labeling custom colors | ||
if create_colors: | ||
ax.legend( | ||
handles=[ | ||
mpl.patches.Circle((0, 0), facecolor=color, label=string) | ||
for string, color in custom_color_mapping.items() | ||
] | ||
) | ||
|
||
errors_x = self._get_errorbars(label=x, index=0, yerr=False) | ||
errors_y = self._get_errorbars(label=y, index=0, xerr=False) | ||
if len(errors_x) > 0 or len(errors_y) > 0: | ||
|
@@ -1390,6 +1407,38 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool): | |
c_values = c | ||
return c_values | ||
|
||
def _are_valid_colors(self, c_values: np.ndarray | list): | ||
# check if c_values contains strings and if these strings are valid mpl colors. | ||
# no need to check numerics as these (and mpl colors) will be validated for us | ||
# in .Axes.scatter._parse_scatter_color_args(...) | ||
try: | ||
if len(c_values) and all(isinstance(c, str) for c in c_values): | ||
mpl.colors.to_rgba_array(c_values) | ||
|
||
return True | ||
|
||
except (TypeError, ValueError) as _: | ||
return False | ||
|
||
def _uniquely_color_strs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am no matplotlib expert but I think we need to defer to that somehow to get the desired colors, instead of trying to write this out ourselves There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did some research looking into how other libs which support this functionality, including seaborn, handle this workflow. Each utilize the current mpl.colormap and draw colors from a linear space across this map. This allows users to change the choice colors of the same way they would with all other graphs Additionally, I've added an automatic legend to this type of plot since the chosen colors are not exposed to the user, similar to how a colorbar is drawn in some cases of the same function |
||
self, c_values: np.ndarray | list | ||
) -> tuple[dict, np.ndarray]: | ||
# well, almost uniquely color them (up to 949) | ||
unique = np.unique(c_values) | ||
|
||
# for up to 7, lets keep colors consistent | ||
if len(unique) <= 7: | ||
possible_colors = list(mpl.colors.BASE_COLORS.values()) # Hex | ||
# explore better ways to handle this case | ||
else: | ||
possible_colors = list(mpl.colors.XKCD_COLORS.values()) # Hex | ||
shuffle(possible_colors) | ||
|
||
colors = [possible_colors[i % len(possible_colors)] for i in range(len(unique))] | ||
color_mapping = dict(zip(unique, colors)) | ||
|
||
return color_mapping, np.array(list(map(color_mapping.get, c_values))) | ||
|
||
def _get_norm_and_cmap(self, c_values, color_by_categorical: bool): | ||
c = self.c | ||
if self.colormap is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,6 +207,21 @@ def test_scatter_with_c_column_name_with_colors(self, cmap): | |
ax = df.plot.scatter(x=0, y=1, c="species", cmap=cmap) | ||
assert ax.collections[0].colorbar is None | ||
|
||
def test_scatter_with_c_column_name_without_colors(self): | ||
df = DataFrame( | ||
{ | ||
"dataX": range(100), | ||
"dataY": range(100), | ||
"state": ["NY", "MD", "MA", "CA"] * 25, | ||
} | ||
) | ||
df.plot.scatter("dataX", "dataY", c="state") | ||
|
||
with tm.assert_produces_warning(None): | ||
ax = df.plot.scatter(x=0, y=1, c="state") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not 100% sure that the test here should be in the There should also be tests where the column There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the Additionally, I updated some other functionality tests to confirm they are not executing the new logic added and coloring properly |
||
|
||
assert len(np.unique(ax.collections[0].get_facecolor())) == 4 # 4 states | ||
|
||
def test_scatter_colors(self): | ||
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | ||
with pytest.raises(TypeError, match="Specify exactly one of `c` and `color`"): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what instances is
c_values
a list? Might be misreading but would be better if we only worked with a pd.Series and could call .unique on that, instead of checking every single value in a loopThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to take a
pd.Series
, notnp.ndarray | list