diff --git a/src/aequitas/flow/experiment/experiment.py b/src/aequitas/flow/experiment/experiment.py index ea243bac..714477b2 100644 --- a/src/aequitas/flow/experiment/experiment.py +++ b/src/aequitas/flow/experiment/experiment.py @@ -17,7 +17,8 @@ from ..methods.preprocessing.identity import Identity as PreIdentity from ..methods.preprocessing.preprocessing import PreProcessing from ..optimization import ObjectiveFunction -from ..utils import ConfigReader, create_logger, import_object +from ..utils import ConfigReader, create_logger, import_object, read_results +from ..plots.pareto import Plot class Experiment: @@ -143,7 +144,7 @@ def _read_datasets(self): def run(self) -> None: self.logger.info("Beginning Experiment.") - exp_folder: Path = Path() + self.exp_folder: Path = Path() dataset_folder: Path = Path() if self.artifacts: if self.save_folder is None: @@ -152,14 +153,14 @@ def run(self) -> None: if not self.hash: self.generate_hash() - exp_folder: Path = self.save_folder / self.hash - exp_folder.mkdir(parents=True, exist_ok=True) - self.logger.info(f"Saving objects to '{exp_folder.resolve()}'.") + self.exp_folder: Path = self.save_folder / self.hash + self.exp_folder.mkdir(parents=True, exist_ok=True) + self.logger.info(f"Saving objects to '{self.exp_folder.resolve()}'.") self._read_datasets() for dataset_id, (dataset_name, dataset) in enumerate(self.datasets.items()): self.logger.info(f"Using '{dataset_name}'.") if self.artifacts: - dataset_folder: Path = exp_folder / dataset_name + dataset_folder: Path = self.exp_folder / dataset_name dataset_folder.mkdir(parents=True, exist_ok=True) self.logger.debug( f"Saving dataset-related objects to '{dataset_folder.resolve()}'." @@ -167,7 +168,8 @@ def run(self) -> None: for method in self.config.methods: for method_name, method_items in method.items(): self.logger.info( - f"Testing '{method_name}', saved in '{dataset_folder.resolve()}'." + f"Testing '{method_name}', " + f"saved in '{dataset_folder.resolve()}'." ) if self.artifacts: method_folder: Path = dataset_folder / method_name @@ -258,3 +260,19 @@ def dt_handler(x): json.dumps(self.config, default=dt_handler, sort_keys=True).encode("utf-8") ).hexdigest() self.logger.debug(f"Hash generated: {self.hash}.") + + def plot_pareto( + self, + dataset: str = "Dataset", + fairness_metric: str = "Predictive Equality", + performance_metric: str = "TPR", + split: str = "test", + ) -> None: + results = read_results(self.exp_folder) + if dataset not in results: + raise ValueError(f"Dataset '{dataset}' was not used in experiment." + f" Try on of the following: {list(results.keys())}") + self.plot = Plot( + results, dataset, fairness_metric, performance_metric, split=split + ) + return self.plot.visualize()