-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from acostapazo/feature/improve-tests
ENH: add some missing unit testing
- Loading branch information
Showing
14 changed files
with
267 additions
and
155 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[run] | ||
omit = ./venv/*,*tests*,*__init__.py,./gradgpad/reproducible_research/cli/*,./gradgpad/cli/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
186 changes: 93 additions & 93 deletions
186
gradgpad/foundations/metrics/get_target_value_fixing_working_point.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,93 +1,93 @@ | ||
import numpy as np | ||
|
||
VALUE_NOT_VALID_WORKING_POINT = 1.0 | ||
THRESHOLD_NOT_VALID_WORKING_POINT = -1.0 | ||
NOT_VALID_WORKING_POINT = ( | ||
VALUE_NOT_VALID_WORKING_POINT, | ||
VALUE_NOT_VALID_WORKING_POINT, | ||
THRESHOLD_NOT_VALID_WORKING_POINT, | ||
) | ||
|
||
|
||
def get_target_value_fixing_working_point( | ||
fixed_working_point, targeted_values, fixing_values, thresholds, interpolated=False | ||
): | ||
if not isinstance(targeted_values, np.ndarray): | ||
targeted_values = np.array(targeted_values) | ||
if not isinstance(fixing_values, np.ndarray): | ||
fixing_values = np.array(fixing_values) | ||
if not isinstance(thresholds, np.ndarray): | ||
thresholds = np.array(thresholds) | ||
|
||
if (targeted_values == fixing_values).all(): | ||
return NOT_VALID_WORKING_POINT | ||
|
||
lower_near_idx = np.abs(targeted_values - fixed_working_point).argmin() | ||
|
||
if ( | ||
not interpolated | ||
or lower_near_idx == 0 | ||
and targeted_values[lower_near_idx] > fixed_working_point | ||
): | ||
return ( | ||
targeted_values[lower_near_idx], | ||
fixing_values[lower_near_idx], | ||
thresholds[lower_near_idx], | ||
) | ||
else: | ||
upper_near_idx = lower_near_idx | ||
if lower_near_idx == len(targeted_values) - 1 or lower_near_idx == 0: | ||
upper_near_idx == lower_near_idx | ||
elif targeted_values[lower_near_idx] < fixed_working_point: | ||
if targeted_values[lower_near_idx + 1] >= targeted_values[lower_near_idx]: | ||
upper_near_idx = lower_near_idx + 1 | ||
else: | ||
upper_near_idx = lower_near_idx - 1 | ||
else: | ||
if targeted_values[lower_near_idx + 1] <= targeted_values[lower_near_idx]: | ||
upper_near_idx = lower_near_idx + 1 | ||
else: | ||
upper_near_idx = lower_near_idx - 1 | ||
|
||
if ( | ||
lower_near_idx > upper_near_idx | ||
): # targeted_values[lower_near_idx] >= fixed_working_point >= targeted_values[upper_near_idx]: | ||
l_idx = lower_near_idx | ||
lower_near_idx = upper_near_idx | ||
upper_near_idx = l_idx | ||
|
||
# if targeted_values[near_idx] <= fixed_working_point <= targeted_values[near_idx - 1]: | ||
# if targeted_values[near_idx] <= fixed_working_point: | ||
# if targeted_values[near_idx] >= fixed_working_point: | ||
# lower_idx = near_idx - 1 | ||
# else: | ||
# lower_idx = near_idx | ||
# near_idx = near_idx + 1 | ||
|
||
# if near_idx > 99: # targeted_values.size - 1 | ||
# near_idx = near_idx - 1 | ||
# lower_idx = lower_idx - 1 | ||
|
||
x0 = targeted_values[lower_near_idx] | ||
x1 = targeted_values[upper_near_idx] | ||
|
||
y0 = fixing_values[lower_near_idx] | ||
y1 = fixing_values[upper_near_idx] | ||
t0 = thresholds[lower_near_idx] | ||
t1 = thresholds[upper_near_idx] | ||
|
||
if t0 == t1 == 1.0: | ||
return NOT_VALID_WORKING_POINT | ||
|
||
if x1 - x0 == 0.0: | ||
m = 0 | ||
mt = 0 | ||
else: | ||
m = (y1 - y0) / (x1 - x0) | ||
mt = (t1 - t0) / (x1 - x0) | ||
|
||
target_val = fixed_working_point | ||
fixed_val = (target_val - x0) * m + y0 | ||
threshold_val = (target_val - x0) * mt + t0 | ||
|
||
return target_val, fixed_val, threshold_val | ||
# import numpy as np | ||
# | ||
# VALUE_NOT_VALID_WORKING_POINT = 1.0 | ||
# THRESHOLD_NOT_VALID_WORKING_POINT = -1.0 | ||
# NOT_VALID_WORKING_POINT = ( | ||
# VALUE_NOT_VALID_WORKING_POINT, | ||
# VALUE_NOT_VALID_WORKING_POINT, | ||
# THRESHOLD_NOT_VALID_WORKING_POINT, | ||
# ) | ||
# | ||
# | ||
# def get_target_value_fixing_working_point( | ||
# fixed_working_point, targeted_values, fixing_values, thresholds, interpolated=False | ||
# ): | ||
# if not isinstance(targeted_values, np.ndarray): | ||
# targeted_values = np.array(targeted_values) | ||
# if not isinstance(fixing_values, np.ndarray): | ||
# fixing_values = np.array(fixing_values) | ||
# if not isinstance(thresholds, np.ndarray): | ||
# thresholds = np.array(thresholds) | ||
# | ||
# if (targeted_values == fixing_values).all(): | ||
# return NOT_VALID_WORKING_POINT | ||
# | ||
# lower_near_idx = np.abs(targeted_values - fixed_working_point).argmin() | ||
# | ||
# if ( | ||
# not interpolated | ||
# or lower_near_idx == 0 | ||
# and targeted_values[lower_near_idx] > fixed_working_point | ||
# ): | ||
# return ( | ||
# targeted_values[lower_near_idx], | ||
# fixing_values[lower_near_idx], | ||
# thresholds[lower_near_idx], | ||
# ) | ||
# else: | ||
# upper_near_idx = lower_near_idx | ||
# if lower_near_idx == len(targeted_values) - 1 or lower_near_idx == 0: | ||
# upper_near_idx == lower_near_idx | ||
# elif targeted_values[lower_near_idx] < fixed_working_point: | ||
# if targeted_values[lower_near_idx + 1] >= targeted_values[lower_near_idx]: | ||
# upper_near_idx = lower_near_idx + 1 | ||
# else: | ||
# upper_near_idx = lower_near_idx - 1 | ||
# else: | ||
# if targeted_values[lower_near_idx + 1] <= targeted_values[lower_near_idx]: | ||
# upper_near_idx = lower_near_idx + 1 | ||
# else: | ||
# upper_near_idx = lower_near_idx - 1 | ||
# | ||
# if ( | ||
# lower_near_idx > upper_near_idx | ||
# ): # targeted_values[lower_near_idx] >= fixed_working_point >= targeted_values[upper_near_idx]: | ||
# l_idx = lower_near_idx | ||
# lower_near_idx = upper_near_idx | ||
# upper_near_idx = l_idx | ||
# | ||
# # if targeted_values[near_idx] <= fixed_working_point <= targeted_values[near_idx - 1]: | ||
# # if targeted_values[near_idx] <= fixed_working_point: | ||
# # if targeted_values[near_idx] >= fixed_working_point: | ||
# # lower_idx = near_idx - 1 | ||
# # else: | ||
# # lower_idx = near_idx | ||
# # near_idx = near_idx + 1 | ||
# | ||
# # if near_idx > 99: # targeted_values.size - 1 | ||
# # near_idx = near_idx - 1 | ||
# # lower_idx = lower_idx - 1 | ||
# | ||
# x0 = targeted_values[lower_near_idx] | ||
# x1 = targeted_values[upper_near_idx] | ||
# | ||
# y0 = fixing_values[lower_near_idx] | ||
# y1 = fixing_values[upper_near_idx] | ||
# t0 = thresholds[lower_near_idx] | ||
# t1 = thresholds[upper_near_idx] | ||
# | ||
# if t0 == t1 == 1.0: | ||
# return NOT_VALID_WORKING_POINT | ||
# | ||
# if x1 - x0 == 0.0: | ||
# m = 0 | ||
# mt = 0 | ||
# else: | ||
# m = (y1 - y0) / (x1 - x0) | ||
# mt = (t1 - t0) / (x1 - x0) | ||
# | ||
# target_val = fixed_working_point | ||
# fixed_val = (target_val - x0) * m + y0 | ||
# threshold_val = (target_val - x0) * mt + t0 | ||
# | ||
# return target_val, fixed_val, threshold_val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 12 additions & 57 deletions
69
tests/unit/foundations/metrics/test_generalization_metrics.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,19 @@ | ||
import pytest | ||
|
||
from gradgpad.foundations.metrics.metrics import Metrics | ||
from gradgpad.foundations.scores.approach import Approach | ||
from gradgpad.foundations.scores.protocol import Protocol | ||
from gradgpad.foundations.scores.scores_provider import ScoresProvider | ||
from gradgpad.foundations.scores.subset import Subset | ||
import os | ||
|
||
import pytest | ||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize( | ||
"devel_scores,test_scores", | ||
[ | ||
( | ||
ScoresProvider.get( | ||
approach=Approach.QUALITY_LINEAR, | ||
protocol=Protocol.GRANDTEST, | ||
subset=Subset.DEVEL, | ||
), | ||
ScoresProvider.get( | ||
approach=Approach.QUALITY_LINEAR, | ||
protocol=Protocol.GRANDTEST, | ||
subset=Subset.TEST, | ||
), | ||
) | ||
], | ||
) | ||
def test_should_calculate_eer_for_devel_and_test(devel_scores, test_scores): | ||
|
||
metrics = Metrics(devel_scores, test_scores) | ||
assert pytest.approx(metrics.get_eer(Subset.DEVEL), 0.01) == 0.269 | ||
assert pytest.approx(metrics.get_eer(Subset.TEST), 0.01) == 0.246 | ||
from gradgpad import Approach, GeneralizationMetrics, ResultsProvider | ||
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize( | ||
"devel_scores,test_scores", | ||
[ | ||
( | ||
ScoresProvider.get( | ||
approach=Approach.QUALITY_RBF, | ||
protocol=Protocol.GRANDTEST, | ||
subset=Subset.DEVEL, | ||
), | ||
ScoresProvider.get( | ||
approach=Approach.QUALITY_RBF, | ||
protocol=Protocol.GRANDTEST, | ||
subset=Subset.TEST, | ||
), | ||
) | ||
], | ||
) | ||
def test_should_calculate_indeepth_analysis(devel_scores, test_scores): | ||
|
||
metrics = Metrics(devel_scores, test_scores) | ||
bpcer_fixing_working_points = [0.10] | ||
apcer_fixing_working_points = [0.10] | ||
|
||
indepth_analysis = metrics.get_indepth_analysis( | ||
bpcer_fixing_working_points, apcer_fixing_working_points | ||
@pytest.mark.skip | ||
def test_should_save_generalization_metrics(): | ||
os.makedirs("output", exist_ok=True) | ||
output_filename = "output/generalization_metrics.txt" | ||
all_results = ResultsProvider.all(Approach.AUXILIARY) | ||
generalization_metrics = GeneralizationMetrics() | ||
generalization_metrics.save( | ||
output_filename=output_filename, all_results=all_results | ||
) | ||
|
||
assert pytest.approx(indepth_analysis["fine_grained_pai"]["acer"], 0.01) == 59.60 | ||
assert pytest.approx(indepth_analysis["coarse_grained_pai"]["acer"], 0.01) == 36.06 | ||
assert os.path.isfile(output_filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import pytest | ||
|
||
from gradgpad.tools.visualization.colors import get_color_random_style | ||
|
||
|
||
@pytest.mark.unit | ||
def test_should_obtain_a_random_color_style(): | ||
color, linestyles, markers = get_color_random_style() | ||
assert isinstance(color, str) | ||
assert isinstance(linestyles, str) | ||
assert isinstance(markers, str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from gradgpad import ( | ||
Approach, | ||
DetPlotter, | ||
Protocol, | ||
ScoresProvider, | ||
SplitByLabelMode, | ||
Subset, | ||
) | ||
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize("split_by_label_mode", SplitByLabelMode.options_for_curves()) | ||
def test_should_save_a_det_plotter(split_by_label_mode: SplitByLabelMode): | ||
os.makedirs("output", exist_ok=True) | ||
output_filename = "output/radar.png" | ||
scores = ScoresProvider.get(Approach.AUXILIARY, Protocol.GRANDTEST, Subset.TEST) | ||
plotter = DetPlotter( | ||
title="My Title", | ||
split_by_label_mode=split_by_label_mode, | ||
) | ||
plotter.save(output_filename=output_filename, scores=scores) | ||
assert os.path.isfile(output_filename) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_should_raise_a_type_error_when_save_a_det_plotter_with_no_valid_split_by_label_mode(): | ||
pytest.raises( | ||
TypeError, | ||
lambda: DetPlotter( | ||
title="My Title", | ||
split_by_label_mode=SplitByLabelMode.DATASET, | ||
), | ||
) |
27 changes: 27 additions & 0 deletions
27
tests/unit/tools/visualization/percentile/test_bias_percentile_comparison_plotter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from gradgpad import ( | ||
Approach, | ||
BiasPercentileComparisonPlotter, | ||
Demographic, | ||
Protocol, | ||
ScoresProvider, | ||
Subset, | ||
) | ||
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize("demographic", Demographic.options()) | ||
def test_should_save_a_bias_percentile_comparison_plotter(demographic: Demographic): | ||
os.makedirs("output", exist_ok=True) | ||
output_filename = "output/radar.png" | ||
scores = ScoresProvider.get(Approach.AUXILIARY, Protocol.GRANDTEST, Subset.TEST) | ||
plotter = BiasPercentileComparisonPlotter( | ||
title="My Title", | ||
demographic=demographic, | ||
working_point=(0.5, 0.7), | ||
) | ||
plotter.save(output_filename=output_filename, scores=scores) | ||
assert os.path.isfile(output_filename) |
27 changes: 27 additions & 0 deletions
27
tests/unit/tools/visualization/percentile/test_bias_percentile_plotter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from gradgpad import ( | ||
Approach, | ||
BiasPercentilePlotter, | ||
Demographic, | ||
Protocol, | ||
ScoresProvider, | ||
Subset, | ||
) | ||
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize("demographic", Demographic.options()) | ||
def test_should_save_a_bias_percentile_plotter(demographic: Demographic): | ||
os.makedirs("output", exist_ok=True) | ||
output_filename = "output/radar.png" | ||
scores = ScoresProvider.get(Approach.AUXILIARY, Protocol.GRANDTEST, Subset.TEST) | ||
plotter = BiasPercentilePlotter( | ||
title="My Title", | ||
demographic=demographic, | ||
working_point=(0.5, 0.7), | ||
) | ||
plotter.save(output_filename=output_filename, scores=scores) | ||
assert os.path.isfile(output_filename) |
File renamed without changes.
File renamed without changes.
Oops, something went wrong.