Skip to content

Commit

Permalink
Add method to Experiment to plot results (#158)
Browse files Browse the repository at this point in the history
* Added plot method to Experiment class

* Saving plot in experiment

* Changed default split to test

* Passing default dataset to plotting function
  • Loading branch information
reluzita authored Feb 1, 2024
1 parent cdff72e commit 489233c
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/aequitas/flow/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from ..methods.preprocessing.identity import Identity as PreIdentity
from ..methods.preprocessing.preprocessing import PreProcessing
from ..optimization import ObjectiveFunction
from ..utils import ConfigReader, create_logger, import_object
from ..utils import ConfigReader, create_logger, import_object, read_results
from ..plots.pareto import Plot


class Experiment:
Expand Down Expand Up @@ -163,7 +164,7 @@ def _read_datasets(self):

def run(self) -> None:
self.logger.info("Beginning Experiment.")
exp_folder: Path = Path()
self.exp_folder: Path = Path()
dataset_folder: Path = Path()
if self.artifacts:
if self.save_folder is None:
Expand All @@ -172,22 +173,23 @@ def run(self) -> None:
if not self.hash:
self.generate_hash()

exp_folder: Path = self.save_folder / self.hash
exp_folder.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Saving objects to '{exp_folder.resolve()}'.")
self.exp_folder: Path = self.save_folder / self.hash
self.exp_folder.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Saving objects to '{self.exp_folder.resolve()}'.")
self._read_datasets()
for dataset_id, (dataset_name, dataset) in enumerate(self.datasets.items()):
self.logger.info(f"Using '{dataset_name}'.")
if self.artifacts:
dataset_folder: Path = exp_folder / dataset_name
dataset_folder: Path = self.exp_folder / dataset_name
dataset_folder.mkdir(parents=True, exist_ok=True)
self.logger.debug(
f"Saving dataset-related objects to '{dataset_folder.resolve()}'."
)
for method in self.config.methods:
for method_name, method_items in method.items():
self.logger.info(
f"Testing '{method_name}', saved in '{dataset_folder.resolve()}'."
f"Testing '{method_name}', "
f"saved in '{dataset_folder.resolve()}'."
)
if self.artifacts:
method_folder: Path = dataset_folder / method_name
Expand Down Expand Up @@ -278,3 +280,19 @@ def dt_handler(x):
json.dumps(self.config, default=dt_handler, sort_keys=True).encode("utf-8")
).hexdigest()
self.logger.debug(f"Hash generated: {self.hash}.")

def plot_pareto(
self,
dataset: str = "Dataset",
fairness_metric: str = "Predictive Equality",
performance_metric: str = "TPR",
split: str = "test",
) -> None:
results = read_results(self.exp_folder)
if dataset not in results:
raise ValueError(f"Dataset '{dataset}' was not used in experiment."
f" Try on of the following: {list(results.keys())}")
self.plot = Plot(
results, dataset, fairness_metric, performance_metric, split=split
)
return self.plot.visualize()

0 comments on commit 489233c

Please sign in to comment.