From 36b411f8e55a74d280e09d8050219a4dcb633b0b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 19 Feb 2021 19:15:31 +0800 Subject: [PATCH] [AutoScheduler] Fix the type inference for conv3d (#7475) --- src/relay/op/nn/convolution.h | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index c08d3553e4ccd..5b4850ec6653f 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_ #define TVM_RELAY_OP_NN_CONVOLUTION_H_ +#include #include #include @@ -369,7 +370,18 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, } else { // use weight to infer the conv shape. if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + + Array wshape; + if (param->auto_scheduler_rewritten_layout.size() == 0) { + wshape = weight->shape; + } else { + // works for the default kernel layout "DHWIO" + ICHECK_EQ(param->kernel_layout, "DHWIO"); + wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, + {"rd", "rh", "rw", "rc", "cc"}); + } + + wshape = trans_kernel_layout.ForwardShape(wshape); if (param->kernel_size.defined()) { ICHECK_EQ(param->kernel_size.size(), 3); // check the size