Skip to content

Commit

Permalink
Label flipping corrections (#148)
Browse files Browse the repository at this point in the history
* Fixed label flipping docstring and disparity calculation

* Fixed spelling mistakes

* Fixed max prevalence to mean prevalence

* Faster disparity fair ordering

* Deleted old code
  • Loading branch information
reluzita authored Jan 30, 2024
1 parent 0f079a4 commit 3f94bdf
Showing 1 changed file with 51 additions and 28 deletions.
79 changes: 51 additions & 28 deletions src/aequitas/flow/methods/preprocessing/label_flipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import inspect
import pandas as pd
import math
from typing import Optional, Tuple, Literal, Union, Callable
import numpy as np
from sklearn.ensemble import BaggingClassifier
Expand All @@ -15,8 +16,8 @@
class LabelFlipping(PreProcessing):
def __init__(
self,
flip_rate: float = 0.1,
disparity_target: Optional[float] = None,
max_flip_rate: float = 0.1,
disparity_target: Optional[float] = 0.05,
score_threshold: Optional[float] = None,
bagging_max_samples: float = 0.5,
bagging_base_estimator: Union[
Expand All @@ -34,8 +35,16 @@ def __init__(
Parameters
----------
flip_rate : float, optional
max_flip_rate : float, optional
Maximum fraction of the training data to flip, by default 0.1
disparity_target : float, optional
The target disparity between the groups (difference between the prevalence
of a group and the mean prevalence). By default None, which means the
method will attempt to equalize the prevalence of the groups.
score_threshold : float, optional
The threshold above which the labels are flipped. By default None,
which means the method will flip the labels of the instances with
a score value higher than 0.
bagging_max_samples : float, optional
The number of samples to draw from X to train each base estimator of the
bagging classifier (with replacement).
Expand All @@ -45,17 +54,17 @@ def __init__(
bagging_n_estimators : int, optional
The number of base estimators in the ensemble, by default 10.
fair_ordering : bool, optional
Whether to take additional fairness criteria into account when flipping
Whether to take additional fairness criteria into account when flipping
labels, only modifying the labels that contribute to equalizing the
prevalence of the groups. By default True.
ordering_method : str, optional
The method used to calculate the margin of the base estimator. If
"ensemble_margin", calculates the ensemble margins based on the binary
predictions of the classifiers. If "residuals", oreders the missclafied
instances based on the average residuals of the classifiers predictions. By
The method used to calculate the margin of the base estimator. If
"ensemble_margin", calculates the ensemble margins based on the binary
predictions of the classifiers. If "residuals", orders the misclassified
instances based on the average residuals of the classifiers predictions. By
default "ensemble_margin".
unawareness_features : list, optional
The sensitive attributes (or proxies) to ignore when fitting the ensemble
The sensitive attributes (or proxies) to ignore when fitting the ensemble
to enable fairness through unawareness.
seed : int, optional
The seed to use when fitting the ensemble.
Expand All @@ -67,17 +76,17 @@ def __init__(
>>> from aequitas.preprocessing import LabelFlipping
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.datasets import make_classification
>>> X, y = make_classification(n_samples=1000, n_features=10, n_informative=5,
>>> X, y = make_classification(n_samples=1000, n_features=10, n_informative=5,
n_redundant=0, random_state=42)
>>> lf = LabelFlipping(bagging_base_estimator=DecisionTreeClassifier,
flip_rate=0.1, max_depth=3)
>>> lf = LabelFlipping(bagging_base_estimator=DecisionTreeClassifier,
max_flip_rate=0.1, max_depth=3)
>>> lf.fit(X, y)
>>> X_transformed, y_transformed = lf.transform(X, y)
"""
self.logger = create_logger("methods.preprocessing.LabelFlipping")
self.logger.info("Instantiating a LabelFlipping preprocessing method.")

self.flip_rate = flip_rate
self.max_flip_rate = max_flip_rate

if disparity_target is not None:
if disparity_target < 0 or disparity_target > 1:
Expand Down Expand Up @@ -114,7 +123,7 @@ def __init__(
self.bagging_base_estimator = bagging_base_estimator(**args)
self.logger.info(
f"Created base estimator {self.bagging_base_estimator} with params {args}, "
F"discarded args:{list(set(base_estimator_args.keys()) - set(args.keys()))}"
f"discarded args:{list(set(base_estimator_args.keys()) - set(args.keys()))}"
)
self.bagging_n_estimators = bagging_n_estimators

Expand Down Expand Up @@ -159,8 +168,8 @@ def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None:
def _score_instances(self, X: pd.DataFrame, y: pd.Series) -> pd.Series:
"""Scores the instances based on the predictions of the ensemble of classifiers.
If the ordering method is "ensemble_margin", the scores are the ensemble
margins. If the ordering method is "residuals", the scores are the average
If the ordering method is "ensemble_margin", the scores are the ensemble
margins. If the ordering method is "residuals", the scores are the average
residuals of the classifiers predictions.
Parameters
Expand Down Expand Up @@ -202,19 +211,28 @@ def _score_instances(self, X: pd.DataFrame, y: pd.Series) -> pd.Series:

return scores

def _calculate_prevalence_disparity(self, y: pd.Series, s: pd.Series):
def _calculate_group_flips(self, y: pd.Series, s: pd.Series):
prevalence = y.mean()
group_prevalence = y.groupby(s).mean().to_dict()
group_disparity = {k: v - prevalence for k, v in group_prevalence.items()}
group_prevalences = y.groupby(s).mean()

return group_disparity
min_prevalence = prevalence - self.disparity_target * prevalence
max_prevalence = prevalence + self.disparity_target * prevalence

group_flips = {
group: math.ceil(min_prevalence * len(y[s == group])) - y[s == group].sum()
if group_prevalences[group] < min_prevalence
else math.floor(max_prevalence * len(y[s == group])) - y[s == group].sum()
for group in group_prevalences.index
}

return group_flips

def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Series):
"""Flips the labels of the desired fraction of the training data.
If fair_ordering is True, only flips the labels of the instances that contribute
to equalizing the prevalence of the groups.
Otherwise, the labels of the instances with the largest score values are
Otherwise, the labels of the instances with the largest score values are
flipped.
Parameters
Expand All @@ -236,10 +254,10 @@ def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Serie
ascending=(self.ordering_method == "ensemble_margin")
).index
)
n_flip = int(self.flip_rate * len(y))
n_flip = int(self.max_flip_rate * len(y))

if self.fair_ordering:
disparity = self._calculate_prevalence_disparity(y_flipped, s)
group_flips = self._calculate_group_flips(y_flipped, s)
flip_index = (
y_flipped.index
if self.ordering_method == "residuals"
Expand All @@ -251,12 +269,15 @@ def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Serie
if abs(scores.loc[i]) < self.score_threshold:
break

if (disparity[s.loc[i]] > self.disparity_target and y.loc[i] == 1) or (
disparity[s.loc[i]] < self.disparity_target and y.loc[i] == 0
if (group_flips[s.loc[i]] > 0 and y.loc[i] == 0) or (
group_flips[s.loc[i]] < 0 and y.loc[i] == 1
):
y_flipped.loc[i] = 1 - y.loc[i]
disparity = self._calculate_prevalence_disparity(y_flipped, s)
flip_count += 1
if group_flips[s.loc[i]] > 0:
group_flips[s.loc[i]] -= 1
else:
group_flips[s.loc[i]] += 1

if flip_count == n_flip:
break
Expand All @@ -282,7 +303,9 @@ def transform(
Parameters
----------
X : pd.DataFrame
Feature[s.loc[i]]ector.
Feature matrix.
y : pd.Series
Label vector.
s : pd.Series, optional
Protected attribute vector.
Expand All @@ -295,7 +318,7 @@ def transform(

if s is None and self.fair_ordering:
raise ValueError(
"Sensitive Attribute `s` not passed. Must be passed if `fair_ordering` "
"Sensitive Attribute `s` not passed. Must be passed if `fair_ordering` "
"is True."
)

Expand Down

0 comments on commit 3f94bdf

Please sign in to comment.