Skip to content

Commit

Permalink
Merge pull request #6 from acostapazo/feature/improve-tests
Browse files Browse the repository at this point in the history
ENH: add some missing unit testing
  • Loading branch information
acostapazo authored Jan 9, 2022
2 parents b45f97e + f6ea398 commit 11f1649
Show file tree
Hide file tree
Showing 14 changed files with 267 additions and 155 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[run]
omit = ./venv/*,*tests*,*__init__.py,./gradgpad/reproducible_research/cli/*,./gradgpad/cli/*
3 changes: 3 additions & 0 deletions gradgpad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from gradgpad.tools.visualization.gif_creator import GifCreator
from gradgpad.tools.visualization.histogram.histogram_plotter import HistogramPlotter
from gradgpad.tools.visualization.histogram.split_by_level_mode import SplitByLabelMode
from gradgpad.tools.visualization.percentile.bias_percentile_comparison_plotter import (
BiasPercentileComparisonPlotter,
)
from gradgpad.tools.visualization.percentile.bias_percentile_plotter import (
BiasPercentilePlotter,
)
Expand Down
186 changes: 93 additions & 93 deletions gradgpad/foundations/metrics/get_target_value_fixing_working_point.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,93 @@
import numpy as np

VALUE_NOT_VALID_WORKING_POINT = 1.0
THRESHOLD_NOT_VALID_WORKING_POINT = -1.0
NOT_VALID_WORKING_POINT = (
VALUE_NOT_VALID_WORKING_POINT,
VALUE_NOT_VALID_WORKING_POINT,
THRESHOLD_NOT_VALID_WORKING_POINT,
)


def get_target_value_fixing_working_point(
fixed_working_point, targeted_values, fixing_values, thresholds, interpolated=False
):
if not isinstance(targeted_values, np.ndarray):
targeted_values = np.array(targeted_values)
if not isinstance(fixing_values, np.ndarray):
fixing_values = np.array(fixing_values)
if not isinstance(thresholds, np.ndarray):
thresholds = np.array(thresholds)

if (targeted_values == fixing_values).all():
return NOT_VALID_WORKING_POINT

lower_near_idx = np.abs(targeted_values - fixed_working_point).argmin()

if (
not interpolated
or lower_near_idx == 0
and targeted_values[lower_near_idx] > fixed_working_point
):
return (
targeted_values[lower_near_idx],
fixing_values[lower_near_idx],
thresholds[lower_near_idx],
)
else:
upper_near_idx = lower_near_idx
if lower_near_idx == len(targeted_values) - 1 or lower_near_idx == 0:
upper_near_idx == lower_near_idx
elif targeted_values[lower_near_idx] < fixed_working_point:
if targeted_values[lower_near_idx + 1] >= targeted_values[lower_near_idx]:
upper_near_idx = lower_near_idx + 1
else:
upper_near_idx = lower_near_idx - 1
else:
if targeted_values[lower_near_idx + 1] <= targeted_values[lower_near_idx]:
upper_near_idx = lower_near_idx + 1
else:
upper_near_idx = lower_near_idx - 1

if (
lower_near_idx > upper_near_idx
): # targeted_values[lower_near_idx] >= fixed_working_point >= targeted_values[upper_near_idx]:
l_idx = lower_near_idx
lower_near_idx = upper_near_idx
upper_near_idx = l_idx

# if targeted_values[near_idx] <= fixed_working_point <= targeted_values[near_idx - 1]:
# if targeted_values[near_idx] <= fixed_working_point:
# if targeted_values[near_idx] >= fixed_working_point:
# lower_idx = near_idx - 1
# else:
# lower_idx = near_idx
# near_idx = near_idx + 1

# if near_idx > 99: # targeted_values.size - 1
# near_idx = near_idx - 1
# lower_idx = lower_idx - 1

x0 = targeted_values[lower_near_idx]
x1 = targeted_values[upper_near_idx]

y0 = fixing_values[lower_near_idx]
y1 = fixing_values[upper_near_idx]
t0 = thresholds[lower_near_idx]
t1 = thresholds[upper_near_idx]

if t0 == t1 == 1.0:
return NOT_VALID_WORKING_POINT

if x1 - x0 == 0.0:
m = 0
mt = 0
else:
m = (y1 - y0) / (x1 - x0)
mt = (t1 - t0) / (x1 - x0)

target_val = fixed_working_point
fixed_val = (target_val - x0) * m + y0
threshold_val = (target_val - x0) * mt + t0

return target_val, fixed_val, threshold_val
# import numpy as np
#
# VALUE_NOT_VALID_WORKING_POINT = 1.0
# THRESHOLD_NOT_VALID_WORKING_POINT = -1.0
# NOT_VALID_WORKING_POINT = (
# VALUE_NOT_VALID_WORKING_POINT,
# VALUE_NOT_VALID_WORKING_POINT,
# THRESHOLD_NOT_VALID_WORKING_POINT,
# )
#
#
# def get_target_value_fixing_working_point(
# fixed_working_point, targeted_values, fixing_values, thresholds, interpolated=False
# ):
# if not isinstance(targeted_values, np.ndarray):
# targeted_values = np.array(targeted_values)
# if not isinstance(fixing_values, np.ndarray):
# fixing_values = np.array(fixing_values)
# if not isinstance(thresholds, np.ndarray):
# thresholds = np.array(thresholds)
#
# if (targeted_values == fixing_values).all():
# return NOT_VALID_WORKING_POINT
#
# lower_near_idx = np.abs(targeted_values - fixed_working_point).argmin()
#
# if (
# not interpolated
# or lower_near_idx == 0
# and targeted_values[lower_near_idx] > fixed_working_point
# ):
# return (
# targeted_values[lower_near_idx],
# fixing_values[lower_near_idx],
# thresholds[lower_near_idx],
# )
# else:
# upper_near_idx = lower_near_idx
# if lower_near_idx == len(targeted_values) - 1 or lower_near_idx == 0:
# upper_near_idx == lower_near_idx
# elif targeted_values[lower_near_idx] < fixed_working_point:
# if targeted_values[lower_near_idx + 1] >= targeted_values[lower_near_idx]:
# upper_near_idx = lower_near_idx + 1
# else:
# upper_near_idx = lower_near_idx - 1
# else:
# if targeted_values[lower_near_idx + 1] <= targeted_values[lower_near_idx]:
# upper_near_idx = lower_near_idx + 1
# else:
# upper_near_idx = lower_near_idx - 1
#
# if (
# lower_near_idx > upper_near_idx
# ): # targeted_values[lower_near_idx] >= fixed_working_point >= targeted_values[upper_near_idx]:
# l_idx = lower_near_idx
# lower_near_idx = upper_near_idx
# upper_near_idx = l_idx
#
# # if targeted_values[near_idx] <= fixed_working_point <= targeted_values[near_idx - 1]:
# # if targeted_values[near_idx] <= fixed_working_point:
# # if targeted_values[near_idx] >= fixed_working_point:
# # lower_idx = near_idx - 1
# # else:
# # lower_idx = near_idx
# # near_idx = near_idx + 1
#
# # if near_idx > 99: # targeted_values.size - 1
# # near_idx = near_idx - 1
# # lower_idx = lower_idx - 1
#
# x0 = targeted_values[lower_near_idx]
# x1 = targeted_values[upper_near_idx]
#
# y0 = fixing_values[lower_near_idx]
# y1 = fixing_values[upper_near_idx]
# t0 = thresholds[lower_near_idx]
# t1 = thresholds[upper_near_idx]
#
# if t0 == t1 == 1.0:
# return NOT_VALID_WORKING_POINT
#
# if x1 - x0 == 0.0:
# m = 0
# mt = 0
# else:
# m = (y1 - y0) / (x1 - x0)
# mt = (t1 - t0) / (x1 - x0)
#
# target_val = fixed_working_point
# fixed_val = (target_val - x0) * m + y0
# threshold_val = (target_val - x0) * mt + t0
#
# return target_val, fixed_val, threshold_val
4 changes: 0 additions & 4 deletions tests/unit/foundations/metrics/test_apcer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import numpy as np
import pytest

# scores = np.array([0.5, 0.6, 0.2, 0.0, 0.0])
# labels = np.array([1, 2, 2, 0, 0])
# expected_apcer = 0.0
# th_eer_dev = 0.15
from gradgpad.foundations.metrics.apcer import apcer


Expand Down
69 changes: 12 additions & 57 deletions tests/unit/foundations/metrics/test_generalization_metrics.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,19 @@
import pytest

from gradgpad.foundations.metrics.metrics import Metrics
from gradgpad.foundations.scores.approach import Approach
from gradgpad.foundations.scores.protocol import Protocol
from gradgpad.foundations.scores.scores_provider import ScoresProvider
from gradgpad.foundations.scores.subset import Subset
import os

import pytest

@pytest.mark.unit
@pytest.mark.parametrize(
"devel_scores,test_scores",
[
(
ScoresProvider.get(
approach=Approach.QUALITY_LINEAR,
protocol=Protocol.GRANDTEST,
subset=Subset.DEVEL,
),
ScoresProvider.get(
approach=Approach.QUALITY_LINEAR,
protocol=Protocol.GRANDTEST,
subset=Subset.TEST,
),
)
],
)
def test_should_calculate_eer_for_devel_and_test(devel_scores, test_scores):

metrics = Metrics(devel_scores, test_scores)
assert pytest.approx(metrics.get_eer(Subset.DEVEL), 0.01) == 0.269
assert pytest.approx(metrics.get_eer(Subset.TEST), 0.01) == 0.246
from gradgpad import Approach, GeneralizationMetrics, ResultsProvider


@pytest.mark.unit
@pytest.mark.parametrize(
"devel_scores,test_scores",
[
(
ScoresProvider.get(
approach=Approach.QUALITY_RBF,
protocol=Protocol.GRANDTEST,
subset=Subset.DEVEL,
),
ScoresProvider.get(
approach=Approach.QUALITY_RBF,
protocol=Protocol.GRANDTEST,
subset=Subset.TEST,
),
)
],
)
def test_should_calculate_indeepth_analysis(devel_scores, test_scores):

metrics = Metrics(devel_scores, test_scores)
bpcer_fixing_working_points = [0.10]
apcer_fixing_working_points = [0.10]

indepth_analysis = metrics.get_indepth_analysis(
bpcer_fixing_working_points, apcer_fixing_working_points
@pytest.mark.skip
def test_should_save_generalization_metrics():
os.makedirs("output", exist_ok=True)
output_filename = "output/generalization_metrics.txt"
all_results = ResultsProvider.all(Approach.AUXILIARY)
generalization_metrics = GeneralizationMetrics()
generalization_metrics.save(
output_filename=output_filename, all_results=all_results
)

assert pytest.approx(indepth_analysis["fine_grained_pai"]["acer"], 0.01) == 59.60
assert pytest.approx(indepth_analysis["coarse_grained_pai"]["acer"], 0.01) == 36.06
assert os.path.isfile(output_filename)
11 changes: 11 additions & 0 deletions tests/unit/tools/visualization/det/test_colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from gradgpad.tools.visualization.colors import get_color_random_style


@pytest.mark.unit
def test_should_obtain_a_random_color_style():
color, linestyles, markers = get_color_random_style()
assert isinstance(color, str)
assert isinstance(linestyles, str)
assert isinstance(markers, str)
37 changes: 37 additions & 0 deletions tests/unit/tools/visualization/det/test_det_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os

import pytest

from gradgpad import (
Approach,
DetPlotter,
Protocol,
ScoresProvider,
SplitByLabelMode,
Subset,
)


@pytest.mark.unit
@pytest.mark.parametrize("split_by_label_mode", SplitByLabelMode.options_for_curves())
def test_should_save_a_det_plotter(split_by_label_mode: SplitByLabelMode):
os.makedirs("output", exist_ok=True)
output_filename = "output/radar.png"
scores = ScoresProvider.get(Approach.AUXILIARY, Protocol.GRANDTEST, Subset.TEST)
plotter = DetPlotter(
title="My Title",
split_by_label_mode=split_by_label_mode,
)
plotter.save(output_filename=output_filename, scores=scores)
assert os.path.isfile(output_filename)


@pytest.mark.unit
def test_should_raise_a_type_error_when_save_a_det_plotter_with_no_valid_split_by_label_mode():
pytest.raises(
TypeError,
lambda: DetPlotter(
title="My Title",
split_by_label_mode=SplitByLabelMode.DATASET,
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import pytest

from gradgpad import (
Approach,
BiasPercentileComparisonPlotter,
Demographic,
Protocol,
ScoresProvider,
Subset,
)


@pytest.mark.unit
@pytest.mark.parametrize("demographic", Demographic.options())
def test_should_save_a_bias_percentile_comparison_plotter(demographic: Demographic):
os.makedirs("output", exist_ok=True)
output_filename = "output/radar.png"
scores = ScoresProvider.get(Approach.AUXILIARY, Protocol.GRANDTEST, Subset.TEST)
plotter = BiasPercentileComparisonPlotter(
title="My Title",
demographic=demographic,
working_point=(0.5, 0.7),
)
plotter.save(output_filename=output_filename, scores=scores)
assert os.path.isfile(output_filename)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import pytest

from gradgpad import (
Approach,
BiasPercentilePlotter,
Demographic,
Protocol,
ScoresProvider,
Subset,
)


@pytest.mark.unit
@pytest.mark.parametrize("demographic", Demographic.options())
def test_should_save_a_bias_percentile_plotter(demographic: Demographic):
os.makedirs("output", exist_ok=True)
output_filename = "output/radar.png"
scores = ScoresProvider.get(Approach.AUXILIARY, Protocol.GRANDTEST, Subset.TEST)
plotter = BiasPercentilePlotter(
title="My Title",
demographic=demographic,
working_point=(0.5, 0.7),
)
plotter.save(output_filename=output_filename, scores=scores)
assert os.path.isfile(output_filename)
Loading

0 comments on commit 11f1649

Please sign in to comment.