Skip to content

Commit

Permalink
value range fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shurale-nkn committed Nov 2, 2023
1 parent 3ff724f commit 34c1094
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
3 changes: 3 additions & 0 deletions test/rnn_seq_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ struct rnn_seq_driver : rnn_seq_api_test_driver<T>

bool is_correct_params()
{
if(this->useDropout == 1 && (this->hiddenSize == 1 || this->batchSize == 1))
return false;

if(this->inputMode == 1 && this->hiddenSize != this->inVecLen)
return false;

Expand Down
20 changes: 12 additions & 8 deletions test/rnn_seq_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1462,8 +1462,8 @@ struct rnn_seq_api_test_driver : test_driver
tensor<T>& dcy,
std::vector<T>& weights)
{
const double scale = 0.01;
const double bwd_scale = scale * 0.001;
const double scale = 0.1;
const double bwd_scale = scale;

struct scalar_gen_random_float
{
Expand All @@ -1476,22 +1476,26 @@ struct rnn_seq_api_test_driver : test_driver
}
};

auto gen_positive_value = [=](auto...) { return scalar_gen_random_float{0, 1 * scale}(); };
auto gen_positive_value = [=](auto...) {
return scalar_gen_random_float{std::numeric_limits<T>::epsilon(), 1 * scale}();
};

auto gen_positive_value_bwd = [=](auto...) {
double bwd_max = 1. * scale;
double bwd_min = std::numeric_limits<T>::epsilon();
double bwd_max = std::min(bwd_min * 100, 1. * scale);
return scalar_gen_random_float{std::numeric_limits<T>::epsilon(), bwd_max}();
return scalar_gen_random_float{bwd_min, bwd_max}();
};

auto fill_array_via_gen = [](auto& dst, size_t dst_sz, double range_l, double range_r) {
for(size_t it = 0; it < dst_sz; it++)
dst[it] = prng::gen_A_to_B(static_cast<T>(range_l), static_cast<T>(range_r));
};
prng::reset_seed();
fill_array_via_gen(input.data, input.data.size(), 0.0, 1.0 * scale);
fill_array_via_gen(
input.data, input.data.size(), std::numeric_limits<T>::epsilon(), 1. * scale);
prng::reset_seed();
fill_array_via_gen(dy.data, dy.data.size(), 0, 1.0 * bwd_scale);
fill_array_via_gen(
dy.data, dy.data.size(), std::numeric_limits<T>::epsilon(), 1. * bwd_scale);
prng::reset_seed();

const auto hidden_size = hx.desc.GetLengths()[2];
Expand Down Expand Up @@ -1646,7 +1650,7 @@ struct rnn_seq_api_test_driver : test_driver
// avoid BWD unexpected fails
if(inVecLen == 1)
{
tolerance = 160;
tolerance = 80;
}
else
{
Expand Down

0 comments on commit 34c1094

Please sign in to comment.