From 0afc09f88cdf5b430da0d8aef232df9631de3e0c Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sat, 21 May 2022 23:25:43 +0800 Subject: [PATCH] fix upsample shape infer bug (#8105) * fix upsample shape infer bug * add more example * fix eager free tensor bug when in job * fix nearest2d bug * fix upsample nearest1d shape infer bug * restruct upsample op * change all float scale to double * Fix upsample shape infer bug continue (#8159) * fix_upsample_shape_infer_bug * fix 5 nearest * add 5 nearest test * fix 5 nearest test * fix 1 linear * fix 4 bilinear * fix 4 bicubic * modify bicubic 2d file name * fix 5 trilinear * fix exception info * fix exception info * fix bug * modify interpolate * change float to double * rm useless SI64ArrayAttr: in OneFlowUserOps * rm useless import in cpp * update * add judge for output_size * update oneflow/oneflow/core/autograd/gradient_funcs/upsample.cpp * add grad in td * test failed * fix small failed case in upsample * solve test error * change float to double * align to fix_upsample_shape_infer_bug * align to fix_upsample_shape_infer_bug Co-authored-by: BBuf <1182563586@qq.com> * fix comment * fix test bug * fix comment * fix jiebao commnet * fix comment * fix commnet * all scale to double * auto format by CI * fix clang tidy bug * fix clang tidy warning * relax speed test * simplify test case * fix bug Co-authored-by: Shanshan Zhong <62104945+zhongshsh@users.noreply.github.com> Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com> Co-authored-by: oneflow-ci-bot Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- .github/workflows/test.yml | 4 +- .../autograd/gradient_funcs/activation.cpp | 2 - .../core/autograd/gradient_funcs/upsample.cpp | 154 +++++++++------ oneflow/core/functional/functional_api.yaml | 36 ++-- .../core/functional/impl/array_functor.cpp | 182 ++++++++++++------ oneflow/ir/include/OneFlow/OneFlowUserOps.td | 70 ++++--- ...nel.cpp => upsample_bicubic_2d_kernel.cpp} | 23 ++- ...ernel.cu => upsample_bicubic_2d_kernel.cu} | 18 +- .../kernels/upsample_bilinear_2d_kernel.cpp | 20 +- .../kernels/upsample_bilinear_2d_kernel.cu | 18 +- oneflow/user/kernels/upsample_kernel.h | 33 ++-- .../kernels/upsample_linear_1d_kernel.cpp | 30 ++- .../user/kernels/upsample_linear_1d_kernel.cu | 28 ++- .../user/kernels/upsample_nearest_kernel.cpp | 80 +++++--- .../user/kernels/upsample_nearest_kernel.cu | 76 ++++++-- .../kernels/upsample_trilinear_3d_kernel.cpp | 26 ++- .../kernels/upsample_trilinear_3d_kernel.cu | 26 ++- oneflow/user/ops/upsample_op.cpp | 145 +++++++++----- python/oneflow/nn/modules/interpolate.py | 30 ++- python/oneflow/nn/modules/upsampling.py | 1 - python/oneflow/test/modules/test_upsample.py | 61 +++++- 21 files changed, 733 insertions(+), 330 deletions(-) rename oneflow/user/kernels/{upsample_bicubic2d_kernel.cpp => upsample_bicubic_2d_kernel.cpp} (90%) rename oneflow/user/kernels/{upsample_bicubic2d_kernel.cu => upsample_bicubic_2d_kernel.cu} (91%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 539d1924440..030b6a35ae4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -912,8 +912,8 @@ jobs: with: collect-path: ${{ env.FLOW_VISION_SRC }}/benchmark container-name: ${{ env.TEST_CONTAINER_NAME }} - unknown-threshold: 15 - error-threshold: 20 + unknown-threshold: 30 + error-threshold: 40 - name: Remove automerge if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }} uses: actions/github-script@v4 diff --git a/oneflow/core/autograd/gradient_funcs/activation.cpp b/oneflow/core/autograd/gradient_funcs/activation.cpp index e27675b504f..175fffa134b 100644 --- a/oneflow/core/autograd/gradient_funcs/activation.cpp +++ b/oneflow/core/autograd/gradient_funcs/activation.cpp @@ -15,9 +15,7 @@ limitations under the License. */ #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" -#include "oneflow/core/common/container_util.h" #include "oneflow/core/functional/functional.h" -#include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { diff --git a/oneflow/core/autograd/gradient_funcs/upsample.cpp b/oneflow/core/autograd/gradient_funcs/upsample.cpp index 722b8d2e24c..a97c8d7cd82 100644 --- a/oneflow/core/autograd/gradient_funcs/upsample.cpp +++ b/oneflow/core/autograd/gradient_funcs/upsample.cpp @@ -16,14 +16,15 @@ limitations under the License. #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/common/container_util.h" namespace oneflow { namespace one { struct UpsampleCaptureState : public AutoGradCaptureState { - bool requires_grad; - float height_scale; - float width_scale; + bool requires_grad = false; + double height_scale = 0.0; + double width_scale = 0.0; float align_corners; std::string data_format; std::string interpolation; @@ -54,8 +55,8 @@ Maybe Upsample::Capture(UpsampleCaptureState* ctx, const TensorTuple& inpu ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); - ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); - ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); + ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->interpolation = JUST(composed_attrs.GetAttr("interpolation")); @@ -70,18 +71,19 @@ Maybe Upsample::Apply(const UpsampleCaptureState* ctx, const TensorTuple& const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = - JUST(functional::UpsampleGrad(out_grads.at(0), x, ctx->height_scale, ctx->width_scale, - ctx->align_corners, ctx->data_format, ctx->interpolation)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, + ctx->align_corners, ctx->data_format, ctx->interpolation)); return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample); struct UpsampleNearest2DCaptureState : public AutoGradCaptureState { - bool requires_grad; - float height_scale; - float width_scale; + bool requires_grad = false; + double height_scale = 0.0; + double width_scale = 0.0; + std::vector output_size; std::string data_format; }; @@ -96,8 +98,11 @@ class UpsampleNearest2D : public OpExprGradFunctionrequires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); - ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); - ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); + ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + if (base_attrs_.find("output_size") != base_attrs_.end()) { + ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); + } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); @@ -110,8 +115,9 @@ class UpsampleNearest2D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = JUST(functional::UpsampleNearest2DGrad(out_grads.at(0), x, ctx->height_scale, - ctx->width_scale, ctx->data_format)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, + ctx->output_size, ctx->data_format)); return Maybe::Ok(); } @@ -123,10 +129,11 @@ class UpsampleNearest2D : public OpExprGradFunction output_size; std::string data_format; }; @@ -141,9 +148,12 @@ class UpsampleBilinear2D : public OpExprGradFunctionrequires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); - ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); - ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); + ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); + if (base_attrs_.find("output_size") != base_attrs_.end()) { + ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); + } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); @@ -156,9 +166,9 @@ class UpsampleBilinear2D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = JUST(functional::UpsampleBilinear2DGrad(out_grads.at(0), x, ctx->height_scale, - ctx->width_scale, ctx->align_corners, - ctx->data_format)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBilinear2DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, + ctx->align_corners, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } @@ -170,9 +180,10 @@ class UpsampleBilinear2D : public OpExprGradFunction output_size; std::string data_format; }; @@ -187,8 +198,11 @@ class UpsampleLinear1D : public OpExprGradFunction ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); - ctx->scale_factor = JUST(composed_attrs.GetAttr("scale_factor")); + ctx->scale_factor = JUST(composed_attrs.GetAttr("scale_factor")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); + if (base_attrs_.find("output_size") != base_attrs_.end()) { + ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); + } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); @@ -201,8 +215,9 @@ class UpsampleLinear1D : public OpExprGradFunction MutableAttrMap attrs; const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = JUST(functional::UpsampleLinear1DGrad(out_grads.at(0), x, ctx->scale_factor, - ctx->align_corners, ctx->data_format)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleLinear1DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->scale_factor, ctx->align_corners, + ctx->output_size, ctx->data_format)); return Maybe::Ok(); } @@ -214,8 +229,9 @@ class UpsampleLinear1D : public OpExprGradFunction REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_linear_1d", UpsampleLinear1D); struct UpsampleNearest1DCaptureState : public AutoGradCaptureState { - bool requires_grad; - float scale_factor; + bool requires_grad = false; + double scale_factor = 0.0; + std::vector output_size; std::string data_format; }; @@ -230,7 +246,10 @@ class UpsampleNearest1D : public OpExprGradFunctionrequires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); - ctx->scale_factor = JUST(composed_attrs.GetAttr("scale_factor")); + ctx->scale_factor = JUST(composed_attrs.GetAttr("scale_factor")); + if (base_attrs_.find("output_size") != base_attrs_.end()) { + ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); + } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); @@ -243,8 +262,9 @@ class UpsampleNearest1D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = JUST( - functional::UpsampleNearest1DGrad(out_grads.at(0), x, ctx->scale_factor, ctx->data_format)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST( + functional::UpsampleNearest1DGrad(JUST(oneflow::VectorAt(out_grads, 0)), x, + ctx->scale_factor, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } @@ -256,10 +276,11 @@ class UpsampleNearest1D : public OpExprGradFunction output_size; std::string data_format; }; @@ -274,9 +295,12 @@ class UpsampleBicubic2D : public OpExprGradFunctionrequires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); - ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); - ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); + ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); + if (base_attrs_.find("output_size") != base_attrs_.end()) { + ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); + } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); @@ -289,9 +313,9 @@ class UpsampleBicubic2D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = JUST(functional::UpsampleBicubic2DGrad(out_grads.at(0), x, ctx->height_scale, - ctx->width_scale, ctx->align_corners, - ctx->data_format)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBicubic2DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, + ctx->align_corners, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } @@ -302,10 +326,11 @@ class UpsampleBicubic2D : public OpExprGradFunction output_size; std::string data_format; }; @@ -320,9 +345,12 @@ class UpsampleNearest3D : public OpExprGradFunctionrequires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); - ctx->depth_scale = JUST(composed_attrs.GetAttr("depth_scale")); - ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); - ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + ctx->depth_scale = JUST(composed_attrs.GetAttr("depth_scale")); + ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); + ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + if (base_attrs_.find("output_size") != base_attrs_.end()) { + ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); + } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); @@ -335,9 +363,9 @@ class UpsampleNearest3D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = JUST(functional::UpsampleNearest3DGrad(out_grads.at(0), x, ctx->depth_scale, - ctx->height_scale, ctx->width_scale, - ctx->data_format)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale, + ctx->width_scale, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } @@ -349,11 +377,12 @@ class UpsampleNearest3D : public OpExprGradFunction output_size; std::string data_format; }; @@ -368,10 +397,13 @@ class UpsampleTrilinear3D : public OpExprGradFunctionrequires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); - ctx->depth_scale = JUST(composed_attrs.GetAttr("depth_scale")); - ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); - ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); + ctx->depth_scale = JUST(composed_attrs.GetAttr("depth_scale")); + ctx->height_scale = JUST(composed_attrs.GetAttr("height_scale")); + ctx->width_scale = JUST(composed_attrs.GetAttr("width_scale")); ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); + if (base_attrs_.find("output_size") != base_attrs_.end()) { + ctx->output_size = JUST(composed_attrs.GetAttr>("output_size")); + } ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); @@ -384,9 +416,9 @@ class UpsampleTrilinear3D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = JUST(functional::UpsampleTrilinear3DGrad( - out_grads.at(0), x, ctx->depth_scale, ctx->height_scale, ctx->width_scale, - ctx->align_corners, ctx->data_format)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleTrilinear3DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale, + ctx->width_scale, ctx->align_corners, ctx->output_size, ctx->data_format)); return Maybe::Ok(); } diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index ba43d18c651..4c3c33c8cbb 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1298,94 +1298,94 @@ - name: "upsample" signature: - 'Tensor (Tensor x, Float height_scale, Float width_scale, Bool align_corners, + 'Tensor (Tensor x, Double height_scale, Double width_scale, Bool align_corners, String interpolation, String data_format="channels_first") => Upsample' bind_python: True - name: "upsample_grad" signature: - 'Tensor (Tensor dy, Tensor x, Float height_scale, Float width_scale, Bool align_corners, + 'Tensor (Tensor dy, Tensor x, Double height_scale, Double width_scale, Bool align_corners, String data_format, String interpolation) => UpsampleGrad' bind_python: False - name: "upsample_linear_1d" - signature: 'Tensor (Tensor x, Float scale_factor, Bool align_corners, + signature: 'Tensor (Tensor x, Double scale_factor=0.0, Bool align_corners=False, Int64List[1] output_size=None, String data_format="channels_first") => UpsampleLinear1D' bind_python: True - name: "upsample_linear_1d_grad" signature: - 'Tensor (Tensor dy, Tensor x, Float scale_factor, Bool align_corners, + 'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Bool align_corners=False, Int64List[1] output_size=None, String data_format="channels_first") => UpsampleLinear1DGrad' bind_python: False - name: "upsample_nearest_1d" - signature: 'Tensor (Tensor x, Float scale_factor, + signature: 'Tensor (Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None, String data_format="channels_first") => UpsampleNearest1D' bind_python: True - name: "upsample_nearest_1d_grad" - signature: 'Tensor (Tensor dy, Tensor x, Float scale_factor, + signature: 'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None, String data_format="channels_first") => UpsampleNearest1DGrad' bind_python: False - name: "upsample_nearest_2d" - signature: 'Tensor (Tensor x, Float height_scale, Float width_scale, + signature: 'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleNearest2D' bind_python: True - name: "upsample_nearest_2d_grad" signature: - 'Tensor (Tensor dy, Tensor x, Float height_scale, Float width_scale, + 'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleNearest2DGrad' bind_python: False - name: "upsample_bilinear_2d" signature: - 'Tensor (Tensor x, Float height_scale, Float width_scale, Bool align_corners, + 'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleBilinear2D' bind_python: True - name: "upsample_bilinear_2d_grad" signature: - 'Tensor (Tensor dy, Tensor x, Float height_scale, Float width_scale, Bool align_corners, + 'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleBilinear2DGrad' bind_python: False - name: "upsample_bicubic_2d" signature: - 'Tensor (Tensor x, Float height_scale, Float width_scale, Bool align_corners, + 'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleBicubic2D' bind_python: True - name: "upsample_bicubic_2d_grad" signature: - 'Tensor (Tensor dy, Tensor x, Float height_scale, Float width_scale, Bool align_corners, + 'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None, String data_format="channels_first") => UpsampleBicubic2DGrad' bind_python: False - name: "upsample_nearest_3d" signature: - 'Tensor (Tensor x, Float depth_scale, Float height_scale, Float width_scale, + 'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None, String data_format="channels_first") => UpsampleNearest3D' bind_python: True - name: "upsample_nearest_3d_grad" signature: - 'Tensor (Tensor dy, Tensor x, Float depth_scale, Float height_scale, Float width_scale, + 'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None, String data_format="channels_first") => UpsampleNearest3DGrad' bind_python: False - name: "upsample_trilinear_3d" signature: - 'Tensor (Tensor x, Float depth_scale, Float height_scale, Float width_scale, - Bool align_corners, String data_format="channels_first") => UpsampleTrilinear3D' + 'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, + Int64List[3] output_size=None, String data_format="channels_first") => UpsampleTrilinear3D' bind_python: True - name: "upsample_trilinear_3d_grad" signature: - 'Tensor (Tensor dy, Tensor x, Float depth_scale, Float height_scale, Float width_scale, - Bool align_corners, String data_format="channels_first") => UpsampleTrilinear3DGrad' + 'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, + Bool align_corners=False, Int64List[3] output_size=None, String data_format="channels_first") => UpsampleTrilinear3DGrad' bind_python: False - name: "abs" diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 3a7aa3edefc..988076955e8 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -1365,12 +1365,12 @@ class UpsampleGradFunctor { op_ = CHECK_JUST(one::OpBuilder("upsample_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& x, const float& height_scale, - const float& width_scale, const bool& align_corners, + const std::shared_ptr& x, const double& height_scale, + const double& width_scale, const bool& align_corners, const std::string& data_format, const std::string& interpolation) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("align_corners", align_corners)); JUST(attrs.SetAttr("interpolation", interpolation)); JUST(attrs.SetAttr("data_format", data_format)); @@ -1466,12 +1466,17 @@ class UpsampleLinear1DFunctor { UpsampleLinear1DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_linear_1d").Input("x").Output("y").Build()); } - Maybe operator()(const std::shared_ptr& x, const float& scale_factor, - const bool& align_corners, const std::string& data_format) const { + Maybe operator()(const std::shared_ptr& x, const double& scale_factor, + const bool& align_corners, + const Optional>& output_size, + const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("scale_factor", scale_factor)); + JUST(attrs.SetAttr("scale_factor", scale_factor)); JUST(attrs.SetAttr("align_corners", align_corners)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1486,11 +1491,16 @@ class UpsampleLinear1DGradFunctor { one::OpBuilder("upsample_linear_1d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& x, const float& scale_factor, - const bool& align_corners, const std::string& data_format) const { + const std::shared_ptr& x, const double& scale_factor, + const bool& align_corners, + const Optional>& output_size, + const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("scale_factor", scale_factor)); + JUST(attrs.SetAttr("scale_factor", scale_factor)); JUST(attrs.SetAttr("align_corners", align_corners)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } JUST(attrs.SetAttr("data_format", data_format)); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } @@ -1504,11 +1514,15 @@ class UpsampleNearest1DFunctor { UpsampleNearest1DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_1d").Input("x").Output("y").Build()); } - Maybe operator()(const std::shared_ptr& x, const float& scale_factor, + Maybe operator()(const std::shared_ptr& x, const double& scale_factor, + const Optional>& output_size, const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("scale_factor", scale_factor)); + JUST(attrs.SetAttr("scale_factor", scale_factor)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1523,11 +1537,15 @@ class UpsampleNearest1DGradFunctor { one::OpBuilder("upsample_nearest_1d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& x, const float& scale_factor, + const std::shared_ptr& x, const double& scale_factor, + const Optional>& output_size, const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("scale_factor", scale_factor)); + JUST(attrs.SetAttr("scale_factor", scale_factor)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } @@ -1540,12 +1558,17 @@ class UpsampleNearest2DFunctor { UpsampleNearest2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_2d").Input("x").Output("y").Build()); } - Maybe operator()(const std::shared_ptr& x, const float& height_scale, - const float& width_scale, const std::string& data_format) const { + Maybe operator()(const std::shared_ptr& x, const double& height_scale, + const double& width_scale, + const Optional>& output_size, + const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1560,12 +1583,17 @@ class UpsampleNearest2DGradFunctor { one::OpBuilder("upsample_nearest_2d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& x, const float& height_scale, - const float& width_scale, const std::string& data_format) const { + const std::shared_ptr& x, const double& height_scale, + const double& width_scale, + const Optional>& output_size, + const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } @@ -1578,14 +1606,18 @@ class UpsampleBilinear2DFunctor { UpsampleBilinear2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_bilinear_2d").Input("x").Output("y").Build()); } - Maybe operator()(const std::shared_ptr& x, const float& height_scale, - const float& width_scale, const bool& align_corners, + Maybe operator()(const std::shared_ptr& x, const double& height_scale, + const double& width_scale, const bool& align_corners, + const Optional>& output_size, const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("align_corners", align_corners)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1600,13 +1632,17 @@ class UpsampleBilinear2DGradFunctor { one::OpBuilder("upsample_bilinear_2d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& x, const float& height_scale, - const float& width_scale, const bool& align_corners, + const std::shared_ptr& x, const double& height_scale, + const double& width_scale, const bool& align_corners, + const Optional>& output_size, const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("align_corners", align_corners)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } JUST(attrs.SetAttr("data_format", data_format)); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } @@ -1620,14 +1656,18 @@ class UpsampleBicubic2DFunctor { UpsampleBicubic2DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_bicubic_2d").Input("x").Output("y").Build()); } - Maybe operator()(const std::shared_ptr& x, const float& height_scale, - const float& width_scale, const bool& align_corners, + Maybe operator()(const std::shared_ptr& x, const double& height_scale, + const double& width_scale, const bool& align_corners, + const Optional>& output_size, const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("align_corners", align_corners)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1642,13 +1682,17 @@ class UpsampleBicubic2DGradFunctor { one::OpBuilder("upsample_bicubic_2d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& x, const float& height_scale, - const float& width_scale, const bool& align_corners, + const std::shared_ptr& x, const double& height_scale, + const double& width_scale, const bool& align_corners, + const Optional>& output_size, const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("align_corners", align_corners)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } JUST(attrs.SetAttr("data_format", data_format)); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } @@ -1662,14 +1706,18 @@ class UpsampleNearest3DFunctor { UpsampleNearest3DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_3d").Input("x").Output("y").Build()); } - Maybe operator()(const std::shared_ptr& x, const float& depth_scale, - const float& height_scale, const float& width_scale, + Maybe operator()(const std::shared_ptr& x, const double& depth_scale, + const double& height_scale, const double& width_scale, + const Optional>& output_size, const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("depth_scale", depth_scale)); - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("depth_scale", depth_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1684,13 +1732,17 @@ class UpsampleNearest3DGradFunctor { one::OpBuilder("upsample_nearest_3d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& x, const float& depth_scale, - const float& height_scale, const float& width_scale, + const std::shared_ptr& x, const double& depth_scale, + const double& height_scale, const double& width_scale, + const Optional>& output_size, const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("depth_scale", depth_scale)); - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("depth_scale", depth_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } JUST(attrs.SetAttr("data_format", data_format)); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } @@ -1704,15 +1756,20 @@ class UpsampleTrilinear3DFunctor { UpsampleTrilinear3DFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample_trilinear_3d").Input("x").Output("y").Build()); } - Maybe operator()(const std::shared_ptr& x, const float& depth_scale, - const float& height_scale, const float& width_scale, - const bool& align_corners, const std::string& data_format) const { + Maybe operator()(const std::shared_ptr& x, const double& depth_scale, + const double& height_scale, const double& width_scale, + const bool& align_corners, + const Optional>& output_size, + const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("depth_scale", depth_scale)); - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("depth_scale", depth_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("align_corners", align_corners)); JUST(attrs.SetAttr("data_format", data_format)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1727,14 +1784,19 @@ class UpsampleTrilinear3DGradFunctor { one::OpBuilder("upsample_trilinear_3d_grad").Input("dy").Input("x").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& x, const float& depth_scale, - const float& height_scale, const float& width_scale, - const bool& align_corners, const std::string& data_format) const { + const std::shared_ptr& x, const double& depth_scale, + const double& height_scale, const double& width_scale, + const bool& align_corners, + const Optional>& output_size, + const std::string& data_format) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("depth_scale", depth_scale)); - JUST(attrs.SetAttr("height_scale", height_scale)); - JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("depth_scale", depth_scale)); + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); JUST(attrs.SetAttr("align_corners", align_corners)); + if (output_size.has_value()) { + JUST(attrs.SetAttr>("output_size", *JUST(output_size))); + } JUST(attrs.SetAttr("data_format", data_format)); return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); } diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 4e6d1c1cd56..448c691ebfd 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -8790,9 +8790,10 @@ def OneFlow_UpsampleBicubic2DOp : OneFlow_BaseOp<"upsample_bicubic_2d", [NoSideE OneFlow_Tensor:$y ); let attrs = (ins - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8810,9 +8811,10 @@ def OneFlow_UpsampleBicubic2DGradOp : OneFlow_BaseOp<"upsample_bicubic_2d_grad", OneFlow_Tensor:$dx ); let attrs = (ins - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8829,9 +8831,10 @@ def OneFlow_UpsampleBilinear2DOp : OneFlow_BaseOp<"upsample_bilinear_2d", [NoSid OneFlow_Tensor:$y ); let attrs = (ins - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8849,9 +8852,10 @@ def OneFlow_UpsampleBilinear2DGradOp : OneFlow_BaseOp<"upsample_bilinear_2d_grad OneFlow_Tensor:$dx ); let attrs = (ins - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8868,8 +8872,9 @@ def OneFlow_UpsampleLinear1DOp : OneFlow_BaseOp<"upsample_linear_1d", [NoSideEff OneFlow_Tensor:$y ); let attrs = (ins - DefaultValuedAttr:$scale_factor, + DefaultValuedAttr:$scale_factor, DefaultValuedAttr:$align_corners, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8887,8 +8892,9 @@ def OneFlow_UpsampleLinear1DGradOp : OneFlow_BaseOp<"upsample_linear_1d_grad", [ OneFlow_Tensor:$dx ); let attrs = (ins - DefaultValuedAttr:$scale_factor, + DefaultValuedAttr:$scale_factor, DefaultValuedAttr:$align_corners, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8905,7 +8911,8 @@ def OneFlow_UpsampleNearest1DOp : OneFlow_BaseOp<"upsample_nearest_1d", [NoSideE OneFlow_Tensor:$y ); let attrs = (ins - DefaultValuedAttr:$scale_factor, + DefaultValuedAttr:$scale_factor, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8923,7 +8930,8 @@ def OneFlow_UpsampleNearest1DGradOp : OneFlow_BaseOp<"upsample_nearest_1d_grad", OneFlow_Tensor:$dx ); let attrs = (ins - DefaultValuedAttr:$scale_factor, + DefaultValuedAttr:$scale_factor, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8940,8 +8948,9 @@ def OneFlow_UpsampleNearest2DOp : OneFlow_BaseOp<"upsample_nearest_2d", [NoSideE OneFlow_Tensor:$y ); let attrs = (ins - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8959,8 +8968,9 @@ def OneFlow_UpsampleNearest2DGradOp : OneFlow_BaseOp<"upsample_nearest_2d_grad", OneFlow_Tensor:$dx ); let attrs = (ins - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8977,9 +8987,10 @@ def OneFlow_UpsampleNearest3DOp : OneFlow_BaseOp<"upsample_nearest_3d", [NoSideE OneFlow_Tensor:$y ); let attrs = (ins - DefaultValuedAttr:$depth_scale, - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -8997,9 +9008,10 @@ def OneFlow_UpsampleNearest3DGradOp : OneFlow_BaseOp<"upsample_nearest_3d_grad", OneFlow_Tensor:$dx ); let attrs = (ins - DefaultValuedAttr:$depth_scale, - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -9016,10 +9028,11 @@ def OneFlow_UpsampleTrilinear3DOp : OneFlow_BaseOp<"upsample_trilinear_3d", [NoS OneFlow_Tensor:$y ); let attrs = (ins - DefaultValuedAttr:$depth_scale, - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; @@ -9037,10 +9050,11 @@ def OneFlow_UpsampleTrilinear3DGradOp : OneFlow_BaseOp<"upsample_trilinear_3d_gr OneFlow_Tensor:$dx ); let attrs = (ins - DefaultValuedAttr:$depth_scale, - DefaultValuedAttr:$height_scale, - DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, DefaultValuedAttr:$align_corners, + SI64ArrayAttr:$output_size, StrAttr:$data_format ); let has_logical_tensor_desc_infer_fn = 1; diff --git a/oneflow/user/kernels/upsample_bicubic2d_kernel.cpp b/oneflow/user/kernels/upsample_bicubic_2d_kernel.cpp similarity index 90% rename from oneflow/user/kernels/upsample_bicubic2d_kernel.cpp rename to oneflow/user/kernels/upsample_bicubic_2d_kernel.cpp index 88f4f2d22ba..e174f9d2f94 100644 --- a/oneflow/user/kernels/upsample_bicubic2d_kernel.cpp +++ b/oneflow/user/kernels/upsample_bicubic_2d_kernel.cpp @@ -30,18 +30,24 @@ class UpsampleBicubic2dCPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + const T* in_ptr = x_tensor->dptr(); T* out_ptr = y_tensor->mut_dptr(); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); - const int nbatch = x_tensor->shape().At(0); const int channels = x_tensor->shape().At(1); + const int64_t in_height = x_tensor->shape().At(2); const int64_t in_width = x_tensor->shape().At(3); const int64_t out_height = y_tensor->shape().At(2); const int64_t out_width = y_tensor->shape().At(3); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } if (in_height == out_height && in_width == out_width) { memcpy(out_ptr, in_ptr, sizeof(T) * nbatch * channels * in_height * in_width); @@ -108,18 +114,23 @@ class UpsampleBicubic2dGradCPUKernel final : public user_op::OpKernel { user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); T* in_ptr = dx_tensor->mut_dptr(); const T* out_ptr = dy_tensor->dptr(); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); - const int nbatch = dx_tensor->shape().At(0); int channels = dx_tensor->shape().At(1); channels = channels * nbatch; + const int64_t in_height = dx_tensor->shape().At(2); const int64_t in_width = dx_tensor->shape().At(3); const int64_t out_height = dy_tensor->shape().At(2); const int64_t out_width = dy_tensor->shape().At(3); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } if (in_height == out_height && in_width == out_width) { memcpy(in_ptr, out_ptr, sizeof(T) * channels * in_height * in_width); } else { diff --git a/oneflow/user/kernels/upsample_bicubic2d_kernel.cu b/oneflow/user/kernels/upsample_bicubic_2d_kernel.cu similarity index 91% rename from oneflow/user/kernels/upsample_bicubic2d_kernel.cu rename to oneflow/user/kernels/upsample_bicubic_2d_kernel.cu index 21b7ec4ddac..ba810969160 100644 --- a/oneflow/user/kernels/upsample_bicubic2d_kernel.cu +++ b/oneflow/user/kernels/upsample_bicubic_2d_kernel.cu @@ -137,8 +137,6 @@ class UpsampleBicubic2dGPUKernel final : public user_op::OpKernel { user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const T* in_ptr = x_tensor->dptr(); T* out_ptr = y_tensor->mut_dptr(); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); const int nbatch = x_tensor->shape().At(0); @@ -147,6 +145,13 @@ class UpsampleBicubic2dGPUKernel final : public user_op::OpKernel { const int64_t in_width = x_tensor->shape().At(3); const int64_t out_height = y_tensor->shape().At(2); const int64_t out_width = y_tensor->shape().At(3); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } const int64_t elem_cnt = out_height * out_width; if (in_height == out_height && in_width == out_width) { @@ -178,8 +183,6 @@ class UpsampleBicubic2dGradGPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); const int nbatch = dx_tensor->shape().At(0); @@ -188,6 +191,13 @@ class UpsampleBicubic2dGradGPUKernel final : public user_op::OpKernel { const int64_t in_width = dx_tensor->shape().At(3); const int64_t out_height = dy_tensor->shape().At(2); const int64_t out_width = dy_tensor->shape().At(3); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } const int64_t elem_cnt = out_height * out_width; if (in_height == out_height && in_width == out_width) { diff --git a/oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp b/oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp index 7c0054050f5..ea1d3637f5e 100644 --- a/oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp +++ b/oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp @@ -84,9 +84,10 @@ class UpsampleBilinear2DCPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); NdIndexOffsetHelper in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1), x_tensor->shape().At(2), x_tensor->shape().At(3)); @@ -100,6 +101,11 @@ class UpsampleBilinear2DCPUKernel final : public user_op::OpKernel { const int64_t out_height = y_tensor->shape().At(2); const int64_t out_width = y_tensor->shape().At(3); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } + if (in_height == out_height && in_width == out_width) { memcpy(y_tensor->mut_dptr(), x_tensor->dptr(), sizeof(T) * nbatch * channels * in_height * in_width); @@ -126,9 +132,10 @@ class UpsampleBilinear2DGradCPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); NdIndexOffsetHelper dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1), dy_tensor->shape().At(2), dy_tensor->shape().At(3)); @@ -141,6 +148,11 @@ class UpsampleBilinear2DGradCPUKernel final : public user_op::OpKernel { const int64_t in_width = dx_tensor->shape().At(3); const int64_t out_height = dy_tensor->shape().At(2); const int64_t out_width = dy_tensor->shape().At(3); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } + if (in_height == out_height && in_width == out_width) { memcpy(dx_tensor->mut_dptr(), dy_tensor->dptr(), sizeof(T) * nbatch * channels * in_height * in_width); diff --git a/oneflow/user/kernels/upsample_bilinear_2d_kernel.cu b/oneflow/user/kernels/upsample_bilinear_2d_kernel.cu index 0fc35b4c944..c9f3a9d7fb7 100644 --- a/oneflow/user/kernels/upsample_bilinear_2d_kernel.cu +++ b/oneflow/user/kernels/upsample_bilinear_2d_kernel.cu @@ -90,9 +90,10 @@ class UpsampleBilinear2DGPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); NdIndexOffsetHelper in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1), x_tensor->shape().At(2), x_tensor->shape().At(3)); @@ -103,6 +104,10 @@ class UpsampleBilinear2DGPUKernel final : public user_op::OpKernel { const int64_t in_width = x_tensor->shape().At(3); const int64_t out_height = y_tensor->shape().At(2); const int64_t out_width = y_tensor->shape().At(3); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), @@ -131,9 +136,10 @@ class UpsampleBilinear2DGradGPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); NdIndexOffsetHelper dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1), dy_tensor->shape().At(2), dy_tensor->shape().At(3)); @@ -144,6 +150,10 @@ class UpsampleBilinear2DGradGPUKernel final : public user_op::OpKernel { const int64_t in_width = dx_tensor->shape().At(3); const int64_t out_height = dy_tensor->shape().At(2); const int64_t out_width = dy_tensor->shape().At(3); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), diff --git a/oneflow/user/kernels/upsample_kernel.h b/oneflow/user/kernels/upsample_kernel.h index 5365d9a8b60..a5d8b8f95ca 100644 --- a/oneflow/user/kernels/upsample_kernel.h +++ b/oneflow/user/kernels/upsample_kernel.h @@ -16,45 +16,43 @@ limitations under the License. #include "oneflow/core/common/nd_index_offset_helper.h" #include -template -OF_DEVICE_FUNC T GetLinearInputIndex(const int64_t out_dim_idx, const T scale, bool align_corners) { +OF_DEVICE_FUNC double GetLinearInputIndex(const int64_t out_dim_idx, const double scale, + bool align_corners) { if (align_corners) { - return static_cast(scale * out_dim_idx); + return static_cast(scale * out_dim_idx); } else { - T src_idx = scale * (out_dim_idx + 0.5) - 0.5; - return static_cast(src_idx < 0 ? 0 : src_idx); + double src_idx = scale * (out_dim_idx + 0.5) - 0.5; + return static_cast(src_idx < 0 ? 0 : src_idx); } } -OF_DEVICE_FUNC static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale, +OF_DEVICE_FUNC static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const double scale, const int64_t in_dim_size) { int64_t index = static_cast(floorf(out_dim_idx * scale)); index = index > in_dim_size - 1 ? in_dim_size - 1 : index; return index; } -template -OF_DEVICE_FUNC T GetAreaPixelScale(const int64_t input_size, const int64_t output_size, - bool align_corners, const T scale) { +OF_DEVICE_FUNC double GetAreaPixelScale(const int64_t input_size, const int64_t output_size, + bool align_corners, const double scale) { if (align_corners) { if (output_size > 1) { - return static_cast(input_size - 1) / (output_size - 1); + return static_cast(input_size - 1) / (output_size - 1); } else { return 0; } } else { - return (scale > 0. ? 1.0 / scale : static_cast(input_size) / output_size); + return (scale > 0. ? 1.0 / scale : static_cast(input_size) / output_size); } } -template -OF_DEVICE_FUNC T GetAreaPixel(const T scale, const int64_t dst_index, bool align_corners, - bool cubic = false) { +OF_DEVICE_FUNC double GetAreaPixel(const double scale, const int64_t dst_index, bool align_corners, + bool cubic = false) { if (align_corners) { return scale * dst_index; } else { - T src_idx = scale * (dst_index + 0.5) - 0.5; - return (!cubic && src_idx < 0) ? static_cast(0) : src_idx; + double src_idx = scale * (dst_index + 0.5) - 0.5; + return (!cubic && src_idx < 0) ? static_cast(0) : src_idx; } } @@ -71,7 +69,8 @@ struct BilinearParam { template OF_DEVICE_FUNC void GetBilinearParam(const bool align_corners, const int64_t h, const int64_t w, const int64_t in_height, const int64_t in_width, - const T scale_h, const T scale_w, BilinearParam* params) { + const double scale_h, const double scale_w, + BilinearParam* params) { T h1r; if (align_corners) { h1r = scale_h * static_cast(h); diff --git a/oneflow/user/kernels/upsample_linear_1d_kernel.cpp b/oneflow/user/kernels/upsample_linear_1d_kernel.cpp index 66fab8b4775..27c7cf41d94 100644 --- a/oneflow/user/kernels/upsample_linear_1d_kernel.cpp +++ b/oneflow/user/kernels/upsample_linear_1d_kernel.cpp @@ -26,15 +26,15 @@ template static void UpsampleLinear1DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int in_height, - const float scale_factor, bool align_corners, T* out_dptr) { + const double scale_factor, bool align_corners, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h; out_helper.OffsetToNdIndex(index, n, c, h); - const T h1r = GetLinearInputIndex(h, scale_factor, align_corners); + const double h1r = GetLinearInputIndex(h, scale_factor, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; - const T h1lambda = h1r - h1; - const T h0lambda = static_cast(1.) - h1lambda; + const double h1lambda = h1r - h1; + const double h0lambda = static_cast(1.) - h1lambda; out_dptr[index] = h0lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1)] + h1lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1 + h1p)]; } @@ -44,15 +44,15 @@ template static void UpsampleLinear1DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int in_height, - const float scale_factor, bool align_corners, T* dx_dptr) { + const double scale_factor, bool align_corners, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h; dy_helper.OffsetToNdIndex(index, n, c, h); - const T h1r = GetLinearInputIndex(h, scale_factor, align_corners); + const double h1r = GetLinearInputIndex(h, scale_factor, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; - const T h1lambda = h1r - h1; - const T h0lambda = static_cast(1.) - h1lambda; + const double h1lambda = h1r - h1; + const double h0lambda = static_cast(1.) - h1lambda; *(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1)) += h0lambda * dy_dptr[index]; *(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1 + h1p)) += h1lambda * dy_dptr[index]; @@ -71,7 +71,6 @@ class UpsampleLinear1DCPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float height_scale = ctx->Attr("scale_factor"); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); NdIndexOffsetHelper in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1), @@ -82,6 +81,12 @@ class UpsampleLinear1DCPUKernel final : public user_op::OpKernel { const int64_t channels = x_tensor->shape().At(1); const int64_t in_height = x_tensor->shape().At(2); const int64_t out_height = y_tensor->shape().At(2); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("scale_factor"); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + } + if (in_height == out_height) { memcpy(y_tensor->mut_dptr(), x_tensor->dptr(), sizeof(T) * nbatch * channels * in_height); @@ -106,7 +111,6 @@ class UpsampleLinearGrad1DCPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("scale_factor"); const bool align_corners = ctx->Attr("align_corners"); NdIndexOffsetHelper dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1), @@ -119,6 +123,12 @@ class UpsampleLinearGrad1DCPUKernel final : public user_op::OpKernel { const int64_t channels = dx_tensor->shape().At(1); const int64_t in_height = dx_tensor->shape().At(2); const int64_t out_height = dy_tensor->shape().At(2); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("scale_factor"); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + } + if (in_height == out_height) { memcpy(dx_tensor->mut_dptr(), dy_tensor->dptr(), sizeof(T) * nbatch * channels * in_height); diff --git a/oneflow/user/kernels/upsample_linear_1d_kernel.cu b/oneflow/user/kernels/upsample_linear_1d_kernel.cu index ce3f6427fd2..2c44f882baa 100644 --- a/oneflow/user/kernels/upsample_linear_1d_kernel.cu +++ b/oneflow/user/kernels/upsample_linear_1d_kernel.cu @@ -27,16 +27,16 @@ template __global__ void UpsampleLinear1DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, - const int in_height, const float scale_factor, + const int in_height, const double scale_factor, bool align_corners, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h; out_helper.OffsetToNdIndex(index, n, c, h); - const T h1r = GetLinearInputIndex(h, scale_factor, align_corners); + const double h1r = GetLinearInputIndex(h, scale_factor, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; - const T h1lambda = h1r - h1; - const T h0lambda = static_cast(1.) - h1lambda; + const double h1lambda = h1r - h1; + const double h0lambda = static_cast(1.) - h1lambda; out_dptr[index] = h0lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1)] + h1lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1 + h1p)]; } @@ -46,16 +46,16 @@ template __global__ void UpsampleLinear1DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, - const int in_height, const float scale_factor, + const int in_height, const double scale_factor, bool align_corners, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h; dy_helper.OffsetToNdIndex(index, n, c, h); - const T h1r = GetLinearInputIndex(h, scale_factor, align_corners); + const double h1r = GetLinearInputIndex(h, scale_factor, align_corners); const int64_t h1 = h1r; const int64_t h1p = (h1 < in_height - 1) ? 1 : 0; - const T h1lambda = h1r - h1; - const T h0lambda = static_cast(1.) - h1lambda; + const double h1lambda = h1r - h1; + const double h0lambda = static_cast(1.) - h1lambda; cuda::atomic::Add(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1), h0lambda * dy_dptr[index]); cuda::atomic::Add(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1 + h1p), @@ -76,7 +76,6 @@ class UpsampleLinear1DGPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float height_scale = ctx->Attr("scale_factor"); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); NdIndexOffsetHelper in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1), @@ -85,6 +84,11 @@ class UpsampleLinear1DGPUKernel final : public user_op::OpKernel { y_tensor->shape().At(2)); const int64_t in_height = x_tensor->shape().At(2); const int64_t out_height = y_tensor->shape().At(2); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("scale_factor"); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + } if (in_height == out_height) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), @@ -112,7 +116,6 @@ class UpsampleLinearGrad1DGPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("scale_factor"); const bool align_corners = ctx->Attr("align_corners"); NdIndexOffsetHelper dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1), @@ -122,6 +125,11 @@ class UpsampleLinearGrad1DGPUKernel final : public user_op::OpKernel { const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); const int64_t in_height = dx_tensor->shape().At(2); const int64_t out_height = dy_tensor->shape().At(2); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("scale_factor"); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + } if (in_height == out_height) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), diff --git a/oneflow/user/kernels/upsample_nearest_kernel.cpp b/oneflow/user/kernels/upsample_nearest_kernel.cpp index b5a0e05e628..4db78f85e5d 100644 --- a/oneflow/user/kernels/upsample_nearest_kernel.cpp +++ b/oneflow/user/kernels/upsample_nearest_kernel.cpp @@ -26,7 +26,7 @@ template static void UpsampleNearest1DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, - const int64_t in_height, const float scale_factor, + const int64_t in_height, const double scale_factor, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h; @@ -40,7 +40,7 @@ template static void UpsampleNearest1DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, - const int64_t in_height, const float scale_factor, + const int64_t in_height, const double scale_factor, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h; @@ -55,7 +55,7 @@ static void UpsampleNearest2DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const int64_t in_width, - const float scale_h, const float scale_w, T* out_dptr) { + const double scale_h, const double scale_w, T* out_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h, w; out_helper.OffsetToNdIndex(index, n, c, h, w); @@ -70,7 +70,7 @@ static void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t dx_height, const int64_t dx_width, - const float scale_h, const float scale_w, T* dx_dptr) { + const double scale_h, const double scale_w, T* dx_dptr) { for (int64_t index = 0; index < elem_cnt; ++index) { int64_t n, c, h, w; dy_helper.OffsetToNdIndex(index, n, c, h, w); @@ -126,13 +126,16 @@ class UpsampleNearest1DCPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float height_scale = ctx->Attr("scale_factor"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); - + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("scale_factor"); const int64_t nbatch = x_tensor->shape().At(0); const int64_t channels = x_tensor->shape().At(1); const int64_t in_height = x_tensor->shape().At(2); const int64_t out_height = y_tensor->shape().At(2); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + } if (in_height == out_height) { memcpy(y_tensor->mut_dptr(), x_tensor->dptr(), @@ -163,13 +166,16 @@ class UpsampleNearestGrad1DCPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("scale_factor"); - + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("scale_factor"); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); const int64_t nbatch = dx_tensor->shape().At(0); const int64_t channels = dx_tensor->shape().At(1); const int64_t in_height = dx_tensor->shape().At(2); const int64_t out_height = dy_tensor->shape().At(2); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + } if (in_height == out_height) { memcpy(dx_tensor->mut_dptr(), dy_tensor->dptr(), sizeof(T) * nbatch * channels * in_height); @@ -209,17 +215,20 @@ class UpsampleNearest2DCPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); const int64_t nbatch = x_tensor->shape().At(0); const int64_t channels = x_tensor->shape().At(1); const int64_t in_height = x_tensor->shape().At(2); const int64_t in_width = x_tensor->shape().At(3); const int64_t out_height = y_tensor->shape().At(2); const int64_t out_width = y_tensor->shape().At(3); - - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } if (in_height == out_height && in_width == out_width) { memcpy(y_tensor->mut_dptr(), x_tensor->dptr(), @@ -250,17 +259,20 @@ class UpsampleNearest2DGradCPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); const int64_t nbatch = dx_tensor->shape().At(0); const int64_t channels = dx_tensor->shape().At(1); const int64_t in_height = dx_tensor->shape().At(2); const int64_t in_width = dx_tensor->shape().At(3); const int64_t out_height = dy_tensor->shape().At(2); const int64_t out_width = dy_tensor->shape().At(3); - - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } if (in_height == out_height && in_width == out_width) { memcpy(dx_tensor->mut_dptr(), dy_tensor->dptr(), @@ -301,10 +313,22 @@ class UpsampleNearest3DCPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_blob = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_blob = ctx->Tensor4ArgNameAndIndex("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const std::vector output_size = ctx->Attr>("output_size"); + double depth_scale = ctx->Attr("depth_scale"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + const int64_t in_depth = x_blob->shape().At(2); + const int64_t in_height = x_blob->shape().At(3); + const int64_t in_width = x_blob->shape().At(4); + const int64_t out_depth = y_blob->shape().At(2); + const int64_t out_height = y_blob->shape().At(3); + const int64_t out_width = y_blob->shape().At(4); const int64_t elem_cnt = y_blob->shape().elem_cnt(); + if (!output_size.empty()) { + depth_scale = static_cast(out_depth) / static_cast(in_depth); + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } NdIndexOffsetHelper in_helper(x_blob->shape().At(0), x_blob->shape().At(1), x_blob->shape().At(2), x_blob->shape().At(3), x_blob->shape().At(4)); @@ -332,10 +356,22 @@ class UpsampleNearestGrad3DCPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_blob->mut_dptr(), 0, dx_blob->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const std::vector output_size = ctx->Attr>("output_size"); + double depth_scale = ctx->Attr("depth_scale"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + const int64_t in_depth = dx_blob->shape().At(2); + const int64_t in_height = dx_blob->shape().At(3); + const int64_t in_width = dx_blob->shape().At(4); + const int64_t out_depth = dy_blob->shape().At(2); + const int64_t out_height = dy_blob->shape().At(3); + const int64_t out_width = dy_blob->shape().At(4); const int64_t elem_cnt = dy_blob->shape().elem_cnt(); + if (!output_size.empty()) { + depth_scale = static_cast(out_depth) / static_cast(in_depth); + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } NdIndexOffsetHelper dy_helper(dy_blob->shape().At(0), dy_blob->shape().At(1), dy_blob->shape().At(2), dy_blob->shape().At(3), dy_blob->shape().At(4)); diff --git a/oneflow/user/kernels/upsample_nearest_kernel.cu b/oneflow/user/kernels/upsample_nearest_kernel.cu index 2d769085335..a9fe4d557b9 100644 --- a/oneflow/user/kernels/upsample_nearest_kernel.cu +++ b/oneflow/user/kernels/upsample_nearest_kernel.cu @@ -27,7 +27,7 @@ template __global__ void UpsampleNearest1DForward(const int64_t elem_cnt, const T* in_dptr, NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, - const int64_t in_height, const float scale_factor, + const int64_t in_height, const double scale_factor, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h; @@ -41,7 +41,7 @@ template __global__ void UpsampleNearest1DBackward(const int64_t elem_cnt, const T* dy_dptr, NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, - const int64_t in_height, const float scale_factor, + const int64_t in_height, const double scale_factor, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h; @@ -56,7 +56,7 @@ __global__ void UpsampleNearest2DForward(const int64_t elem_cnt, const T* in_dpt NdIndexOffsetHelper in_helper, NdIndexOffsetHelper out_helper, const int64_t in_height, const int64_t in_width, - const float scale_h, const float scale_w, T* out_dptr) { + const double scale_h, const double scale_w, T* out_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h, w; out_helper.OffsetToNdIndex(index, n, c, h, w); @@ -71,7 +71,7 @@ __global__ void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dp NdIndexOffsetHelper dy_helper, NdIndexOffsetHelper dx_helper, const int64_t dx_height, const int64_t dx_width, - const float scale_h, const float scale_w, T* dx_dptr) { + const double scale_h, const double scale_w, T* dx_dptr) { CUDA_1D_KERNEL_LOOP(index, elem_cnt) { int64_t n, c, h, w; dy_helper.OffsetToNdIndex(index, n, c, h, w); @@ -128,10 +128,14 @@ class UpsampleNearest1DGPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float height_scale = ctx->Attr("scale_factor"); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("scale_factor"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); const int64_t in_height = x_tensor->shape().At(2); const int64_t out_height = y_tensor->shape().At(2); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + } if (in_height == out_height) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), @@ -163,10 +167,14 @@ class UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("scale_factor"); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("scale_factor"); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); const int64_t in_height = dx_tensor->shape().At(2); const int64_t out_height = dy_tensor->shape().At(2); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + } if (in_height == out_height) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), @@ -208,14 +216,19 @@ class UpsampleNearest2DGPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); - const int64_t in_height = x_tensor->shape().At(2); const int64_t in_width = x_tensor->shape().At(3); const int64_t out_height = y_tensor->shape().At(2); const int64_t out_width = y_tensor->shape().At(3); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } + if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), y_tensor->mut_dptr(), x_tensor->dptr(), @@ -248,13 +261,18 @@ class UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const std::vector output_size = ctx->Attr>("output_size"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); const int64_t in_height = dx_tensor->shape().At(2); const int64_t in_width = dx_tensor->shape().At(3); const int64_t out_height = dy_tensor->shape().At(2); const int64_t out_width = dy_tensor->shape().At(3); + if (!output_size.empty()) { + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } if (in_height == out_height && in_width == out_width) { Memcpy( ctx->stream(), dx_tensor->mut_dptr(), dy_tensor->dptr(), @@ -297,10 +315,22 @@ class UpsampleNearest3DGPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - const float depth_scale = ctx->Attr("depth_scale"); + const std::vector output_size = ctx->Attr>("output_size"); + double depth_scale = ctx->Attr("depth_scale"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + const int64_t in_depth = x_tensor->shape().At(2); + const int64_t in_height = x_tensor->shape().At(3); + const int64_t in_width = x_tensor->shape().At(4); + const int64_t out_depth = y_tensor->shape().At(2); + const int64_t out_height = y_tensor->shape().At(3); + const int64_t out_width = y_tensor->shape().At(4); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); + if (!output_size.empty()) { + depth_scale = static_cast(out_depth) / static_cast(in_depth); + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } NdIndexOffsetHelper in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1), x_tensor->shape().At(2), x_tensor->shape().At(3), x_tensor->shape().At(4)); @@ -329,10 +359,22 @@ class UpsampleNearestGrad3DGPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - const float depth_scale = ctx->Attr("depth_scale"); + const std::vector output_size = ctx->Attr>("output_size"); + double depth_scale = ctx->Attr("depth_scale"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + const int64_t in_depth = dx_tensor->shape().At(2); + const int64_t in_height = dx_tensor->shape().At(3); + const int64_t in_width = dx_tensor->shape().At(4); + const int64_t out_depth = dy_tensor->shape().At(2); + const int64_t out_height = dy_tensor->shape().At(3); + const int64_t out_width = dy_tensor->shape().At(4); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); + if (!output_size.empty()) { + depth_scale = static_cast(out_depth) / static_cast(in_depth); + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } NdIndexOffsetHelper dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1), dy_tensor->shape().At(2), dy_tensor->shape().At(3), dy_tensor->shape().At(4)); diff --git a/oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp b/oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp index 4873856f701..1872a901802 100644 --- a/oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp +++ b/oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp @@ -124,9 +124,6 @@ class UpsampleTrilinear3DCPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); NdIndexOffsetHelper in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1), @@ -144,6 +141,16 @@ class UpsampleTrilinear3DCPUKernel final : public user_op::OpKernel { const int64_t out_height = y_tensor->shape().At(3); const int64_t out_width = y_tensor->shape().At(4); + const std::vector output_size = ctx->Attr>("output_size"); + double depth_scale = ctx->Attr("depth_scale"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + if (!output_size.empty()) { + depth_scale = static_cast(out_depth) / static_cast(in_depth); + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } + const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale); const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); @@ -169,9 +176,6 @@ class UpsampleTrilinearGrad3DCPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); NdIndexOffsetHelper dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1), @@ -189,6 +193,16 @@ class UpsampleTrilinearGrad3DCPUKernel final : public user_op::OpKernel { const int64_t out_height = dy_tensor->shape().At(3); const int64_t out_width = dy_tensor->shape().At(4); + const std::vector output_size = ctx->Attr>("output_size"); + double depth_scale = ctx->Attr("depth_scale"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + if (!output_size.empty()) { + depth_scale = static_cast(out_depth) / static_cast(in_depth); + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } + const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale); const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); diff --git a/oneflow/user/kernels/upsample_trilinear_3d_kernel.cu b/oneflow/user/kernels/upsample_trilinear_3d_kernel.cu index 5235e787f0e..7ce58e53027 100644 --- a/oneflow/user/kernels/upsample_trilinear_3d_kernel.cu +++ b/oneflow/user/kernels/upsample_trilinear_3d_kernel.cu @@ -128,9 +128,6 @@ class UpsampleTrilinear3DGPUKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = y_tensor->shape().elem_cnt(); NdIndexOffsetHelper in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1), @@ -148,6 +145,16 @@ class UpsampleTrilinear3DGPUKernel final : public user_op::OpKernel { const int64_t out_height = y_tensor->shape().At(3); const int64_t out_width = y_tensor->shape().At(4); + const std::vector output_size = ctx->Attr>("output_size"); + double depth_scale = ctx->Attr("depth_scale"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + if (!output_size.empty()) { + depth_scale = static_cast(out_depth) / static_cast(in_depth); + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } + const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale); const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); @@ -174,9 +181,6 @@ class UpsampleTrilinearGrad3DGPUKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx_tensor->mut_dptr(), 0, dx_tensor->shape().elem_cnt() * sizeof(T)); const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); const bool align_corners = ctx->Attr("align_corners"); const int64_t elem_cnt = dy_tensor->shape().elem_cnt(); NdIndexOffsetHelper dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1), @@ -194,6 +198,16 @@ class UpsampleTrilinearGrad3DGPUKernel final : public user_op::OpKernel { const int64_t out_height = dy_tensor->shape().At(3); const int64_t out_width = dy_tensor->shape().At(4); + const std::vector output_size = ctx->Attr>("output_size"); + double depth_scale = ctx->Attr("depth_scale"); + double height_scale = ctx->Attr("height_scale"); + double width_scale = ctx->Attr("width_scale"); + if (!output_size.empty()) { + depth_scale = static_cast(out_depth) / static_cast(in_depth); + height_scale = static_cast(out_height) / static_cast(in_height); + width_scale = static_cast(out_width) / static_cast(in_width); + } + const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale); const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale); const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale); diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index 57efaebc62e..e1d05c1b097 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -25,13 +25,18 @@ namespace oneflow { /*static*/ Maybe UpsampleLinear1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float scale_factor = ctx->Attr("scale_factor"); + const double scale_factor = ctx->Attr("scale_factor"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 3) << "upsample_linear_1d only supports NCH"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(scale_factor * x_desc.shape().At(2))}); + std::vector output_size = ctx->Attr>("output_size"); + if (output_size.size()) { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_size[0]}); + } else { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(scale_factor * x_desc.shape().At(2))}); + } return Maybe::Ok(); } /*static*/ Maybe UpsampleLinear1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -49,12 +54,17 @@ namespace oneflow { /*static*/ Maybe UpsampleNearest1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float scale_factor = ctx->Attr("scale_factor"); + const double scale_factor = ctx->Attr("scale_factor"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 3) << "upsample_nearest_1d only supports NCH"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(scale_factor * x_desc.shape().At(2))}); + std::vector output_size = ctx->Attr>("output_size"); + if (output_size.size()) { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_size[0]}); + } else { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(scale_factor * x_desc.shape().At(2))}); + } return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -72,14 +82,20 @@ namespace oneflow { /*static*/ Maybe UpsampleNearest2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const double height_scale = ctx->Attr("height_scale"); + const double width_scale = ctx->Attr("width_scale"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 4) << "upsample_nearest_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(height_scale * x_desc.shape().At(2)), - static_cast(width_scale * x_desc.shape().At(3))}); + std::vector output_size = ctx->Attr>("output_size"); + if (output_size.size()) { + *y_desc->mut_shape() = + Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_size[0], output_size[1]}); + } else { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + } return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -97,14 +113,20 @@ namespace oneflow { /*static*/ Maybe UpsampleBilinear2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const double height_scale = ctx->Attr("height_scale"); + const double width_scale = ctx->Attr("width_scale"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 4) << "upsample_bilinear_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(height_scale * x_desc.shape().At(2)), - static_cast(width_scale * x_desc.shape().At(3))}); + std::vector output_size = ctx->Attr>("output_size"); + if (output_size.size()) { + *y_desc->mut_shape() = + Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_size[0], output_size[1]}); + } else { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + } return Maybe::Ok(); } /*static*/ Maybe UpsampleBilinear2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -122,14 +144,20 @@ namespace oneflow { /*static*/ Maybe UpsampleBicubic2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const double height_scale = ctx->Attr("height_scale"); + const double width_scale = ctx->Attr("width_scale"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 4) << "upsample_bicubic_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(height_scale * x_desc.shape().At(2)), - static_cast(width_scale * x_desc.shape().At(3))}); + std::vector output_size = ctx->Attr>("output_size"); + if (output_size.size()) { + *y_desc->mut_shape() = + Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_size[0], output_size[1]}); + } else { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + } return Maybe::Ok(); } /*static*/ Maybe UpsampleBicubic2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -147,16 +175,22 @@ namespace oneflow { /*static*/ Maybe UpsampleNearest3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const double depth_scale = ctx->Attr("depth_scale"); + const double height_scale = ctx->Attr("height_scale"); + const double width_scale = ctx->Attr("width_scale"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 5) << "upsample_nearest_3d only supports NCDHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(depth_scale * x_desc.shape().At(2)), - static_cast(height_scale * x_desc.shape().At(3)), - static_cast(width_scale * x_desc.shape().At(4))}); + std::vector output_size = ctx->Attr>("output_size"); + if (output_size.size()) { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_size[0], + output_size[1], output_size[2]}); + } else { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(depth_scale * x_desc.shape().At(2)), + static_cast(height_scale * x_desc.shape().At(3)), + static_cast(width_scale * x_desc.shape().At(4))}); + } return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -174,16 +208,22 @@ namespace oneflow { /*static*/ Maybe UpsampleTrilinear3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); + const double depth_scale = ctx->Attr("depth_scale"); + const double height_scale = ctx->Attr("height_scale"); + const double width_scale = ctx->Attr("width_scale"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 5) << "upsample_trilinear_3d only supports NCDHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(depth_scale * x_desc.shape().At(2)), - static_cast(height_scale * x_desc.shape().At(3)), - static_cast(width_scale * x_desc.shape().At(4))}); + std::vector output_size = ctx->Attr>("output_size"); + if (output_size.size()) { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_size[0], + output_size[1], output_size[2]}); + } else { + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(depth_scale * x_desc.shape().At(2)), + static_cast(height_scale * x_desc.shape().At(3)), + static_cast(width_scale * x_desc.shape().At(4))}); + } return Maybe::Ok(); } /*static*/ Maybe UpsampleTrilinear3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -387,8 +427,9 @@ REGISTER_USER_OP_GRAD("upsample_linear_1d") .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Input("x", op.input("x", 0)) .Output("dx") - .Attr("scale_factor", op.attr("scale_factor")) + .Attr("scale_factor", op.attr("scale_factor")) .Attr("align_corners", op.attr("align_corners")) + .Attr("output_size", op.attr>("output_size")) .Attr("data_format", op.attr("data_format")) .Build(); op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); @@ -407,7 +448,8 @@ REGISTER_USER_OP_GRAD("upsample_nearest_1d") .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Input("x", op.input("x", 0)) .Output("dx") - .Attr("scale_factor", op.attr("scale_factor")) + .Attr("scale_factor", op.attr("scale_factor")) + .Attr("output_size", op.attr>("output_size")) .Attr("data_format", op.attr("data_format")) .Build(); op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); @@ -426,8 +468,9 @@ REGISTER_USER_OP_GRAD("upsample_nearest_2d") .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Input("x", op.input("x", 0)) .Output("dx") - .Attr("height_scale", op.attr("height_scale")) - .Attr("width_scale", op.attr("width_scale")) + .Attr("height_scale", op.attr("height_scale")) + .Attr("width_scale", op.attr("width_scale")) + .Attr("output_size", op.attr>("output_size")) .Attr("data_format", op.attr("data_format")) .Build(); op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); @@ -446,9 +489,10 @@ REGISTER_USER_OP_GRAD("upsample_bilinear_2d") .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Input("x", op.input("x", 0)) .Output("dx") - .Attr("height_scale", op.attr("height_scale")) - .Attr("width_scale", op.attr("width_scale")) + .Attr("height_scale", op.attr("height_scale")) + .Attr("width_scale", op.attr("width_scale")) .Attr("align_corners", op.attr("align_corners")) + .Attr("output_size", op.attr>("output_size")) .Attr("data_format", op.attr("data_format")) .Build(); op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); @@ -467,9 +511,10 @@ REGISTER_USER_OP_GRAD("upsample_bicubic_2d") .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Input("x", op.input("x", 0)) .Output("dx") - .Attr("height_scale", op.attr("height_scale")) - .Attr("width_scale", op.attr("width_scale")) + .Attr("height_scale", op.attr("height_scale")) + .Attr("width_scale", op.attr("width_scale")) .Attr("align_corners", op.attr("align_corners")) + .Attr("output_size", op.attr>("output_size")) .Attr("data_format", op.attr("data_format")) .Build(); op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); @@ -488,9 +533,10 @@ REGISTER_USER_OP_GRAD("upsample_nearest_3d") .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Input("x", op.input("x", 0)) .Output("dx") - .Attr("depth_scale", op.attr("depth_scale")) - .Attr("height_scale", op.attr("height_scale")) - .Attr("width_scale", op.attr("width_scale")) + .Attr("depth_scale", op.attr("depth_scale")) + .Attr("height_scale", op.attr("height_scale")) + .Attr("width_scale", op.attr("width_scale")) + .Attr("output_size", op.attr>("output_size")) .Attr("data_format", op.attr("data_format")) .Build(); op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); @@ -509,10 +555,11 @@ REGISTER_USER_OP_GRAD("upsample_trilinear_3d") .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Input("x", op.input("x", 0)) .Output("dx") - .Attr("depth_scale", op.attr("depth_scale")) - .Attr("height_scale", op.attr("height_scale")) - .Attr("width_scale", op.attr("width_scale")) + .Attr("depth_scale", op.attr("depth_scale")) + .Attr("height_scale", op.attr("height_scale")) + .Attr("width_scale", op.attr("width_scale")) .Attr("align_corners", op.attr("align_corners")) + .Attr("output_size", op.attr>("output_size")) .Attr("data_format", op.attr("data_format")) .Build(); op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); diff --git a/python/oneflow/nn/modules/interpolate.py b/python/oneflow/nn/modules/interpolate.py index ac977712e10..e9b168ae060 100644 --- a/python/oneflow/nn/modules/interpolate.py +++ b/python/oneflow/nn/modules/interpolate.py @@ -71,6 +71,19 @@ def __init__( raise ValueError('interpolation "nearest" does not support align_corners.') def forward(self, x): + if len(x.shape) == 3 and self.mode == "bilinear": + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") + if len(x.shape) == 3 and self.mode == "trilinear": + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") + if len(x.shape) == 4 and self.mode == "linear": + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") + if len(x.shape) == 4 and self.mode == "trilinear": + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") + if len(x.shape) == 5 and self.mode == "linear": + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") + if len(x.shape) == 5 and self.mode == "bilinear": + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") + dim = len(x.shape) - 2 if self.size is not None and self.scale_factor is not None: raise ValueError("only one of size or scale_factor should be defined") @@ -121,13 +134,17 @@ def forward(self, x): scale_factors.append(output_size[i] / x.shape[2 + i]) if len(x.shape) == 3 and self.mode == "nearest": return flow._C.upsample_nearest_1d( - x, scale_factor=scale_factors[0], data_format="channels_first" + x, + scale_factor=scale_factors[0], + output_size=output_size, + data_format="channels_first", ) if len(x.shape) == 4 and self.mode == "nearest": return flow._C.upsample_nearest_2d( x, height_scale=scale_factors[0], width_scale=scale_factors[1], + output_size=output_size, data_format="channels_first", ) if len(x.shape) == 5 and self.mode == "nearest": @@ -136,6 +153,7 @@ def forward(self, x): depth_scale=scale_factors[0], height_scale=scale_factors[1], width_scale=scale_factors[2], + output_size=output_size, data_format="channels_first", ) if len(x.shape) == 3 and self.mode == "area": @@ -153,6 +171,7 @@ def forward(self, x): x, scale_factor=scale_factors[0], align_corners=self.align_corners, + output_size=output_size, data_format="channels_first", ) if len(x.shape) == 4 and self.mode == "bilinear": @@ -162,6 +181,7 @@ def forward(self, x): height_scale=scale_factors[0], width_scale=scale_factors[1], align_corners=self.align_corners, + output_size=output_size, data_format="channels_first", ) if len(x.shape) == 4 and self.mode == "bicubic": @@ -171,6 +191,7 @@ def forward(self, x): height_scale=scale_factors[0], width_scale=scale_factors[1], align_corners=self.align_corners, + output_size=output_size, data_format="channels_first", ) if len(x.shape) == 5 and self.mode == "trilinear": @@ -181,9 +202,16 @@ def forward(self, x): height_scale=scale_factors[1], width_scale=scale_factors[2], align_corners=self.align_corners, + output_size=output_size, data_format="channels_first", ) + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area" + " (got {})".format(len(x.shape), self.mode) + ) + def interpolate( input, diff --git a/python/oneflow/nn/modules/upsampling.py b/python/oneflow/nn/modules/upsampling.py index 4703505364e..ff22f9cd125 100644 --- a/python/oneflow/nn/modules/upsampling.py +++ b/python/oneflow/nn/modules/upsampling.py @@ -16,7 +16,6 @@ from typing import Optional, Tuple, Union import oneflow as flow -from oneflow.framework.tensor import register_tensor_op from oneflow.nn.module import Module diff --git a/python/oneflow/test/modules/test_upsample.py b/python/oneflow/test/modules/test_upsample.py index 5af1b304c2f..9cc0d679813 100644 --- a/python/oneflow/test/modules/test_upsample.py +++ b/python/oneflow/test/modules/test_upsample.py @@ -387,8 +387,8 @@ def test_upsample2d_nearest(test_case): # The forward and backward result in cpu and cuda of bilinear interpolate operation in PyTorch is different # in some corner cases. OneFlow has the same cpu and cuda results with PyTorch's cuda result. # So here we only test cuda device forward result. - @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(n=10, auto_backward=False, atol=1e-8) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample2d_bilinear(test_case): x = random_tensor(ndim=4).to("cuda") x = x.permute(1, 3, 0, 2) @@ -400,8 +400,8 @@ def test_upsample2d_bilinear(test_case): y = m(x) return y - @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @autotest(atol=1e-5) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_upsample2d_bicubic(test_case): x = random_tensor(ndim=4, dim0=16, dim1=8).to("cuda") m = torch.nn.Upsample( @@ -412,6 +412,63 @@ def test_upsample2d_bicubic(test_case): y = m(x) return y + @autotest(n=5, atol=1e-5) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_upsample1d_nearest_output_size(test_case): + x = random_tensor(ndim=3, dim0=1, dim1=2, dim2=12).to("cuda") + m = torch.nn.Upsample(size=(13), mode="nearest") + y = m(x) + return y + + @autotest(n=5, atol=1e-5) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_upsample2d_nearest_output_size(test_case): + x = random_tensor(ndim=4, dim0=1, dim1=1, dim2=1, dim3=937).to("cuda") + m = torch.nn.Upsample(size=(1, 30), mode="nearest") + y = m(x) + return y + + @autotest(n=5, atol=1e-5) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_upsample3d_nearest_output_size(test_case): + x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=6, dim3=12, dim4=6).to("cuda") + m = torch.nn.Upsample(size=(8, 10, 7), mode="nearest") + y = m(x) + return y + + @autotest(n=5, atol=1e-5) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_upsample1d_linear_output_size(test_case): + device = random_device() + x = random_tensor(ndim=3, dim0=1, dim1=2, dim2=12).to(device) + m = torch.nn.Upsample(size=(13), mode="linear") + y = m(x) + return y + + @autotest(n=5, atol=1e-5) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_upsample2d_bilinear_output_size(test_case): + x = random_tensor(ndim=4, dim0=1, dim1=1, dim2=12, dim3=21).to("cuda") + m = torch.nn.Upsample(size=(14, 19), mode="bilinear") + y = m(x) + return y + + @autotest(n=5, atol=1e-5) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_upsample2d_bicubic_output_size(test_case): + x = random_tensor(ndim=4, dim0=1, dim1=2, dim2=12, dim3=21).to("cuda") + m = torch.nn.Upsample(size=(14, 19), mode="bicubic") + y = m(x) + return y + + @autotest(n=5, atol=1e-5) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_upsample3d_trilinear_output_size(test_case): + x = random_tensor(ndim=5, dim0=1, dim1=2, dim2=1, dim3=12, dim4=17).to("cuda") + m = torch.nn.Upsample(size=(1, 14, 23), mode="trilinear") + y = m(x) + return y + if __name__ == "__main__": unittest.main()