diff --git a/sharp/visualization/_aggregate.py b/sharp/visualization/_aggregate.py new file mode 100644 index 0000000..90cd7d5 --- /dev/null +++ b/sharp/visualization/_aggregate.py @@ -0,0 +1,88 @@ +""" +Produce dataset-wide plots. +""" + +import numpy as np +import pandas as pd +from sharp.utils._utils import _optional_import +from sharp.utils import check_feature_names, scores_to_ordering + + +def strata_boxplots( + X, + y, + contributions, + feature_names=None, + n_strata=5, + gap_size=1, + cmap=None, + ax=None, + **kwargs, +): + + plt = _optional_import("matplotlib.pyplot") + + if feature_names is None: + feature_names = check_feature_names(X) + + if ax is None: + fig, ax = plt.subplots() + + df = pd.DataFrame(contributions, columns=feature_names) + + perc_step = 100 / n_strata + stratum_size = X.shape[0] / n_strata + + df["target"] = scores_to_ordering(y, -1) + df["target_binned"] = [ + ( + f"0-\n{int(perc_step)}%" + if np.floor((rank - 1) / stratum_size) == 0 + else str(int(np.floor((rank - 1) / stratum_size) * perc_step)) + + "-\n" + + str(int((np.floor((rank - 1) / stratum_size) + 1) * perc_step)) + + "%" + ) + for rank in df["target"] + ] + df.sort_values(by=["target_binned"], inplace=True) + df.drop(columns=["target"], inplace=True) + + df["target_binned"] = df["target_binned"].str.replace("<", "$<$") + + colors = [plt.get_cmap(cmap)(i) for i in range(len(feature_names))] + bin_names = df["target_binned"].unique() + pos_increment = 1 / (len(feature_names) + gap_size) + boxes = [] + for i, bin_name in enumerate(bin_names): + box = plt.boxplot( + df[df["target_binned"] == bin_name][feature_names], + widths=pos_increment, + positions=[i + pos_increment * n for n in range(len(feature_names))], + patch_artist=True, + medianprops={"color": "black"}, + boxprops={"facecolor": "C0", "edgecolor": "black"}, + **kwargs, + ) + boxes.append(box) + + for box in boxes: + patches = [] + for patch, color in zip(box["boxes"], colors): + patch.set_facecolor(color) + patches.append(patch) + + plt.xticks( + np.arange(0, len(bin_names)) + pos_increment * (len(feature_names) - 1) / 2, + bin_names, + ) + + plt.legend( + patches, + feature_names, + loc="upper center", + bbox_to_anchor=(0.5, 1.05), + ncol=len(feature_names), + ) + + plt.show() diff --git a/sharp/visualization/_visualization.py b/sharp/visualization/_visualization.py index 22821a0..9fc94bc 100644 --- a/sharp/visualization/_visualization.py +++ b/sharp/visualization/_visualization.py @@ -7,6 +7,7 @@ import pandas as pd from sharp.utils._utils import _optional_import from ._waterfall import _waterfall +from ._aggregate import strata_boxplots class ShaRPViz: # TODO @@ -44,3 +45,29 @@ def waterfall(self, scores, mean_shapley_value=0): "values": pd.Series(scores, index=feature_names), } return _waterfall(rank_dict, max_display=10) + + def strata_boxplot( + self, + X, + y, + contributions, + feature_names=None, + n_strata=5, + gap_size=1, + cmap=None, + ax=None, + **kwargs + ): + if feature_names is None: + feature_names = self.sharp.feature_names_.astype(str).tolist() + return strata_boxplots( + X=X, + y=y, + contributions=contributions, + feature_names=feature_names, + n_strata=n_strata, + gap_size=gap_size, + cmap=cmap, + ax=ax, + **kwargs + )