diff --git a/.gitignore b/.gitignore index 314e49809..66958385c 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ __pycache__ /tests/mnist_data catboost_info node_modules +mpl-results diff --git a/pyproject.toml b/pyproject.toml index acce59f10..9434771e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,7 +141,6 @@ module = [ "shap.plots._benchmark", "shap.plots._force_matplotlib", "shap.plots._force", - "shap.plots._scatter.*", "shap.plots._text", "shap.plots._waterfall", ] diff --git a/shap/plots/_scatter.py b/shap/plots/_scatter.py index 5339c8e4d..6fbcb08c8 100644 --- a/shap/plots/_scatter.py +++ b/shap/plots/_scatter.py @@ -1,38 +1,47 @@ +from __future__ import annotations + +import typing import warnings +from typing import Any, Literal, Union import matplotlib -import matplotlib.pyplot as pl +import matplotlib.pyplot as plt import numpy as np import pandas as pd +from matplotlib.markers import MarkerStyle from .._explanation import Explanation from ..utils import approximate_interactions, convert_name +from ..utils._exceptions import DimensionError from ..utils._general import encode_array_if_needed from . import colors from ._labels import labels +# Various ways to specify a desired axis limit +LimitSpec = Union[Explanation, str, float, None] + # TODO: Make the color bar a one-sided beeswarm plot so we can see the density along the color axis def scatter( - shap_values, - color="#1E88E5", - hist=True, + shap_values: Explanation, + color: str | Explanation | None = "#1E88E5", + hist: bool = True, axis_color="#333333", cmap=colors.red_blue, dot_size=16, - x_jitter="auto", - alpha=1, - title=None, - xmin=None, - xmax=None, - ymin=None, - ymax=None, - overlay=None, - ax=None, - ylabel="SHAP value", - show=True, + x_jitter: float | Literal["auto"] = "auto", + alpha: float = 1.0, + title: str | None = None, + xmin: LimitSpec = None, + xmax: LimitSpec = None, + ymin: LimitSpec = None, + ymax: LimitSpec = None, + overlay: dict[str, Any] | None = None, + ax: plt.Axes | None = None, + ylabel: str = "SHAP value", + show: bool = True, ): - """Create a SHAP dependence scatter plot, colored by an interaction feature. + """Create a SHAP dependence scatter plot, optionally colored by an interaction feature. Plots the value of the feature on the x-axis and the SHAP value of the same feature on the y-axis. This shows how the model depends on the given feature, and is like a @@ -47,17 +56,22 @@ def scatter( Parameters ---------- shap_values : shap.Explanation - Typically a single column of an :class:`.Explanation` object (i.e. - ``shap_values[:,"Feature A"]``). - Alternatively, pass multiple columns to create several subplots. + Typically a single column of an :class:`.Explanation` object + (i.e. ``shap_values[:, "Feature A"]``). + + Alternatively, pass multiple columns to create several subplots + (i.e. ``shap_values[:, ["Feature A", "Feature B"]]``). - color : string or shap.Explanation + color : string or shap.Explanation, optional How to color the scatter plot points. This can be a fixed color string, or an - :class:`.Explanation` object. If it is an :class:`.Explanation` object, then the - scatter plot points are colored by the feature that seems to have the strongest - interaction effect with the feature given by the ``shap_values`` argument. This - is calculated using :func:`shap.utils.approximate_interactions`. If only a - single column of an :class:`.Explanation` object is passed, then that + :class:`.Explanation` object. + + If it is an :class:`.Explanation` object, then the scatter plot points are + colored by the feature that seems to have the strongest interaction effect with + the feature given by the ``shap_values`` argument. This is calculated using + :func:`shap.utils.approximate_interactions`. + + If only a single column of an :class:`.Explanation` object is passed, then that feature column will be used to color the data points. hist : bool @@ -71,56 +85,72 @@ def scatter( increase plot readability when a feature is discrete. By default, ``x_jitter`` is chosen based on auto-detection of categorical features. + title: str, optional + Plot title. + alpha : float The transparency of the data points (between 0 and 1). This can be useful to show the density of the data points when using a large dataset. - xmin : float or string - Represents the lower bound of the plot's x-axis. It can be a string of the format - "percentile(float)" to denote that percentile of the feature's value used on the x-axis. + xmin, xmax, ymin, ymax : float, string, aggregated Explanation or None + Desired axis limits. Can be a float to specify a fixed limit. - xmax : float or string - Represents the upper bound of the plot's x-axis. It can be a string of the format - "percentile(float)" to denote that percentile of the feature's value used on the x-axis. + It can be a string of the format ``"percentile(float)"`` to denote that + percentile of the feature's value. - ax : matplotlib Axes object - Optionally specify an existing matplotlib ``Axes`` object, into which the plot will be placed. - In this case, we do not create a ``Figure``, otherwise we do. + It can also be an aggregated column of a single column of an :class:`.Explanation`, + such as ``explanation[:, "feature_name"].percentile(20)``. + + overlay: dict, optional + Optional dictionary of up to three additional curves to overlay as line plots. + + The dictionary maps a curve name to a list of (xvalues, yvalues) pairs, where + there is one pair for each feature to be plotted. + + ax : matplotlib Axes, optional + Optionally specify an existing matplotlib ``Axes`` object, into which + the plot will be placed. + + Only supported when plotting a single feature. show : bool Whether ``matplotlib.pyplot.show()`` is called before returning. - Setting this to ``False`` allows the plot - to be customized further after it has been created. + + Setting this to ``False`` allows the plot to be customized further after it + has been created. + + Returns + ------- + ax : matplotlib Axes object + Only returned if ``show=False``. Examples -------- See `scatter plot examples `_. """ - assert str(type(shap_values)).endswith( - "Explanation'>" - ), "The shap_values parameter must be a shap.Explanation object!" + if not isinstance(shap_values, Explanation): + raise TypeError("The shap_values parameter must be a shap.Explanation object!") # see if we are plotting multiple columns if not isinstance(shap_values.feature_names, str) and len(shap_values.feature_names) > 0: + if ax is not None: + raise ValueError("The ax parameter is not supported when plotting multiple features") + # Define order of columns (features) to plot based on average shap value inds = np.argsort(np.abs(shap_values.values).mean(0)) - nan_min = np.nanmin(shap_values.values) - nan_max = np.nanmax(shap_values.values) - if ymin is None: - ymin = nan_min - (nan_max - nan_min) / 20 - if ymax is None: - ymax = nan_max + (nan_max - nan_min) / 20 - # FIXME: the following code ignores any passed in `ax` - _ = pl.subplots(1, len(inds), figsize=(min(6 * len(inds), 15), 5)) + ymin = _parse_limit(ymin, shap_values.values, is_shap_axis=True) + ymax = _parse_limit(ymax, shap_values.values, is_shap_axis=True) + ymin, ymax = _suggest_buffered_limits(ymin, ymax, shap_values.values) + _ = plt.subplots(1, len(inds), figsize=(min(6 * len(inds), 15), 5)) for i in inds: - ax = pl.subplot(1, len(inds), i + 1) + ax = plt.subplot(1, len(inds), i + 1) scatter(shap_values[:, i], color=color, show=False, ax=ax, ymin=ymin, ymax=ymax) if overlay is not None: line_styles = ["solid", "dotted", "dashed"] for j, name in enumerate(overlay): vals = overlay[name] if isinstance(vals[i][0][0], (float, int)): - pl.plot(vals[i][0], vals[i][1], color="#000000", linestyle=line_styles[j], label=name) + plt.plot(vals[i][0], vals[i][1], color="#000000", linestyle=line_styles[j], label=name) if i == 0: ax.set_ylabel(ylabel) else: @@ -128,44 +158,34 @@ def scatter( ax.set_yticks([]) ax.spines["left"].set_visible(False) if overlay is not None: - pl.legend() + plt.legend() if show: - pl.show() + plt.show() return if len(shap_values.shape) != 1: - raise Exception( - "The passed Explanation object has multiple columns, please pass a single feature column to " - "shap.plots.dependence like: shap_values[:,column]" + raise DimensionError( + "The passed Explanation object has multiple columns. Please pass a single feature column to " + "shap.plots.scatter like: shap_values[:, column]" ) # this unpacks the explanation object for the code that was written earlier feature_names = [shap_values.feature_names] - ind = 0 + ind: int = 0 shap_values_arr = shap_values.values.reshape(-1, 1) features = shap_values.data.reshape(-1, 1) if shap_values.display_data is None: display_features = features else: display_features = shap_values.display_data.reshape(-1, 1) - interaction_index = None - - # unwrap explanation objects used for bounds - if issubclass(type(xmin), Explanation): - xmin = xmin.data - if issubclass(type(xmax), Explanation): - xmax = xmax.data - if issubclass(type(ymin), Explanation): - ymin = ymin.values - if issubclass(type(ymax), Explanation): - ymax = ymax.values + interaction_index: str | int | None = None # wrap np.arrays as Explanations if isinstance(color, np.ndarray): color = Explanation(values=color, base_values=None, data=color) # TODO: This stacking could be avoided if we use the new shap.utils.potential_interactions function - if str(type(color)).endswith("Explanation'>"): + if isinstance(color, Explanation): shap_values2 = color if issubclass(type(shap_values2.feature_names), (str, int)): feature_names.append(shap_values2.feature_names) @@ -211,86 +231,20 @@ def scatter( if len(features.shape) == 1: features = np.reshape(features, (len(features), 1)) - ind = convert_name(ind, shap_values_arr, feature_names) - # pick jitter for categorical features - vals = np.sort(np.unique(features[:, ind])) if x_jitter == "auto": - min_dist = 1 - for i in range(1, len(vals)): - # If vals contains numbers, - # check for min_dist based on difference in vals - # Otherwise, min_dist remains set arbitrarily at 1 - try: - d = vals[i] - vals[i - 1] - if d > 1e-8 and d < min_dist: - min_dist = d - except TypeError: - pass - num_points_per_value = len(features[:, ind]) / len(vals) - if num_points_per_value < 10: - # categorical = False - x_jitter = 0 - elif num_points_per_value < 100: - # categorical = True - if x_jitter == "auto": - x_jitter = min_dist * 0.1 - else: - # categorical = True - if x_jitter == "auto": - x_jitter = min_dist * 0.2 + x_jitter = _suggest_x_jitter(features[:, ind]) # guess what other feature as the stongest interaction with the plotted feature - if not hasattr(ind, "__len__"): - if interaction_index == "auto": - interaction_index = approximate_interactions(ind, shap_values_arr, features)[0] - interaction_index = convert_name(interaction_index, shap_values_arr, feature_names) + if interaction_index == "auto": + interaction_index = approximate_interactions(ind, shap_values_arr, features)[0] + interaction_index = convert_name(interaction_index, shap_values_arr, feature_names) categorical_interaction = False # create a matplotlib figure, if `ax` hasn't been specified. - if not ax: + if ax is None: figsize = (7.5, 5) if interaction_index != ind and interaction_index is not None else (6, 5) - fig = pl.figure(figsize=figsize) - ax = fig.gca() - else: - fig = ax.get_figure() - - # plotting SHAP interaction values - if len(shap_values_arr.shape) == 3 and hasattr(ind, "__len__") and len(ind) == 2: - ind1 = convert_name(ind[0], shap_values_arr, feature_names) - ind2 = convert_name(ind[1], shap_values_arr, feature_names) - if ind1 == ind2: - proj_shap_values_arr = shap_values_arr[:, ind2, :] - else: - proj_shap_values_arr = shap_values_arr[:, ind2, :] * 2 # off-diag values are split in half - - # there is no interaction coloring for the main effect - if ind1 == ind2: - fig.set_size_inches(6, 5, forward=True) - - # TODO: remove recursion; generally the functions should be shorter for more maintainable code - dependence_legacy( - ind1, - proj_shap_values_arr, - features, - feature_names=feature_names, - interaction_index=(None if ind1 == ind2 else ind2), - display_features=display_features, - ax=ax, - show=False, - xmin=xmin, - xmax=xmax, - x_jitter=x_jitter, - alpha=alpha, - ) - if ind1 == ind2: - ax.set_ylabel(labels["MAIN_EFFECT"] % feature_names[ind1]) - else: - ax.set_ylabel(labels["INTERACTION_EFFECT"] % (feature_names[ind1], feature_names[ind2])) - - if show: - pl.show() - return + _, ax = plt.subplots(figsize=figsize) assert ( shap_values_arr.shape[0] == features.shape[0] @@ -400,47 +354,30 @@ def scatter( tick_positions = np.array([cname_map[n] for n in cnames]) tick_positions *= 1 - 1 / len(cnames) tick_positions += 0.5 * (chigh - clow) / (chigh - clow + 1) - cb = pl.colorbar(p, ticks=tick_positions, ax=ax, aspect=80) + cb = plt.colorbar(p, ticks=tick_positions, ax=ax, aspect=80) cb.set_ticklabels(cnames) else: - cb = pl.colorbar(p, ax=ax, aspect=80) + cb = plt.colorbar(p, ax=ax, aspect=80) + # Type narrowing for mypy + assert isinstance(interaction_index, (int, np.integer)), f"Unexpected {type(interaction_index)=}" cb.set_label(feature_names[interaction_index], size=13) cb.ax.tick_params(labelsize=11) if categorical_interaction: cb.ax.tick_params(length=0) cb.set_alpha(1) - cb.outline.set_visible(False) + cb.outline.set_visible(False) # type: ignore # bbox = cb.ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) # cb.ax.set_aspect((bbox.height - 0.7) * 20) - # handles any setting of xmax and xmin - # note that we handle None,float, or "percentile(float)" formats + xmin = _parse_limit(xmin, xv, is_shap_axis=False) + xmax = _parse_limit(xmax, xv, is_shap_axis=False) + ymin = _parse_limit(ymin, s, is_shap_axis=True) + ymax = _parse_limit(ymax, s, is_shap_axis=True) if xmin is not None or xmax is not None: - if isinstance(xmin, str) and xmin.startswith("percentile"): - xmin = np.nanpercentile(xv, float(xmin[11:-1])) - if isinstance(xmax, str) and xmax.startswith("percentile"): - xmax = np.nanpercentile(xv, float(xmax[11:-1])) - - if xmin is None or xmin == np.nanmin(xv): - xmin = np.nanmin(xv) - (xmax - np.nanmin(xv)) / 20 - if xmax is None or xmax == np.nanmax(xv): - xmax = np.nanmax(xv) + (np.nanmax(xv) - xmin) / 20 - - ax.set_xlim(xmin, xmax) - + ax.set_xlim(*_suggest_buffered_limits(xmin, xmax, xv)) if ymin is not None or ymax is not None: - # if type(ymin) == str and ymin.startswith("percentile"): - # ymin = np.nanpercentile(xv, float(ymin[11:-1])) - # if type(ymax) == str and ymax.startswith("percentile"): - # ymax = np.nanpercentile(xv, float(ymax[11:-1])) - - if ymin is None or ymin == np.nanmin(xv): - ymin = np.nanmin(xv) - (ymax - np.nanmin(xv)) / 20 - if ymax is None or ymax == np.nanmax(xv): - ymax = np.nanmax(xv) + (np.nanmax(xv) - ymin) / 20 - - ax.set_ylim(ymin, ymax) + ax.set_ylim(*_suggest_buffered_limits(ymin, ymax, s)) # plot any nan feature values as tick marks along the y-axis xlim = ax.get_xlim() @@ -448,7 +385,7 @@ def scatter( p = ax.scatter( xlim[0] * np.ones(xv_nan.sum()), s[xv_nan], - marker=1, + marker=MarkerStyle(1), linewidth=2, c=cvals_imp[xv_nan], cmap=cmap, @@ -458,57 +395,16 @@ def scatter( ) p.set_array(cvals[xv_nan]) else: - ax.scatter(xlim[0] * np.ones(xv_nan.sum()), s[xv_nan], marker=1, linewidth=2, color=color, alpha=alpha) + ax.scatter( + xlim[0] * np.ones(xv_nan.sum()), s[xv_nan], marker=MarkerStyle(1), linewidth=2, color=color, alpha=alpha + ) ax.set_xlim(xlim) # the histogram of the data if hist: - ax2 = ax.twinx() - # n, bins, patches = - xlim = ax.get_xlim() - xvals = np.unique(xv_no_jitter) - - if len(xvals) / len(xv_no_jitter) < 0.2 and len(xvals) < 75 and np.max(xvals) < 75 and np.min(xvals) >= 0: - np.sort(xvals) - bin_edges = [] - for i in range(int(np.max(xvals) + 1)): - bin_edges.append(i - 0.5) - - # bin_edges.append((xvals[i] + xvals[i+1])/2) - bin_edges.append(int(np.max(xvals)) + 0.5) - - lim = np.floor(np.min(xvals) - 0.5) + 0.5, np.ceil(np.max(xvals) + 0.5) - 0.5 - ax.set_xlim(lim) - else: - if len(xv_no_jitter) >= 500: - bin_edges = 50 - elif len(xv_no_jitter) >= 200: - bin_edges = 20 - elif len(xv_no_jitter) >= 100: - bin_edges = 10 - else: - bin_edges = 5 - - ax2.hist( - xv[~np.isnan(xv)], - bin_edges, - density=False, - facecolor="#000000", - alpha=0.1, - range=(xlim[0], xlim[1]), - zorder=-1, - ) - ax2.set_ylim(0, len(xv)) - - ax2.xaxis.set_ticks_position("bottom") - ax2.yaxis.set_ticks_position("left") - ax2.yaxis.set_ticks([]) - ax2.spines["right"].set_visible(False) - ax2.spines["top"].set_visible(False) - ax2.spines["left"].set_visible(False) - ax2.spines["bottom"].set_visible(False) + _plot_histogram(ax, xv, xv_no_jitter) - pl.sca(ax) + plt.sca(ax) # make the plot more readable ax.set_xlabel(name, color=axis_color, fontsize=13) @@ -528,7 +424,108 @@ def scatter( if show: with warnings.catch_warnings(): # ignore expected matplotlib warnings warnings.simplefilter("ignore", RuntimeWarning) - pl.show() + plt.show() + else: + return ax + + +def _parse_limit(ax_limit: LimitSpec, ax_values: np.ndarray, is_shap_axis: bool) -> float | None: + """Handle axis limits in "percentile(float)" format or from Explanation objects""" + if isinstance(ax_limit, str): + try: + percentage = float(ax_limit.removeprefix("percentile(").removesuffix(")")) + except ValueError as e: + raise ValueError("Only strings of the format `percentile(x)` are supported.") from e + return np.nanpercentile(ax_values, percentage) + if isinstance(ax_limit, Explanation): + # Expect Explanation aggregated to a single value, e.g. `explanation[:, "feature_name"].percentile(20)` + # Extract relevant attribute, depending if x- or y-axis + return float(ax_limit.values) if is_shap_axis else float(ax_limit.data) + # Else, should be float or None + return ax_limit + + +def _suggest_buffered_limits(ax_min: float | None, ax_max: float | None, values: np.ndarray) -> tuple[float, float]: + """If either limit is None, suggest suitable value including a buffer either side""" + nan_max = np.nanmax(values) if ax_max is None else ax_max + nan_min = np.nanmin(values) if ax_min is None else ax_min + buffer = (nan_max - nan_min) / 20 + if ax_min is None: + ax_min = float(nan_min - buffer) + if ax_max is None: + ax_max = float(nan_max + buffer) + return ax_min, ax_max + + +def _suggest_x_jitter(values: np.ndarray) -> float: + """Suggest a suitable x_jitter value based on the unique values in the feature""" + unique_vals = np.sort(np.unique(values)) + try: + # Identify the smallest difference between unique values + diffs = np.diff(unique_vals) + min_dist = np.min(diffs[diffs > 1e-8]) + except TypeError: + # If unique_vals contains non-numeric values, set arbitrarily at 1 + min_dist = 1 + + num_points_per_value = len(values) / len(unique_vals) + if num_points_per_value < 10: + # categorical = False + x_jitter = 0 + elif num_points_per_value < 100: + # categorical = True + x_jitter = min_dist * 0.1 + else: + # categorical = True + x_jitter = min_dist * 0.2 + return x_jitter + + +def _plot_histogram(ax: plt.Axes, xv, xv_no_jitter): + """Add a histogram of the data on a matching secondary axes""" + ax2 = typing.cast(plt.Axes, ax.twinx()) + xlim = ax.get_xlim() + xvals = np.unique(xv_no_jitter) + + # Determine suitable bins and limits + bins: list[float] | int # Hint for mypy + if len(xvals) / len(xv_no_jitter) < 0.2 and len(xvals) < 75 and np.max(xvals) < 75 and np.min(xvals) >= 0: + np.sort(xvals) + bins = [] + for i in range(int(np.max(xvals) + 1)): + bins.append(i - 0.5) + bins.append(int(np.max(xvals)) + 0.5) + + lim = np.floor(np.min(xvals) - 0.5) + 0.5, np.ceil(np.max(xvals) + 0.5) - 0.5 + ax.set_xlim(lim) + else: + if len(xv_no_jitter) >= 500: + bins = 50 + elif len(xv_no_jitter) >= 200: + bins = 20 + elif len(xv_no_jitter) >= 100: + bins = 10 + else: + bins = 5 + + # Plot the histogram + ax2.hist( + xv[~np.isnan(xv)], + bins, + density=False, + facecolor="#000000", + alpha=0.1, + range=(xlim[0], xlim[1]), + zorder=-1, + ) + ax2.set_ylim(0, len(xv)) + ax2.xaxis.set_ticks_position("bottom") + ax2.yaxis.set_ticks_position("left") + ax2.yaxis.set_ticks([]) + ax2.spines["right"].set_visible(False) + ax2.spines["top"].set_visible(False) + ax2.spines["left"].set_visible(False) + ax2.spines["bottom"].set_visible(False) def dependence_legacy( @@ -655,7 +652,7 @@ def dependence_legacy( # create a matplotlib figure, if `ax` hasn't been specified. if not ax: figsize = (7.5, 5) if interaction_index != ind and interaction_index is not None else (6, 5) - fig = pl.figure(figsize=figsize) + fig = plt.figure(figsize=figsize) ax = fig.gca() else: fig = ax.get_figure() @@ -694,7 +691,7 @@ def dependence_legacy( ax.set_ylabel(labels["INTERACTION_EFFECT"] % (feature_names[ind1], feature_names[ind2])) if show: - pl.show() + plt.show() return assert ( @@ -798,17 +795,17 @@ def dependence_legacy( if len(tick_positions) == 2: tick_positions[0] -= 0.25 tick_positions[1] += 0.25 - cb = pl.colorbar(p, ticks=tick_positions, ax=ax, aspect=80) + cb = plt.colorbar(p, ticks=tick_positions, ax=ax, aspect=80) cb.set_ticklabels(cnames) else: - cb = pl.colorbar(p, ax=ax, aspect=80) + cb = plt.colorbar(p, ax=ax, aspect=80) cb.set_label(feature_names[interaction_index], size=13) cb.ax.tick_params(labelsize=11) if categorical_interaction: cb.ax.tick_params(length=0) cb.set_alpha(1) - cb.outline.set_visible(False) + cb.outline.set_visible(False) # type: ignore # bbox = cb.ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) # cb.ax.set_aspect((bbox.height - 0.7) * 20) @@ -873,4 +870,4 @@ def dependence_legacy( if show: with warnings.catch_warnings(): # ignore expected matplotlib warnings warnings.simplefilter("ignore", RuntimeWarning) - pl.show() + plt.show() diff --git a/tests/plots/test_scatter.py b/tests/plots/test_scatter.py index ea9080eea..f1030d38e 100644 --- a/tests/plots/test_scatter.py +++ b/tests/plots/test_scatter.py @@ -8,27 +8,24 @@ def test_scatter_single(explainer): explanation = explainer(explainer.data) shap.plots.scatter(explanation[:, "Age"], show=False) - fig = plt.gcf() plt.tight_layout() - return fig + return plt.gcf() @pytest.mark.mpl_image_compare def test_scatter_interaction(explainer): explanation = explainer(explainer.data) shap.plots.scatter(explanation[:, "Age"], color=explanation[:, "Workclass"], show=False) - fig = plt.gcf() plt.tight_layout() - return fig + return plt.gcf() @pytest.mark.mpl_image_compare def test_scatter_dotchain(explainer): explanation = explainer(explainer.data) shap.plots.scatter(explanation[:, explanation.abs.mean(0).argsort[-2]], show=False) - fig = plt.gcf() plt.tight_layout() - return fig + return plt.gcf() @pytest.mark.mpl_image_compare @@ -42,9 +39,8 @@ def test_scatter_multiple_cols_overlay(explainer): ], } shap.plots.scatter(shap_values, overlay=overlay, show=False) - fig = plt.gcf() plt.tight_layout() - return fig + return plt.gcf() @pytest.mark.mpl_image_compare @@ -63,9 +59,8 @@ def test_scatter_custom(explainer): cmap=plt.get_cmap("cool"), show=False, ) - fig = plt.gcf() plt.tight_layout() - return fig + return plt.gcf() @pytest.fixture()