diff --git a/gradgpad/tools/__init__.py b/gradgpad/tools/__init__.py index b447941..8ed9872 100644 --- a/gradgpad/tools/__init__.py +++ b/gradgpad/tools/__init__.py @@ -1,5 +1,5 @@ from gradgpad.foundations.metrics.metric import * # noqa -from .create_dataframe import * # noqa +# from .create_dataframe import * # noqa from .group_dataframe import * # noqa from .open_result_json import * # noqa diff --git a/gradgpad/tools/create_dataframe.py b/gradgpad/tools/create_dataframe.py index 1857381..a289be9 100644 --- a/gradgpad/tools/create_dataframe.py +++ b/gradgpad/tools/create_dataframe.py @@ -1,14 +1,14 @@ -import pandas as pd - - -def create_dataframe(metric_retriever, results): - - data = {"Metric": [], "Error Rate (%)": [], "Protocol": []} - - for protocol_name, performance_info in sorted(results.items()): - metric, value = metric_retriever(performance_info) - data["Metric"].append(metric) - data["Error Rate (%)"].append(value) - data["Protocol"].append(protocol_name) - df = pd.DataFrame(data, columns=list(data.keys())) - return df +# import pandas as pd +# +# +# def create_dataframe(metric_retriever, results): +# +# data = {"Metric": [], "Error Rate (%)": [], "Protocol": []} +# +# for protocol_name, performance_info in sorted(results.items()): +# metric, value = metric_retriever(performance_info) +# data["Metric"].append(metric) +# data["Error Rate (%)"].append(value) +# data["Protocol"].append(protocol_name) +# df = pd.DataFrame(data, columns=list(data.keys())) +# return df diff --git a/gradgpad/tools/visualization/histogram/histogram_plotter.py b/gradgpad/tools/visualization/histogram/histogram_plotter.py index c08271d..759e66d 100755 --- a/gradgpad/tools/visualization/histogram/histogram_plotter.py +++ b/gradgpad/tools/visualization/histogram/histogram_plotter.py @@ -240,9 +240,7 @@ def save(self, output_filename: str, scores: Scores): if not os.path.isdir(os.path.dirname(output_filename)): raise IOError( - "Output path [{}] does not exist".format( - os.path.dirname(output_filename) - ) + f"Output path [{os.path.dirname(output_filename)}] does not exist" ) plt = self.create_figure(scores) diff --git a/tests/unit/foundations/scores/test_scores.py b/tests/unit/foundations/scores/test_scores.py index 2d11edf..ce378c0 100644 --- a/tests/unit/foundations/scores/test_scores.py +++ b/tests/unit/foundations/scores/test_scores.py @@ -116,3 +116,35 @@ def test_should_load_scores_with_fair_skin_tone_subset(scores): fair_skin_tone_subset = scores.get_fair_skin_tone_subset() for skin_tone in SkinTone.options(): assert len(fair_skin_tone_subset.get(skin_tone.name)) == 34 # 113 + + +@pytest.mark.unit +@pytest.mark.parametrize("scores", scores_approaches_test) +def test_should_load_scores_filter_with_random_values(scores): + assert len(scores.filtered_by(Filter(random_values=10))) == 10 + + +@pytest.mark.unit +@pytest.mark.parametrize("scores", scores_approaches_test) +def test_should_load_scores_filter_with_pseudo_random_values(scores): + assert len(scores.filtered_by(Filter(pseudo_random_values=10))) == 10 + + +@pytest.mark.unit +@pytest.mark.parametrize("scores", scores_approaches_test) +def test_should_load_scores_get_genuine(scores): + assert len(scores.get_genuine()) == 2281 + + +@pytest.mark.unit +@pytest.mark.parametrize("scores", scores_approaches_test) +def test_should_load_scores_get_attacks(scores): + assert len(scores.get_attacks()) == 10209 + + +@pytest.mark.unit +@pytest.mark.parametrize("scores", scores_approaches_test) +def test_should_load_scores_get_numpy_scores_and_labels_filtered_by_labels(scores): + scores, labels = scores.get_numpy_scores_and_labels_filtered_by_labels() + assert len(scores) == 12490 + assert len(labels) == 12490 diff --git a/tests/unit/foundations/scores/test_scores_provider.py b/tests/unit/foundations/scores/test_scores_provider.py index 7c30674..b4a76e6 100644 --- a/tests/unit/foundations/scores/test_scores_provider.py +++ b/tests/unit/foundations/scores/test_scores_provider.py @@ -11,6 +11,12 @@ ) +@pytest.mark.unit +def test_should_success_get_all_scores_auxiliary(): + scores = ScoresProvider.all(Approach.AUXILIARY) + assert len(scores.keys()) == 35 + + @pytest.mark.unit @pytest.mark.parametrize( "approach,protocol,subset,expected_scores_length", diff --git a/tests/unit/tools/visualization/histogram/test_histogram_plotter.py b/tests/unit/tools/visualization/histogram/test_histogram_plotter.py new file mode 100644 index 0000000..9fd5a92 --- /dev/null +++ b/tests/unit/tools/visualization/histogram/test_histogram_plotter.py @@ -0,0 +1,38 @@ +import os + +import pytest + +from gradgpad import ( + Approach, + HistogramPlotter, + Protocol, + ScoresProvider, + SplitByLabelMode, + Subset, +) + + +@pytest.mark.unit +@pytest.mark.parametrize("split_by_label_mode", SplitByLabelMode.options_for_curves()) +def test_should_save_a_histogram_plotter(split_by_label_mode: SplitByLabelMode): + os.makedirs("output", exist_ok=True) + output_filename = "output/radar.png" + scores = ScoresProvider.get(Approach.AUXILIARY, Protocol.GRANDTEST, Subset.TEST) + plotter = HistogramPlotter( + title="My Title", + split_by_label_mode=split_by_label_mode, + ) + plotter.save(output_filename=output_filename, scores=scores) + assert os.path.isfile(output_filename) + + +@pytest.mark.unit +def test_should_raise_a_io_error_when_save_a_histogram_plotter_with_no_valid_output_filename(): + plotter = HistogramPlotter( + title="My Title", + split_by_label_mode=SplitByLabelMode.DATASET, + ) + pytest.raises( + IOError, + lambda: plotter.save(output_filename="not_valid_folder/name", scores=None), + ) diff --git a/tests/unit/tools/visualization/radar/test_fine_grained_pais_provider.py b/tests/unit/tools/visualization/radar/test_fine_grained_pais_provider.py new file mode 100644 index 0000000..75ead84 --- /dev/null +++ b/tests/unit/tools/visualization/radar/test_fine_grained_pais_provider.py @@ -0,0 +1,9 @@ +import pytest + +from gradgpad import CombinedScenario, FineGrainedPaisProvider + + +@pytest.mark.unit +@pytest.mark.parametrize("combined_scenario", CombinedScenario.options()) +def test_should_get_fine_grained_pais_provider(combined_scenario: CombinedScenario): + FineGrainedPaisProvider.by(combined_scenario)