From 489233c943e9dd7a5d8eafb72e3ab268f5919e30 Mon Sep 17 00:00:00 2001 From: Ines Oliveira e Silva Date: Thu, 1 Feb 2024 14:08:29 +0000 Subject: [PATCH] Add method to Experiment to plot results (#158) * Added plot method to Experiment class * Saving plot in experiment * Changed default split to test * Passing default dataset to plotting function --- src/aequitas/flow/experiment/experiment.py | 32 +++++++++++++++++----- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/aequitas/flow/experiment/experiment.py b/src/aequitas/flow/experiment/experiment.py index 9d15aaec..1f620fce 100644 --- a/src/aequitas/flow/experiment/experiment.py +++ b/src/aequitas/flow/experiment/experiment.py @@ -18,7 +18,8 @@ from ..methods.preprocessing.identity import Identity as PreIdentity from ..methods.preprocessing.preprocessing import PreProcessing from ..optimization import ObjectiveFunction -from ..utils import ConfigReader, create_logger, import_object +from ..utils import ConfigReader, create_logger, import_object, read_results +from ..plots.pareto import Plot class Experiment: @@ -163,7 +164,7 @@ def _read_datasets(self): def run(self) -> None: self.logger.info("Beginning Experiment.") - exp_folder: Path = Path() + self.exp_folder: Path = Path() dataset_folder: Path = Path() if self.artifacts: if self.save_folder is None: @@ -172,14 +173,14 @@ def run(self) -> None: if not self.hash: self.generate_hash() - exp_folder: Path = self.save_folder / self.hash - exp_folder.mkdir(parents=True, exist_ok=True) - self.logger.info(f"Saving objects to '{exp_folder.resolve()}'.") + self.exp_folder: Path = self.save_folder / self.hash + self.exp_folder.mkdir(parents=True, exist_ok=True) + self.logger.info(f"Saving objects to '{self.exp_folder.resolve()}'.") self._read_datasets() for dataset_id, (dataset_name, dataset) in enumerate(self.datasets.items()): self.logger.info(f"Using '{dataset_name}'.") if self.artifacts: - dataset_folder: Path = exp_folder / dataset_name + dataset_folder: Path = self.exp_folder / dataset_name dataset_folder.mkdir(parents=True, exist_ok=True) self.logger.debug( f"Saving dataset-related objects to '{dataset_folder.resolve()}'." @@ -187,7 +188,8 @@ def run(self) -> None: for method in self.config.methods: for method_name, method_items in method.items(): self.logger.info( - f"Testing '{method_name}', saved in '{dataset_folder.resolve()}'." + f"Testing '{method_name}', " + f"saved in '{dataset_folder.resolve()}'." ) if self.artifacts: method_folder: Path = dataset_folder / method_name @@ -278,3 +280,19 @@ def dt_handler(x): json.dumps(self.config, default=dt_handler, sort_keys=True).encode("utf-8") ).hexdigest() self.logger.debug(f"Hash generated: {self.hash}.") + + def plot_pareto( + self, + dataset: str = "Dataset", + fairness_metric: str = "Predictive Equality", + performance_metric: str = "TPR", + split: str = "test", + ) -> None: + results = read_results(self.exp_folder) + if dataset not in results: + raise ValueError(f"Dataset '{dataset}' was not used in experiment." + f" Try on of the following: {list(results.keys())}") + self.plot = Plot( + results, dataset, fairness_metric, performance_metric, split=split + ) + return self.plot.visualize()