Skip to content

Commit

Permalink
[python] add ROC, LIFT, update shap, CP, AP
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaniecki committed Dec 23, 2020
1 parent f783000 commit 5e13f5b
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 43 deletions.
7 changes: 6 additions & 1 deletion python/dalex/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@ These are summed up in ([#368](https://github.com/ModelOriented/DALEX/issues/368
* rename modules: `dataset_level` into `model_exlpanations`, `instance_level` into `predict_explanations`, `_arena` module into `arena`
* use `__dir__` method to define autocompletion in IPython environment - show only `['Explainer', 'Arena', 'fairness', 'datasets']`
* add `plot` method and `result` attribute to `LimeExplanation` (use `lime.explanation.Explanation.as_pyplot_figure()` and `lime.explanation.Explanation.as_list()`)
* `CeterisParibus.plot(variable_type='categorical')` now has horizontal barplots - `horizontal_spacing=None` by default (varies on `variable_type`)
* `CeterisParibus.plot(variable_type='categorical')` now has horizontal barplots - `horizontal_spacing=None` by default (varies on `variable_type`). Also, once again added the "dot" for observation value.
* `predict_fn` in `predict_surrogate` now uses `predict_function` (trying to make it work for more frameworks)

#### fixes

* fixed wrong verbose output when any value in `y_hat/residuals` was an `int` not `float`
* added proper `"-"` sign to negative dropout losses in `VariableImportance.plot`

#### features

* added `geom='bars'` to `AggregateProfiles.plot` to force the categorical plot
* added `geom='roc'` and `geom='lift'` to `ModelPerformance.plot`

#### other

* remove `colorize` from `Explainer`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def plot(self,
-----------
objects : AggregatedProfiles object or array_like of AggregatedProfiles objects
Additional objects to plot in subplots (default is `None`).
geom : {'aggregates', 'profiles'}
If `'profiles'` then raw profiles will be plotted in the background
geom : {'aggregates', 'profiles', 'bars'}
If `'profiles'` then raw profiles will be plotted in the background,
'bars' overrides the `_x_` column type and uses barplots for categorical data
(default is `'aggregates'`, which means plot only aggregated profiles).
NOTE: It is useful to use small values of the `N` parameter in object creation
before using `'profiles'`, because of plot performance and clarity (e.g. `100`).
Expand Down Expand Up @@ -206,8 +207,8 @@ def plot(self,
Return figure that can be edited or saved. See `show` parameter.
"""

if geom not in ("aggregates", "profiles"):
raise TypeError("geom should be 'aggregates' or 'profiles'")
if geom not in ("aggregates", "profiles", "bars"):
raise TypeError("geom should be one of {'aggregates', 'profiles', 'bars'}")
if isinstance(variables, str):
variables = (variables,)

Expand Down Expand Up @@ -240,7 +241,7 @@ def plot(self,
min_max_margin = dl.ptp() * 0.10
min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

is_x_numeric = pd.api.types.is_numeric_dtype(_result_df['_x_'])
is_x_numeric = False if geom == 'bars' else pd.api.types.is_numeric_dtype(_result_df['_x_'])
n = len(all_variables)

facet_nrow = int(np.ceil(n / facet_ncol))
Expand Down Expand Up @@ -298,7 +299,7 @@ def plot(self,
fig = px.bar(_result_df,
x="_x_", y="_diff_", color="_label_", facet_col="_vname_",
category_orders={"_vname_": list(all_variables)},
labels={'_yhat_': 'prediction', '_mp_': 'mean_prediction'}, # , color: 'group'},
labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'}, # , color: 'group'},
hover_name=color,
base="_mp_",
hover_data={'_yhat_': ':.3f', '_mp_': mp_format, '_diff_': False,
Expand All @@ -316,10 +317,8 @@ def plot(self,
'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
'range': min_max})

# add hline https://github.com/plotly/plotly.py/issues/2141
for i, bar in enumerate(fig.data):
fig.add_shape(type='line', y0=bar.base[0], y1=bar.base[0], x0=-1, x1=len(bar.x),
xref=bar.xaxis, yref=bar.yaxis, layer='below',
for _, bar in enumerate(fig.data):
fig.add_hline(y=bar.base[0], layer='below',
line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'})

fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, hovermode)
Expand Down
40 changes: 16 additions & 24 deletions python/dalex/dalex/model_explanations/_model_performance/object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import pandas as pd
import plotly.graph_objects as go

from . import plot, utils
from ... import _theme, _global_checks
Expand Down Expand Up @@ -119,14 +118,17 @@ def fit(self, explainer):

def plot(self,
objects=None,
title="Reverse cumulative distribution of |residual|",
geom="ecdf",
title=None,
show=False):
"""Plot the Model Performance explanation
Parameters
-----------
objects : ModelPerformance object or array_like of ModelPerformance objects
Additional objects to plot (default is `None`).
geom: {'ecdf', 'roc', 'lift'}
Type of plot determines how residuals shall be summarized.
title : str, optional
Title of the plot (default depends on the `type` attribute).
show : bool, optional
Expand All @@ -139,6 +141,9 @@ def plot(self,
Return figure that can be edited or saved. See `show` parameter.
"""

if geom not in ("ecdf", "roc", "lift"):
raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}")

# are there any other objects to plot?
if objects is None:
_df_list = [self.residuals.copy()]
Expand All @@ -153,28 +158,15 @@ def plot(self,
_global_checks.global_raise_objects_class(objects, self.__class__)

colors = _theme.get_default_colors(len(_df_list), 'line')
fig = go.Figure()

for i, _df in enumerate(_df_list):
_abs_residuals = np.abs(_df['residuals'])
_unique_abs_residuals = np.unique(_abs_residuals)

fig.add_scatter(
x=_unique_abs_residuals,
y=1 - plot.ecdf(_abs_residuals)(_unique_abs_residuals),
line_shape='hv',
name=_df.iloc[0, _df.columns.get_loc('label')],
marker=dict(color=colors[i])
)

fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside',
'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'tickformat': ',.0%'})

fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': '|residual|'})

fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
margin={'t': 78, 'b': 71, 'r': 30})

if geom == 'ecdf':
fig = plot.plot_ecdf(_df_list, colors, title)
elif geom == 'roc':
fig = plot.plot_roc(_df_list, colors, title)
elif geom == 'lift':
fig = plot.plot_lift(_df_list, colors, title)
else:
raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}")

if show:
fig.show(config=_theme.get_default_config())
Expand Down
98 changes: 97 additions & 1 deletion python/dalex/dalex/model_explanations/_model_performance/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

import pandas as pd
import plotly.graph_objects as go

def ecdf(x):
# https://community.plot.ly/t/plot-the-empirical-cdf/29045
Expand All @@ -9,3 +10,98 @@ def result(v):
return np.searchsorted(x, v, side='right') / x.size

return result


def plot_ecdf(df_list, colors, title):
fig = go.Figure()

for i, _df in enumerate(df_list):
_abs_residuals = np.abs(_df['residuals'])
_unique_abs_residuals = np.unique(_abs_residuals)

fig.add_scatter(
x=_unique_abs_residuals,
y=1 - ecdf(_abs_residuals)(_unique_abs_residuals),
line_shape='hv',
name=_df.iloc[0, _df.columns.get_loc('label')],
marker=dict(color=colors[i])
)

fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside',
'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'tickformat': ',.0%'})

fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': '|residual|'})

title = "Reverse cumulative distribution of |residual|" if title is None else title
fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
margin={'t': 78, 'b': 71, 'r': 30})

return fig


def plot_roc(df_list, colors, title):
fig = go.Figure()
grid_points = 101
idx = np.arange(df_list[0].shape[0], step=int(df_list[0].shape[0]/grid_points))
for i, _df in enumerate(df_list):
_df = _df.sort_values('y_hat', ascending=False)
_df = _df.assign(TPR=np.cumsum(_df.y)/np.sum(_df.y),
FPR=(np.cumsum(1-_df.y)/np.sum(1-_df.y)))
if _df.shape[0] > grid_points:
_df = _df.iloc[idx,:].sort_values('FPR', ascending=True)

fig.add_scatter(
x=_df.FPR,
y=_df.TPR,
line_shape='hv',
name=_df.iloc[0, _df.columns.get_loc('label')],
marker=dict(color=colors[i])
)

fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside',
'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': 'True positive rate'})

fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': 'False positive rate'})
title = "Receiver Operating Characteristic" if title is None else title
fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
margin={'t': 78, 'b': 71, 'r': 30})

return fig

def plot_lift(df_list, colors, title):
fig = go.Figure()
grid_points = 101
idx = np.arange(df_list[0].shape[0], step=int(df_list[0].shape[0]/grid_points))
_temp_df = pd.concat(df_list)
max_lift = _temp_df.y.sum()/_temp_df.shape[0]

for i, _df in enumerate(df_list):
_df = _df.sort_values('y_hat', ascending=False)
n = _df.shape[0]
lift = np.cumsum(_df.y)/n
pr = np.linspace(0, 1, n)
_df = _df.assign(lift=lift/pr, pr=pr)
if _df.shape[0] > grid_points:
_df = _df.iloc[idx,:].sort_values('pr', ascending=True)

fig.add_scatter(
x=_df.pr,
y=_df.lift/max_lift,
line_shape='hv',
name=_df.iloc[0, _df.columns.get_loc('label')],
marker=dict(color=colors[i])
)

fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside',
'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': 'Lift'})

fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': 'Positive rate'})
title = "LIFT chart" if title is None else title
fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
margin={'t': 78, 'b': 71, 'r': 30})

return fig

25 changes: 21 additions & 4 deletions python/dalex/dalex/predict_explanations/_ceteris_paribus/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,13 @@ def plot(self,
for _, value in enumerate(fig_points.data):
fig.add_trace(value)

fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, 'closest')
fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, 'closest')

else:
if color=="_label_" and len(_result_df['_ids_'].unique()) > 1 and len(_result_df['_label_'].unique()) == 1:
warnings.warn("'color' parameter changed to '_ids_' because there are multiple observations for one model.")
color = '_ids_'
elif color=="_label_" and len(_result_df['_ids_'].unique()) != len(_result_df['_label_'].unique()):
elif color=="_label_" and len(_result_df['_ids_'].unique()) > len(_result_df['_label_'].unique()):
# https://github.com/plotly/plotly.py/issues/2657
raise TypeError("Please pick one observation per label or change the `color` parameter.")

Expand Down Expand Up @@ -362,11 +362,28 @@ def plot(self,
.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
'range': min_max})

# add hline https://github.com/plotly/plotly.py/issues/2141

for _, bar in enumerate(fig.data):
fig.add_vline(x=bar.base[0], layer='below',
line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'})

if show_observations:
_points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy()

fig_points = px.scatter(_points_df,
x='_yhat_', y='_x_', facet_col='_vname_',
category_orders={"_vname_": list(variable_names)},
labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},
custom_data=['_text_'],
facet_col_wrap=facet_ncol,
facet_row_spacing=vertical_spacing,
facet_col_spacing=horizontal_spacing,
color_discrete_sequence=["#371ea3"]) \
.update_traces(dict(marker_size=5*size, opacity=alpha),
hovertemplate="%{customdata[0]}<extra></extra>")

for _, value in enumerate(fig_points.data):
fig.add_trace(value)

fig = _theme.fig_update_bar_plot(fig, title, y_title, plot_height, 'closest')

Expand Down
4 changes: 3 additions & 1 deletion python/dalex/dalex/wrappers/_shap/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def check_shap_explainer_type(shap_explainer_type, model):
if model_type.endswith("sklearn.ensemble._forest.RandomForestRegressor'>") or\
model_type.endswith("sklearn.ensemble._forest.RandomForestClassifier'>") or\
model_type.endswith("xgboost.core.Booster'>") or\
model_type.endswith("lightgbm.basic.Booster'>"):
model_type.endswith("lightgbm.basic.Booster'>") or\
model_type.endswith("catboost.core.CatBoostRegressor'>") or\
model_type.endswith("catboost.core.CatBoostClassifier'>"):
shap_explainer_type = "TreeExplainer"
elif model_type.endswith("'keras.engine.training.Model'>") or\
model_type.endswith("nn.Module'>"):
Expand Down
10 changes: 8 additions & 2 deletions python/dalex/dalex/wrappers/_shap/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def fit(self,
shap_explainer_type : {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'}
String name of the Explainer class (default is `None`, which automatically
chooses an Explainer to use).
kwargs: dict
Keyword parameters passed to the `shapley_values` method.
Returns
-----------
Expand All @@ -72,7 +74,10 @@ def fit(self,
new_observation = checks.check_new_observation_predict_parts(new_observation, explainer)

if shap_explainer_type == "TreeExplainer":
self.shap_explainer = TreeExplainer(explainer.model, explainer.data.values)
try:
self.shap_explainer = TreeExplainer(explainer.model, explainer.data.values)
except: # https://github.com/ModelOriented/DALEX/issues/371
self.shap_explainer = TreeExplainer(explainer.model)
elif shap_explainer_type == "DeepExplainer":
self.shap_explainer = DeepExplainer(explainer.model, explainer.data.values)
elif shap_explainer_type == "GradientExplainer":
Expand All @@ -81,7 +86,8 @@ def fit(self,
self.shap_explainer = LinearExplainer(explainer.model, explainer.data.values)
elif shap_explainer_type == "KernelExplainer":
self.shap_explainer = KernelExplainer(
lambda x: explainer.predict(x), explainer.data.values)
lambda x: explainer.predict(x), explainer.data.values
)

self.result = self.shap_explainer.shap_values(new_observation.values, **kwargs)
self.new_observation = new_observation
Expand Down

0 comments on commit 5e13f5b

Please sign in to comment.