Skip to content

Commit

Permalink
[Pyton][Arena] Add Fairnes plot (#373)
Browse files Browse the repository at this point in the history
* [Pyton][Arena] Add Fairnes plot

* [Pyton] Update NEWS.md
  • Loading branch information
piotrpiatyszek committed Dec 28, 2020
1 parent 5e13f5b commit a22a6da
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 4 deletions.
1 change: 1 addition & 0 deletions python/dalex/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ These are summed up in ([#368](https://github.com/ModelOriented/DALEX/issues/368

* added `geom='bars'` to `AggregateProfiles.plot` to force the categorical plot
* added `geom='roc'` and `geom='lift'` to `ModelPerformance.plot`
* added Fairness plot to Arena

#### other

Expand Down
3 changes: 2 additions & 1 deletion python/dalex/dalex/arena/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def __init__(self, precalculate=False, enable_attributes=True, enable_custom_par
CeterisParibusContainer,
BreakDownContainer,
MetricsContainer,
ROCContainer
ROCContainer,
FairnessCheckContainer
]
self.options = {}
for plot in self.plots:
Expand Down
4 changes: 3 additions & 1 deletion python/dalex/dalex/arena/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ._ceteris_paribus_container import CeterisParibusContainer
from ._metrics_container import MetricsContainer
from ._roc_container import ROCContainer
from ._fairness_check_container import FairnessCheckContainer

__all__ = [
'ShapleyValuesContainer',
Expand All @@ -15,5 +16,6 @@
'CeterisParibusContainer',
'BreakDownContainer',
'MetricsContainer',
'ROCContainer'
'ROCContainer',
'FairnessCheckContainer'
]
53 changes: 53 additions & 0 deletions python/dalex/dalex/arena/plots/_fairness_check_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pandas as pd
import numpy as np
from pandas.api.types import is_object_dtype
from .._plot_container import PlotContainer
from dalex.fairness._group_fairness import utils, checks

def rm_nan(obj):
return { k: (None if np.isnan(obj[k]) or np.isinf(obj[k]) else obj[k]) for k in obj.keys() }

class FairnessCheckContainer(PlotContainer):
info = {
'name': 'Fairness',
'plotType': 'Fairness',
'plotCategory': 'Dataset Level',
'requiredParams': ['model', 'variable']
}
options = {
'cutoffs': { 'default': [x / 100 for x in range(5, 100, 5)], 'desc': 'List of tested cutoff levels' },
}
def _fit(self, model, variable):
if not variable.variable in model.variables:
raise Exception('Variable is not a column of explainer')
exp = model.explainer
y_hat = exp.predict(exp.data) if exp.y_hat is None else exp.y_hat
protected = exp.data[variable.variable]
if exp.model_type != 'classification':
self.set_message('Fairness plot is only available for classificators')
return
if not is_object_dtype(protected):
self.set_message('Select categorical variable to check fairness')
return

output_df = None
for cutoff in self.arena.get_option(self.plot_type, 'cutoffs'):
cutoff_dict = checks.check_cutoff(protected, cutoff, False)
sub_confusion_matrix = utils.SubgroupConfusionMatrix(exp.y, y_hat, protected, cutoff_dict)
sub_confusion_matrix_metrics = utils.SubgroupConfusionMatrixMetrics(sub_confusion_matrix)
df = sub_confusion_matrix_metrics.to_vertical_DataFrame()
df['cutoff'] = cutoff
output_df = df if output_df is None else output_df.append(df)

output = {}
for (subgroup, x) in output_df.set_index('metric').groupby('subgroup'):
output[subgroup] = {}
for (cutoff, y) in x.groupby('cutoff'):
output[subgroup][cutoff] = rm_nan(y['score'].to_dict())

self.data = { 'subgroups': output }

def test_arena(arena):
if type(arena).__name__ != 'Arena' or type(arena).__module__ != 'dalex.arena.object':
raise Exception('Invalid Arena argument')
return next((True for model in arena.get_params('model') if model.explainer.model_type == 'classification'), False)
2 changes: 1 addition & 1 deletion python/dalex/dalex/arena/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_params(request):
@app.route("/<string:plot_type>", methods=['GET'])
def get_plot(plot_type):
if plot_type == 'timestamp':
return {'timestamp': arena.timestamp * 1000}
return Response(json.dumps({'timestamp': arena.timestamp * 1000}, default=convert), content_type='application/json')
elif plot_type == 'shutdown':
if request.args.get('token') != shutdown_token:
abort(403)
Expand Down
2 changes: 1 addition & 1 deletion python/dalex/test/test_arena_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setUp(self):

# This plots should be supported
self.reference_plots = [ROCContainer, ShapleyValuesContainer, BreakDownContainer, CeterisParibusContainer,
FeatureImportanceContainer, PartialDependenceContainer, AccumulatedDependenceContainer, MetricsContainer]
FeatureImportanceContainer, PartialDependenceContainer, AccumulatedDependenceContainer, MetricsContainer, FairnessCheckContainer]

def test_supported_plots(self):
arena = dx.Arena()
Expand Down

0 comments on commit a22a6da

Please sign in to comment.