Skip to content

Commit

Permalink
fix build by transformation thresholds
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Mar 4, 2024
1 parent 1cd2268 commit 1049528
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions src/tests/test_utils/common_test_utils/src/ov_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,21 +337,20 @@ class Error {
protected:
struct IncorrectValue {
size_t coordinate;
double actual_value, expected_value, abs_threshold, rel_threshold;
double actual_value, expected_value, threshold;

IncorrectValue(double in_actual_value,
double in_expected_value,
double in_abs_threshold,
double in_rel_threshold,
double in_threshold,
size_t in_coordinate)
: actual_value(in_actual_value),
expected_value(in_expected_value),
abs_threshold(in_abs_threshold),
rel_threshold(in_rel_threshold),
threshold(in_threshold),
coordinate(in_coordinate) {}
};

std::vector<IncorrectValue> incorrect_values;
std::vector<IncorrectValue> incorrect_values_abs;
std::vector<IncorrectValue> incorrect_values_rel;
double abs_threshold, rel_threshold;

public:
Expand All @@ -364,25 +363,41 @@ class Error {

const auto calculated_abs_threshold = abs_threshold * expected;
const auto calculated_rel_threshold = calculated_abs_threshold ? diff / calculated_abs_threshold : 1.;
if (less_or_equal(diff, calculated_abs_threshold) && less_or_equal(calculated_rel_threshold, rel_threshold)) {
incorrect_values_rel.emplace_back(
IncorrectValue(actual, expected, calculated_rel_threshold, coordinate));
if (less_or_equal(diff, calculated_abs_threshold)) {
return;
}
incorrect_values.emplace_back(
IncorrectValue(actual, expected, calculated_abs_threshold, calculated_rel_threshold, coordinate));
incorrect_values_abs.emplace_back(
IncorrectValue(actual, expected, calculated_abs_threshold, coordinate));
}

void get_results() {
if (!incorrect_values.empty()) {
if (!incorrect_values_abs.empty()) {
std::string msg = "[ COMPARATION ] COMPARATION IS FAILED! incorrect elem counter: ";
msg += std::to_string(incorrect_values.size());
msg += ". Please print `incorrect_values` to get detailed information!";
msg += std::to_string(incorrect_values_abs.size());
msg += ". Please print `incorrect_values_abs` to get detailed information!";
throw std::runtime_error(msg);
}
if (!incorrect_values_rel.empty()) {
double rel_error = 0;
std::for_each(incorrect_values_rel.begin(), incorrect_values_rel.end(), [&](IncorrectValue val) {
rel_error += val.threshold;
});
rel_error /= incorrect_values_rel.size();

if (!less_or_equal(rel_error, rel_threshold)) {
std::string msg = "[ COMPARATION ] COMPARATION IS FAILED! incorrect elem counter: ";
msg += std::to_string(incorrect_values_abs.size());
msg += ". Please print `incorrect_values_rel` to get detailed information!";
throw std::runtime_error(msg);
}
}
}
};

template <typename ExpectedT, typename ActualT>
void compare(const ov::Tensor& expected, const ov::Tensor& actual, double abs_threshold, const double rel_threshold) {
void compare(const ov::Tensor& expected, const ov::Tensor& actual, double abs_threshold, double rel_threshold) {
auto expected_shape = expected.get_shape();
auto actual_shape = actual.get_shape();
if (expected_shape != actual_shape) {
Expand All @@ -400,6 +415,11 @@ void compare(const ov::Tensor& expected, const ov::Tensor& actual, double abs_th
abs_threshold = default_abs_threshold;
}

const auto default_rel_threshold = 0.5;
if (rel_threshold == std::numeric_limits<double>::max() || rel_threshold < default_rel_threshold) {
rel_threshold = default_rel_threshold;
}

size_t shape_size_cnt = shape_size(expected_shape);
Error error(abs_threshold, rel_threshold);
const auto expected_data = expected.data<ExpectedT>();
Expand Down

0 comments on commit 1049528

Please sign in to comment.