Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize internal scaling operation #3440

Merged
merged 6 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions seaborn/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,7 @@ def iter_data(
)

# Reduce to the semantics used in this plot
grouping_vars = [
var for var in grouping_vars if var in self.variables
]
grouping_vars = [var for var in grouping_vars if var in self.variables]

if from_comp_data:
data = self.comp_data
Expand All @@ -1040,22 +1038,21 @@ def iter_data(
levels = self.var_levels.copy()
if from_comp_data:
for axis in {"x", "y"} & set(grouping_vars):
converter = self.converters[axis].iloc[0]
if self.var_types[axis] == "categorical":
if self._var_ordered[axis]:
# If the axis is ordered, then the axes in a possible
# facet grid are by definition "shared", or there is a
# single axis with a unique cat -> idx mapping.
# So we can just take the first converter object.
converter = self.converters[axis].iloc[0]
levels[axis] = converter.convert_units(levels[axis])
else:
# Otherwise, the mappings may not be unique, but we can
# use the unique set of index values in comp_data.
levels[axis] = np.sort(data[axis].unique())
elif self.var_types[axis] == "datetime":
levels[axis] = mpl.dates.date2num(levels[axis])
elif self.var_types[axis] == "numeric" and self._log_scaled(axis):
levels[axis] = np.log10(levels[axis])
else:
transform = converter.get_transform().transform
levels[axis] = transform(converter.convert_units(levels[axis]))

if grouping_vars:

Expand Down Expand Up @@ -1129,9 +1126,8 @@ def comp_data(self):
# supporting `order` in categorical plots is tricky
orig = orig[orig.isin(self.var_levels[var])]
comp = pd.to_numeric(converter.convert_units(orig)).astype(float)
if converter.get_scale() == "log":
comp = np.log10(comp)
parts.append(pd.Series(comp, orig.index, name=orig.name))
transform = converter.get_transform().transform
parts.append(pd.Series(transform(comp), orig.index, name=orig.name))
if parts:
comp_col = pd.concat(parts)
else:
Expand Down Expand Up @@ -1300,25 +1296,27 @@ def _attach(

# TODO -- Add axes labels

def _log_scaled(self, axis):
"""Return True if specified axis is log scaled on all attached axes."""
if not hasattr(self, "ax"):
return False

def _get_scale_transforms(self, axis):
"""Return a function implementing the scale transform (or its inverse)."""
if self.ax is None:
axes_list = self.facets.axes.flatten()
axis_list = [getattr(ax, f"{axis}axis") for ax in self.facets.axes.flat]
scales = {axis.get_scale() for axis in axis_list}
if len(scales) > 1:
# It is a simplifying assumption that faceted axes will always have
# the same scale (even if they are unshared and have distinct limits).
# Nothing in the seaborn API allows you to create a FacetGrid with
# a mixture of scales, although it's possible via matplotlib.
# This is constraining, but no more so than previous behavior that
# only (properly) handled log scales, and there are some places where
# it would be much too complicated to use axes-specific transforms.
err = "Cannot determine transform with mixed scales on faceted axes."
raise RuntimeError(err)
transform_obj = axis_list[0].get_transform()
else:
axes_list = [self.ax]

log_scaled = []
for ax in axes_list:
data_axis = getattr(ax, f"{axis}axis")
log_scaled.append(data_axis.get_scale() == "log")

if any(log_scaled) and not all(log_scaled):
raise RuntimeError("Axis scaling is not consistent")
# This case is more straightforward
transform_obj = getattr(self.ax, f"{axis}axis").get_transform()

return any(log_scaled)
return transform_obj.transform, transform_obj.inverted().transform

def _add_axis_labels(self, ax, default_x="", default_y=""):
"""Add axis labels if not present, set visibility to match ticklabels."""
Expand Down
79 changes: 37 additions & 42 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_check_argument,
_draw_figure,
_default_color,
_get_transform_functions,
_normalize_kwargs,
_version_predates,
)
Expand Down Expand Up @@ -371,7 +372,7 @@ def _dodge(self, keys, data):
def _invert_scale(self, ax, data, vars=("x", "y")):
"""Undo scaling after computation so data are plotted correctly."""
for var in vars:
_, inv = utils._get_transform_functions(ax, var[0])
_, inv = _get_transform_functions(ax, var[0])
if var == self.orient and "width" in data:
hw = data["width"] / 2
data["edge"] = inv(data[var] - hw)
Expand Down Expand Up @@ -528,9 +529,7 @@ def plot_swarms(
if not sub_data.empty:
point_collections[(ax, sub_data[self.orient].iloc[0])] = points

beeswarm = Beeswarm(
width=width, orient=self.orient, warn_thresh=warn_thresh,
)
beeswarm = Beeswarm(width=width, orient=self.orient, warn_thresh=warn_thresh)
for (ax, center), points in point_collections.items():
if points.get_offsets().shape[0] > 1:

Expand Down Expand Up @@ -627,6 +626,12 @@ def get_props(element, artist=mpl.lines.Line2D):
capwidth = plot_kws.get("capwidths", 0.5 * data["width"])

self._invert_scale(ax, data)
_, inv = _get_transform_functions(ax, value_var)
for stat in ["mean", "med", "q1", "q3", "cilo", "cihi", "whislo", "whishi"]:
stats[stat] = inv(stats[stat])
stats["fliers"] = stats["fliers"].map(inv)

linear_orient_scale = getattr(ax, f"get_{self.orient}scale")() == "linear"

maincolor = self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color
if fill:
Expand All @@ -651,8 +656,8 @@ def get_props(element, artist=mpl.lines.Line2D):
default_kws = dict(
bxpstats=stats.to_dict("records"),
positions=data[self.orient],
# Set width to 0 with log scaled orient axis to avoid going < 0
widths=0 if self._log_scaled(self.orient) else data["width"],
# Set width to 0 to avoid going out of domain
widths=data["width"] if linear_orient_scale else 0,
patch_artist=fill,
vert=self.orient == "x",
manage_ticks=False,
Expand All @@ -673,7 +678,8 @@ def get_props(element, artist=mpl.lines.Line2D):

# Reset artist widths after adding so everything stays positive
ori_idx = ["x", "y"].index(self.orient)
if self._log_scaled(self.orient):

if not linear_orient_scale:
for i, box in enumerate(data.to_dict("records")):
p0 = box["edge"]
p1 = box["edge"] + box["width"]
Expand Down Expand Up @@ -702,9 +708,10 @@ def get_props(element, artist=mpl.lines.Line2D):
artists["medians"][i].set_data(verts)

if artists["caps"]:
f_fwd, f_inv = _get_transform_functions(ax, self.orient)
for line in artists["caps"][2 * i:2 * i + 2]:
p0 = 10 ** (np.log10(box[self.orient]) - capwidth[i] / 2)
p1 = 10 ** (np.log10(box[self.orient]) + capwidth[i] / 2)
p0 = f_inv(f_fwd(box[self.orient]) - capwidth[i] / 2)
p1 = f_inv(f_fwd(box[self.orient]) + capwidth[i] / 2)
verts = line.get_xydata().T
verts[ori_idx][:] = p0, p1
line.set_data(verts)
Expand Down Expand Up @@ -769,8 +776,8 @@ def plot_boxens(
allow_empty=False):

ax = self._get_axes(sub_vars)
_, inv_ori = utils._get_transform_functions(ax, self.orient)
_, inv_val = utils._get_transform_functions(ax, value_var)
_, inv_ori = _get_transform_functions(ax, self.orient)
_, inv_val = _get_transform_functions(ax, value_var)

# Statistics
lv_data = estimator(sub_data[value_var])
Expand Down Expand Up @@ -1010,8 +1017,8 @@ def vars_to_key(sub_vars):
offsets = span, span

ax = violin["ax"]
_, invx = utils._get_transform_functions(ax, "x")
_, invy = utils._get_transform_functions(ax, "y")
_, invx = _get_transform_functions(ax, "x")
_, invy = _get_transform_functions(ax, "y")
inv_pos = {"x": invx, "y": invy}[self.orient]
inv_val = {"x": invx, "y": invy}[value_var]

Expand Down Expand Up @@ -1168,17 +1175,11 @@ def plot_points(
markers = self._map_prop_with_hue("marker", markers, "o", plot_kws)
linestyles = self._map_prop_with_hue("linestyle", linestyles, "-", plot_kws)

positions = self.var_levels[self.orient]
base_positions = self.var_levels[self.orient]
if self.var_types[self.orient] == "categorical":
min_cat_val = int(self.comp_data[self.orient].min())
max_cat_val = int(self.comp_data[self.orient].max())
positions = [i for i in range(min_cat_val, max_cat_val + 1)]
else:
if self._log_scaled(self.orient):
positions = np.log10(positions)
if self.var_types[self.orient] == "datetime":
positions = mpl.dates.date2num(positions)
positions = pd.Index(positions, name=self.orient)
base_positions = [i for i in range(min_cat_val, max_cat_val + 1)]

n_hue_levels = 0 if self._hue_map.levels is None else len(self._hue_map.levels)
if dodge is True:
Expand All @@ -1192,11 +1193,14 @@ def plot_points(

ax = self._get_axes(sub_vars)

ori_axis = getattr(ax, f"{self.orient}axis")
transform, _ = _get_transform_functions(ax, self.orient)
positions = transform(ori_axis.convert_units(base_positions))
agg_data = sub_data if sub_data.empty else (
sub_data
.groupby(self.orient)
.apply(aggregator, agg_var)
.reindex(positions)
.reindex(pd.Index(positions, name=self.orient))
.reset_index()
)

Expand Down Expand Up @@ -1316,14 +1320,12 @@ def plot_errorbars(self, ax, data, capsize, err_kws):
pos = np.array([row[self.orient], row[self.orient]])
val = np.array([row[f"{var}min"], row[f"{var}max"]])

cw = capsize * self._native_width / 2
if self._log_scaled(self.orient):
log_pos = np.log10(pos)
cap = 10 ** (log_pos[0] - cw), 10 ** (log_pos[1] + cw)
else:
cap = pos[0] - cw, pos[1] + cw

if capsize:

cw = capsize * self._native_width / 2
scl, inv = _get_transform_functions(ax, self.orient)
cap = inv(scl(pos[0]) - cw), inv(scl(pos[1]) + cw)

pos = np.concatenate([
[*cap, np.nan], pos, [np.nan, *cap]
])
Expand Down Expand Up @@ -3220,13 +3222,12 @@ def __call__(self, points, center):
new_xy = new_xyr[:, :2]
new_x_data, new_y_data = ax.transData.inverted().transform(new_xy).T

log_scale = getattr(ax, f"get_{self.orient}scale")() == "log"

# Add gutters
t_fwd, t_inv = _get_transform_functions(ax, self.orient)
if self.orient == "y":
self.add_gutters(new_y_data, center, log_scale=log_scale)
self.add_gutters(new_y_data, center, t_fwd, t_inv)
else:
self.add_gutters(new_x_data, center, log_scale=log_scale)
self.add_gutters(new_x_data, center, t_fwd, t_inv)

# Reposition the points so they do not overlap
if self.orient == "y":
Expand Down Expand Up @@ -3330,20 +3331,14 @@ def first_non_overlapping_candidate(self, candidates, neighbors):
"No non-overlapping candidates found. This should not happen."
)

def add_gutters(self, points, center, log_scale=False):
def add_gutters(self, points, center, trans_fwd, trans_inv):
"""Stop points from extending beyond their territory."""
half_width = self.width / 2
if log_scale:
low_gutter = 10 ** (np.log10(center) - half_width)
else:
low_gutter = center - half_width
low_gutter = trans_inv(trans_fwd(center) - half_width)
off_low = points < low_gutter
if off_low.any():
points[off_low] = low_gutter
if log_scale:
high_gutter = 10 ** (np.log10(center) + half_width)
else:
high_gutter = center + half_width
high_gutter = trans_inv(trans_fwd(center) + half_width)
off_high = points > high_gutter
if off_high.any():
points[off_high] = high_gutter
Expand Down
Loading
Loading