diff --git a/python/dalex/NEWS.md b/python/dalex/NEWS.md index b89e81a40..dd3c64b1a 100644 --- a/python/dalex/NEWS.md +++ b/python/dalex/NEWS.md @@ -22,6 +22,7 @@ These are summed up in ([#368](https://github.com/ModelOriented/DALEX/issues/368 * added `geom='bars'` to `AggregateProfiles.plot` to force the categorical plot * added `geom='roc'` and `geom='lift'` to `ModelPerformance.plot` +* added Fairness plot to Arena #### other diff --git a/python/dalex/dalex/arena/object.py b/python/dalex/dalex/arena/object.py index 2e33267b8..54e7d1cab 100644 --- a/python/dalex/dalex/arena/object.py +++ b/python/dalex/dalex/arena/object.py @@ -85,7 +85,8 @@ def __init__(self, precalculate=False, enable_attributes=True, enable_custom_par CeterisParibusContainer, BreakDownContainer, MetricsContainer, - ROCContainer + ROCContainer, + FairnessCheckContainer ] self.options = {} for plot in self.plots: diff --git a/python/dalex/dalex/arena/plots/__init__.py b/python/dalex/dalex/arena/plots/__init__.py index 0b067223a..3649d642b 100644 --- a/python/dalex/dalex/arena/plots/__init__.py +++ b/python/dalex/dalex/arena/plots/__init__.py @@ -6,6 +6,7 @@ from ._ceteris_paribus_container import CeterisParibusContainer from ._metrics_container import MetricsContainer from ._roc_container import ROCContainer +from ._fairness_check_container import FairnessCheckContainer __all__ = [ 'ShapleyValuesContainer', @@ -15,5 +16,6 @@ 'CeterisParibusContainer', 'BreakDownContainer', 'MetricsContainer', - 'ROCContainer' + 'ROCContainer', + 'FairnessCheckContainer' ] diff --git a/python/dalex/dalex/arena/plots/_fairness_check_container.py b/python/dalex/dalex/arena/plots/_fairness_check_container.py new file mode 100644 index 000000000..9e7b66e9f --- /dev/null +++ b/python/dalex/dalex/arena/plots/_fairness_check_container.py @@ -0,0 +1,53 @@ +import pandas as pd +import numpy as np +from pandas.api.types import is_object_dtype +from .._plot_container import PlotContainer +from dalex.fairness._group_fairness import utils, checks + +def rm_nan(obj): + return { k: (None if np.isnan(obj[k]) or np.isinf(obj[k]) else obj[k]) for k in obj.keys() } + +class FairnessCheckContainer(PlotContainer): + info = { + 'name': 'Fairness', + 'plotType': 'Fairness', + 'plotCategory': 'Dataset Level', + 'requiredParams': ['model', 'variable'] + } + options = { + 'cutoffs': { 'default': [x / 100 for x in range(5, 100, 5)], 'desc': 'List of tested cutoff levels' }, + } + def _fit(self, model, variable): + if not variable.variable in model.variables: + raise Exception('Variable is not a column of explainer') + exp = model.explainer + y_hat = exp.predict(exp.data) if exp.y_hat is None else exp.y_hat + protected = exp.data[variable.variable] + if exp.model_type != 'classification': + self.set_message('Fairness plot is only available for classificators') + return + if not is_object_dtype(protected): + self.set_message('Select categorical variable to check fairness') + return + + output_df = None + for cutoff in self.arena.get_option(self.plot_type, 'cutoffs'): + cutoff_dict = checks.check_cutoff(protected, cutoff, False) + sub_confusion_matrix = utils.SubgroupConfusionMatrix(exp.y, y_hat, protected, cutoff_dict) + sub_confusion_matrix_metrics = utils.SubgroupConfusionMatrixMetrics(sub_confusion_matrix) + df = sub_confusion_matrix_metrics.to_vertical_DataFrame() + df['cutoff'] = cutoff + output_df = df if output_df is None else output_df.append(df) + + output = {} + for (subgroup, x) in output_df.set_index('metric').groupby('subgroup'): + output[subgroup] = {} + for (cutoff, y) in x.groupby('cutoff'): + output[subgroup][cutoff] = rm_nan(y['score'].to_dict()) + + self.data = { 'subgroups': output } + + def test_arena(arena): + if type(arena).__name__ != 'Arena' or type(arena).__module__ != 'dalex.arena.object': + raise Exception('Invalid Arena argument') + return next((True for model in arena.get_params('model') if model.explainer.model_type == 'classification'), False) diff --git a/python/dalex/dalex/arena/server.py b/python/dalex/dalex/arena/server.py index d9ce9ea14..8e6681e8c 100644 --- a/python/dalex/dalex/arena/server.py +++ b/python/dalex/dalex/arena/server.py @@ -65,7 +65,7 @@ def get_params(request): @app.route("/", methods=['GET']) def get_plot(plot_type): if plot_type == 'timestamp': - return {'timestamp': arena.timestamp * 1000} + return Response(json.dumps({'timestamp': arena.timestamp * 1000}, default=convert), content_type='application/json') elif plot_type == 'shutdown': if request.args.get('token') != shutdown_token: abort(403) diff --git a/python/dalex/test/test_arena_classification.py b/python/dalex/test/test_arena_classification.py index 0eb247070..aaa40f88b 100644 --- a/python/dalex/test/test_arena_classification.py +++ b/python/dalex/test/test_arena_classification.py @@ -54,7 +54,7 @@ def setUp(self): # This plots should be supported self.reference_plots = [ROCContainer, ShapleyValuesContainer, BreakDownContainer, CeterisParibusContainer, - FeatureImportanceContainer, PartialDependenceContainer, AccumulatedDependenceContainer, MetricsContainer] + FeatureImportanceContainer, PartialDependenceContainer, AccumulatedDependenceContainer, MetricsContainer, FairnessCheckContainer] def test_supported_plots(self): arena = dx.Arena()