diff --git a/src/layer/diag.cpp b/src/layer/diag.cpp index 936b32b2366..f6f9c64a2ac 100644 --- a/src/layer/diag.cpp +++ b/src/layer/diag.cpp @@ -37,8 +37,7 @@ int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons if (dims == 1) { int w = bottom_blob.w; - int top_w = w + std::abs(diagonal); - int stride = top_w + 1; + int top_w = w + ((diagonal >= 0) ? diagonal : -diagonal); top_blob.create(top_w, top_w, elemsize, opt.blob_allocator); if (top_blob.empty()) @@ -58,10 +57,16 @@ int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons { int w = bottom_blob.w; int h = bottom_blob.h; - float tmp = (w - h) / 2.0; - int len = std::min(w, h) - (int)std::max(std::abs(diagonal - tmp) - std::abs(tmp), 0.0f); - len = std::max(len, 0); + int len = 0; + int minimum = std::min(w - h, 0); + int maximum = std::max(w - h, 0); + if (diagonal <= maximum && diagonal >= minimum) + len = std::min(w, h); + else if (diagonal > -h && diagonal < minimum) + len = diagonal + h; + else if (diagonal > maximum && diagonal < w) + len = -diagonal + w; top_blob.create(len, elemsize, opt.blob_allocator); if (top_blob.empty())