Skip to content

Commit

Permalink
ENH feat importance per stratum viz (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
joaopfonseca committed Mar 3, 2024
1 parent 4d19122 commit f286725
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
88 changes: 88 additions & 0 deletions sharp/visualization/_aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Produce dataset-wide plots.
"""

import numpy as np
import pandas as pd
from sharp.utils._utils import _optional_import
from sharp.utils import check_feature_names, scores_to_ordering


def strata_boxplots(
X,
y,
contributions,
feature_names=None,
n_strata=5,
gap_size=1,
cmap=None,
ax=None,
**kwargs,
):

plt = _optional_import("matplotlib.pyplot")

if feature_names is None:
feature_names = check_feature_names(X)

if ax is None:
fig, ax = plt.subplots()

df = pd.DataFrame(contributions, columns=feature_names)

perc_step = 100 / n_strata
stratum_size = X.shape[0] / n_strata

df["target"] = scores_to_ordering(y, -1)
df["target_binned"] = [
(
f"0-\n{int(perc_step)}%"
if np.floor((rank - 1) / stratum_size) == 0
else str(int(np.floor((rank - 1) / stratum_size) * perc_step))
+ "-\n"
+ str(int((np.floor((rank - 1) / stratum_size) + 1) * perc_step))
+ "%"
)
for rank in df["target"]
]
df.sort_values(by=["target_binned"], inplace=True)
df.drop(columns=["target"], inplace=True)

df["target_binned"] = df["target_binned"].str.replace("<", "$<$")

colors = [plt.get_cmap(cmap)(i) for i in range(len(feature_names))]
bin_names = df["target_binned"].unique()
pos_increment = 1 / (len(feature_names) + gap_size)
boxes = []
for i, bin_name in enumerate(bin_names):
box = plt.boxplot(
df[df["target_binned"] == bin_name][feature_names],
widths=pos_increment,
positions=[i + pos_increment * n for n in range(len(feature_names))],
patch_artist=True,
medianprops={"color": "black"},
boxprops={"facecolor": "C0", "edgecolor": "black"},
**kwargs,
)
boxes.append(box)

for box in boxes:
patches = []
for patch, color in zip(box["boxes"], colors):
patch.set_facecolor(color)
patches.append(patch)

plt.xticks(
np.arange(0, len(bin_names)) + pos_increment * (len(feature_names) - 1) / 2,
bin_names,
)

plt.legend(
patches,
feature_names,
loc="upper center",
bbox_to_anchor=(0.5, 1.05),
ncol=len(feature_names),
)

plt.show()
27 changes: 27 additions & 0 deletions sharp/visualization/_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
from sharp.utils._utils import _optional_import
from ._waterfall import _waterfall
from ._aggregate import strata_boxplots


class ShaRPViz: # TODO
Expand Down Expand Up @@ -44,3 +45,29 @@ def waterfall(self, scores, mean_shapley_value=0):
"values": pd.Series(scores, index=feature_names),
}
return _waterfall(rank_dict, max_display=10)

def strata_boxplot(
self,
X,
y,
contributions,
feature_names=None,
n_strata=5,
gap_size=1,
cmap=None,
ax=None,
**kwargs
):
if feature_names is None:
feature_names = self.sharp.feature_names_.astype(str).tolist()
return strata_boxplots(
X=X,
y=y,
contributions=contributions,
feature_names=feature_names,
n_strata=n_strata,
gap_size=gap_size,
cmap=cmap,
ax=ax,
**kwargs
)

0 comments on commit f286725

Please sign in to comment.