Skip to content

Commit

Permalink
range fixes and checks
Browse files Browse the repository at this point in the history
  • Loading branch information
brightening-eyes committed Sep 23, 2023
1 parent 98c1fb5 commit 3aec165
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions src/layer/range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ Range::Range()

int Range::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
if (bottom_blobs.size() < 2 || bottom_blobs.size() > 3)
if (bottom_blobs.size() < 2 || bottom_blobs.size() > 3 || top_blobs.size() != 1)
return -100;

const Mat& start = bottom_blobs[0];
if(start.empty())
if (start.empty())
return -100;

const Mat& limit = bottom_blobs[1];
if(limit.empty())
if (limit.empty())
return -100;

const Mat& delta = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat();
Expand All @@ -44,29 +44,41 @@ int Range::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_b
if (start.w * start.h * start.d * start.c != 1 || limit.w * limit.h * limit.d * limit.c != 1 || (!delta.empty() && delta.w * delta.h * delta.d * delta.c != 1))
return -100;

const int* start_ptr = start;
const int* limit_ptr = limit;
const int* delta_ptr = delta;
if (start.elemsize != limit.elemsize || (!delta.empty() && start.elemsize != delta.elemsize))
return -100;

int start_val = *start_ptr;
int limit_val = *limit_ptr;
int delta_val = delta.empty() ? 1 : *delta_ptr;
const float* start_ptr = start;
const float* limit_ptr = limit;

if (delta_val == 0)
float start_val = start_ptr[0];
float limit_val = limit_ptr[0];
float delta_val = 1.0f;
if (!delta.empty())
{
const float* delta_ptr = delta;
delta_val = delta_ptr[0];
}

if (delta_val == 0.0f || (limit_val - start_val) * delta_val <= 0.0f)
return -100;

int number_of_elements = (int) std::max((int) ceilf((limit_val - start_val) / delta_val), 0);
if (limit_val < start_val && delta_val > 0.0f)
delta_val = -delta_val;

int number_of_elements = static_cast<int>(ceil((limit_val - start_val) / delta_val));
if (number_of_elements < 0)
number_of_elements = 0;

output.create(number_of_elements, start.elemsize, opt.blob_allocator);
output.create(number_of_elements, start.elemsize, start.elempack, opt.blob_allocator);
if (output.empty())
return -100;

int* outptr = output;
float* outptr = output;

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < number_of_elements; i++)
{
((int*)outptr)[i] = start_val + (i * delta_val);
((float*)outptr)[i] = start_val + (i * delta_val);
}

return 0;
Expand Down

0 comments on commit 3aec165

Please sign in to comment.