Skip to content

Commit

Permalink
Merge pull request #8 from acostapazo/feature/add-tests
Browse files Browse the repository at this point in the history
ENH: add some missing unit testing
  • Loading branch information
acostapazo authored Jan 10, 2022
2 parents a28c457 + fbe4cf2 commit 3df0eab
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 18 deletions.
2 changes: 1 addition & 1 deletion gradgpad/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gradgpad.foundations.metrics.metric import * # noqa

from .create_dataframe import * # noqa
# from .create_dataframe import * # noqa
from .group_dataframe import * # noqa
from .open_result_json import * # noqa
28 changes: 14 additions & 14 deletions gradgpad/tools/create_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pandas as pd


def create_dataframe(metric_retriever, results):

data = {"Metric": [], "Error Rate (%)": [], "Protocol": []}

for protocol_name, performance_info in sorted(results.items()):
metric, value = metric_retriever(performance_info)
data["Metric"].append(metric)
data["Error Rate (%)"].append(value)
data["Protocol"].append(protocol_name)
df = pd.DataFrame(data, columns=list(data.keys()))
return df
# import pandas as pd
#
#
# def create_dataframe(metric_retriever, results):
#
# data = {"Metric": [], "Error Rate (%)": [], "Protocol": []}
#
# for protocol_name, performance_info in sorted(results.items()):
# metric, value = metric_retriever(performance_info)
# data["Metric"].append(metric)
# data["Error Rate (%)"].append(value)
# data["Protocol"].append(protocol_name)
# df = pd.DataFrame(data, columns=list(data.keys()))
# return df
4 changes: 1 addition & 3 deletions gradgpad/tools/visualization/histogram/histogram_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,7 @@ def save(self, output_filename: str, scores: Scores):

if not os.path.isdir(os.path.dirname(output_filename)):
raise IOError(
"Output path [{}] does not exist".format(
os.path.dirname(output_filename)
)
f"Output path [{os.path.dirname(output_filename)}] does not exist"
)

plt = self.create_figure(scores)
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/foundations/scores/test_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,35 @@ def test_should_load_scores_with_fair_skin_tone_subset(scores):
fair_skin_tone_subset = scores.get_fair_skin_tone_subset()
for skin_tone in SkinTone.options():
assert len(fair_skin_tone_subset.get(skin_tone.name)) == 34 # 113


@pytest.mark.unit
@pytest.mark.parametrize("scores", scores_approaches_test)
def test_should_load_scores_filter_with_random_values(scores):
assert len(scores.filtered_by(Filter(random_values=10))) == 10


@pytest.mark.unit
@pytest.mark.parametrize("scores", scores_approaches_test)
def test_should_load_scores_filter_with_pseudo_random_values(scores):
assert len(scores.filtered_by(Filter(pseudo_random_values=10))) == 10


@pytest.mark.unit
@pytest.mark.parametrize("scores", scores_approaches_test)
def test_should_load_scores_get_genuine(scores):
assert len(scores.get_genuine()) == 2281


@pytest.mark.unit
@pytest.mark.parametrize("scores", scores_approaches_test)
def test_should_load_scores_get_attacks(scores):
assert len(scores.get_attacks()) == 10209


@pytest.mark.unit
@pytest.mark.parametrize("scores", scores_approaches_test)
def test_should_load_scores_get_numpy_scores_and_labels_filtered_by_labels(scores):
scores, labels = scores.get_numpy_scores_and_labels_filtered_by_labels()
assert len(scores) == 12490
assert len(labels) == 12490
6 changes: 6 additions & 0 deletions tests/unit/foundations/scores/test_scores_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
)


@pytest.mark.unit
def test_should_success_get_all_scores_auxiliary():
scores = ScoresProvider.all(Approach.AUXILIARY)
assert len(scores.keys()) == 35


@pytest.mark.unit
@pytest.mark.parametrize(
"approach,protocol,subset,expected_scores_length",
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/tools/visualization/histogram/test_histogram_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

import pytest

from gradgpad import (
Approach,
HistogramPlotter,
Protocol,
ScoresProvider,
SplitByLabelMode,
Subset,
)


@pytest.mark.unit
@pytest.mark.parametrize("split_by_label_mode", SplitByLabelMode.options_for_curves())
def test_should_save_a_histogram_plotter(split_by_label_mode: SplitByLabelMode):
os.makedirs("output", exist_ok=True)
output_filename = "output/radar.png"
scores = ScoresProvider.get(Approach.AUXILIARY, Protocol.GRANDTEST, Subset.TEST)
plotter = HistogramPlotter(
title="My Title",
split_by_label_mode=split_by_label_mode,
)
plotter.save(output_filename=output_filename, scores=scores)
assert os.path.isfile(output_filename)


@pytest.mark.unit
def test_should_raise_a_io_error_when_save_a_histogram_plotter_with_no_valid_output_filename():
plotter = HistogramPlotter(
title="My Title",
split_by_label_mode=SplitByLabelMode.DATASET,
)
pytest.raises(
IOError,
lambda: plotter.save(output_filename="not_valid_folder/name", scores=None),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from gradgpad import CombinedScenario, FineGrainedPaisProvider


@pytest.mark.unit
@pytest.mark.parametrize("combined_scenario", CombinedScenario.options())
def test_should_get_fine_grained_pais_provider(combined_scenario: CombinedScenario):
FineGrainedPaisProvider.by(combined_scenario)

0 comments on commit 3df0eab

Please sign in to comment.