diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 59548634fc4a..104cc843e398 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1688,9 +1688,10 @@ def index_map( mapping: Callable, *, inverse_index_map: Optional[Callable] = None, + index_dtype: str = "int64", ) -> IndexMap: """Create a TIR Index mapping""" - return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map) + return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map, index_dtype=index_dtype) def target( diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 202702d78746..344e27456551 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -121,6 +121,10 @@ InferLayoutOutput InferLayoutResize2d(const Call& call, } else { // We dont have a desired layout for resize2d, propagate from the input instead. data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + // Not handling sub indexing now. + if (data_layout->layout.ndim() != data_layout->layout.ndim_primal()) { + data_layout = LayoutDecision(InitialLayout(4)); + } new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), data_layout->layout).name(); } return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout}, diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 7c7718b837d9..3ebbc544f470 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -308,30 +308,59 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, Layout desired_data_layout = (*it).second[0]; Layout desired_weight_layout = (*it).second[1]; Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; - ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) - << "Axis swap only"; - ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) - << "Axis swap only"; - data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout); - weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout); - output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout); - new_attrs->data_layout = (*it).second[0]; - new_attrs->kernel_layout = (*it).second[1]; - new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - } else { - // We don't have a desired layout for conv2d. - // We can just propagate the layout from the input. - data_layout = GetLayoutDecision(var_layout_map, call->args[0]); - weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); - output_layout = data_layout; - new_attrs->data_layout = - TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); - new_attrs->kernel_layout = - TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name(); - new_attrs->out_layout = - TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name(); + tir::Layout input_layout(attrs->data_layout, DataType::Int(64)); + tir::Layout kernel_layout(attrs->kernel_layout, DataType::Int(64)); + tir::Layout out_layout(attrs->out_layout, DataType::Int(64)); + + if ((desired_data_layout.ndim() == input_layout.ndim()) && + (desired_weight_layout.ndim() == kernel_layout.ndim()) && + (desired_output_layout.ndim() == out_layout.ndim())) { + // Just a transpose + data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout); + weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout); + output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } else { + // Layout Transform + auto data_si = GetStructInfo(call->args[0]); + auto kernel_si = GetStructInfo(call->args[1]); + TensorStructInfo data_sinfo = data_si.as().value(); + TensorStructInfo kernel_sinfo = kernel_si.as().value(); + Optional data_shape = GetRef(data_sinfo->shape.as()); + Optional kernel_shape = GetRef(kernel_sinfo->shape.as()); + + bool can_data_proved = + CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); + bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, desired_weight_layout, + kernel_shape.value()->values); + + if (can_data_proved && can_kernel_proved) { + data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout, desired_data_layout); + weight_layout = + TransposeSubLayoutLike(InitialLayout(4), kernel_layout, desired_weight_layout); + output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } + } } + + // We don't have a desired layout for conv2d or desired layouts not compatible. + // We can just propagate the layout from the input. + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; + new_attrs->data_layout = + TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); + new_attrs->kernel_layout = + TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name(); + new_attrs->out_layout = + TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name(); return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 6f8a90e3cba9..7eccf47e4b06 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -93,6 +93,16 @@ InferLayoutOutput InferLayoutSoftmax(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + + // TODO(Siva): We could handle if the axis is not the sub indexed one. + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_sinfo->ndim; + layout = LayoutDecision(InitialLayout(ndim)); + } + ObjectPtr new_attrs = make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, attrs->axis); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -290,8 +300,18 @@ InferLayoutOutput InferLayoutBatchNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + + // While dealing with sub layouts, its adviced to deal with batchnorm + // on other ways like decomposing or fusion methods. + // This handling is fail safe fallback. + const auto* input_sinfo = GetStructInfoAs(call->args[0]); + int ndim = input_sinfo->ndim; + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + layout = LayoutDecision(InitialLayout(ndim)); + } + ObjectPtr new_attrs = make_object(*attrs); - new_attrs->axis = FindAxis(layout->layout, attrs->axis); + new_attrs->axis = FindAxis(layout->layout, (attrs->axis + ndim) % ndim); return InferLayoutOutput( {layout, initial_layouts[1], initial_layouts[2], initial_layouts[3], initial_layouts[4]}, {{layout, initial_layouts[3], initial_layouts[4]}}, Attrs(new_attrs)); @@ -353,9 +373,11 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call, LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = make_object(*attrs); + const auto* input_sinfo = GetStructInfoAs(call->args[0]); + int ndim = input_sinfo->ndim; std::vector new_axis; for (const auto& axis : attrs->axes) { - new_axis.push_back(FindAxis(layout->layout, axis->value)); + new_axis.push_back(FindAxis(layout->layout, (axis->value + ndim) % ndim)); } new_attrs->axes = std::move(new_axis); return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 7fdcefc00bd0..565e6a00c60d 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -234,6 +234,23 @@ InferLayoutOutput InferLayoutPool2d(const Call& call, LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = make_object(*attrs); + + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + tir::Layout in_layout(attrs->layout, DataType::Int(64)); + auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); + auto data_si = GetStructInfo(call->args[0]); + TensorStructInfo data_sinfo = data_si.as().value(); + Optional data_shape = GetRef(data_sinfo->shape.as()); + if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { + // Not handling out_layout being different from in_layout now. Any use case ? + new_attrs->layout = desired_layout.name(); + new_attrs->out_layout = desired_layout.name(); + return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); + } else { + layout = InitialLayout(4); + } + } + new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -583,6 +600,21 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call, LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = make_object(*attrs); + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + tir::Layout in_layout(attrs->layout, DataType::Int(64)); + auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); + auto data_si = GetStructInfo(call->args[0]); + TensorStructInfo data_sinfo = data_si.as().value(); + Optional data_shape = GetRef(data_sinfo->shape.as()); + if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { + // Not handling out_layout being different from in_layout now. Any use case ? + new_attrs->layout = desired_layout.name(); + new_attrs->out_layout = desired_layout.name(); + return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); + } else { + layout = InitialLayout(4); + } + } new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 56bf708f5e06..f9c1ece38c18 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -185,5 +185,27 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call, return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs)); } +bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout, + Array shape) { + bool can_prove = true; + try { + tir::BijectiveLayout todesired(input_layout, desired_layout); + Array desired_shape = todesired.ForwardShape(shape); + Array back_shape = todesired.BackwardShape(desired_shape); + arith::Analyzer analyzer; + for (size_t i = 0; i < shape.size(); ++i) { + if (tir::is_const_int(shape[i])) { + if (!analyzer.CanProveEqual(shape[i], back_shape[i])) { + can_prove = false; + break; + } + } + } + } catch (std::exception& err) { + return false; + } + return can_prove; +} + } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index ed6725e27012..eb9caae4b9e1 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -570,6 +570,16 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind */ Array GetCallArgs(const Call& call); +/** + * \brief Checks the given shape can be proved from the source layout to dst layout + * \param input_layout is the layout of given shape + * \param desired_layout is the target layout the shape to be transformed + * \param shape array + * \return true or false depending on the compatibility + */ +bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout, + Array shape); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index bd4c681c7925..4a63993d507c 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -155,6 +155,21 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call, ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) << "Unknown dim tensors should not be handled by this function"; + Optional shape1 = GetRef(x1_sinfo->shape.as()); + Optional shape2 = GetRef(x2_sinfo->shape.as()); + // Lets handle sub indexing as long as primal dims are matching + if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { + if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { + if (CanProveLayoutTransform(layout2->layout, layout1->layout, shape2.value()->values)) { + return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); + } + } else if (shape1.defined()) { + if (CanProveLayoutTransform(layout1->layout, layout2->layout, shape1.value()->values)) { + return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + } + } + } + if (x1_sinfo->ndim <= x2_sinfo->ndim) { if (x1_sinfo->ndim == 0) { LayoutDecision out_layout = layout2; diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index e62dbe89d08a..8a8dc6de40e4 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -438,6 +438,10 @@ InferLayoutOutput InferLayoutStridedSlice(const Call& call, << "but expression " << call << " has argument " << call->args[0] << " of unknown dimensionality."; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + // Can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(tensor_sinfo->ndim)); + } auto opt_axes_tuple = UnpackTupleOfPrimValue(GetStructInfo(call->args[1])); CHECK(opt_axes_tuple) << "Layout inference of " << call->op diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f64b3ec4f979..452b1f223a80 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -393,6 +393,10 @@ InferLayoutOutput InferLayoutExpandDims(const Call& call, LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); int ndim = tensor_sinfo->ndim; + // Can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } int n_new_dim = attrs->axis.size(); int output_ndim = ndim + n_new_dim; std::vector is_new_dim(output_ndim, false); @@ -622,6 +626,12 @@ InferLayoutOutput InferLayoutPermuteDims(const Call& call, int ndim = tensor_sinfo->ndim; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + + // permute_dims can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } + Array order; if (attrs->axes.defined()) { order = attrs->axes.value(); @@ -942,10 +952,33 @@ InferLayoutOutput InferLayoutSplit(const Call& call, ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); - new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis); StructInfo out_sinfo = InferStructInfoSplit(call, BlockBuilder::Create(IRModule())); const auto* out_tuple = out_sinfo.as(); + + /* + * Fallback if the outputs can't be represented in input sub indexed layout + * This can happen after sub indexing, if we can't split the corresponding primal axis + */ + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + for (const auto& si : out_tuple->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo" + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + Optional shape_expr = GetRef(sinfo->shape.as()); + CHECK(shape_expr.defined()); + auto shape_arr = shape_expr.value(); + if (!CanProveLayoutTransform(InitialLayout(tensor_sinfo->ndim), existing_layout->layout, + shape_arr->values)) { + existing_layout = InitialLayout(tensor_sinfo->ndim); + break; + } + } + } + + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis); ICHECK(out_tuple != nullptr) << "Invalid Call"; NLayout tuple_layouts(Array(out_tuple->fields.size(), existing_layout)); return InferLayoutOutput({existing_layout}, {tuple_layouts}, Attrs(new_attrs)); @@ -1092,6 +1125,10 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call, } LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + // Can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); Array new_axis; for (size_t i = 0; i < new_axis_str.size(); ++i) { diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 24ccde4559e6..606001dfbf3f 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -108,25 +108,35 @@ InferLayoutOutput InferLayoutStatistical(const Call& call, std::string axis_str(ndim, '0'); for (const auto& iter : axis) { - axis_str[(iter->value + ndim) % ndim] = '1'; + axis_str[(iter->value + ndim) % ndim] = '#'; } for (int i = 0, j = 0; i < ndim; ++i) { - if (axis_str[i] != '1') { + if (axis_str[i] != '#') { axis_str[i] = 'A' + j++; } } LayoutDecision exisiting_layout = GetLayoutDecision(var_layout_map, call->args[0]); - String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), exisiting_layout->layout); + auto new_axis_str = TransposeSubLayoutStrLike(axis_str, InitialLayout(ndim).name(), + exisiting_layout->layout.name()); + std::string output_layout_ref = new_axis_str; + new_axis_str.erase(std::remove_if(new_axis_str.begin(), new_axis_str.end(), + [](unsigned char c) { return std::isdigit(c); }), + new_axis_str.end()); + Array new_axis; for (size_t i = 0; i < new_axis_str.size(); ++i) { - if (new_axis_str.at(i) == '1') { + if (new_axis_str.at(i) == '#') { new_axis.push_back(Integer(i)); } } - std::string output_layout = new_axis_str; - output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), '1'), - output_layout.end()); + std::string output_layout; + for (size_t i = 0; i < output_layout_ref.length(); ++i) { + if ((isdigit(output_layout_ref[i]) && (output_layout_ref[i + 1] == '#')) || + (output_layout_ref[i] == '#')) + continue; + output_layout.push_back(output_layout_ref[i]); + } ObjectPtr new_attrs = make_object(*attrs); new_attrs->axis = new_axis; diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 3798bdff351d..0898af1d7636 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -21,10 +21,12 @@ * \brief Automatic layout conversion pass, especially for axis swapping. */ +#include #include #include #include #include +#include #include "../op/tensor/manipulate.h" #include "infer_layout_utils.h" @@ -33,6 +35,7 @@ namespace tvm { namespace relax { +using tir::IndexMap; using tir::Layout; /*! @@ -87,6 +90,22 @@ class LayoutConvertMutator : public ExprMutator { return ret; } + IndexMap LayoutIndexMap(int ndim, const Layout& src_layout, const Layout& desired_layout) { + tir::BijectiveLayout todesired(src_layout, desired_layout); + Optional inverse_index_map; + + Array initial_indices; + Array initial_indices_expr; + initial_indices.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + auto var = tvm::tir::Var("i" + std::to_string(i), DataType::Int(32)); + initial_indices.push_back(var); + initial_indices_expr.push_back(var); + } + Array desired_shape = todesired.ForwardIndex(initial_indices_expr); + return IndexMap(initial_indices, desired_shape, std::move(inverse_index_map)); + } + Expr RewriteExpr(const Expr& expr, const NLayout& to) { auto fvisitleaf = [&](const Expr& expr, std::array layouts) -> Expr { NLayout from = layouts[0], to = layouts[1]; @@ -97,9 +116,24 @@ class LayoutConvertMutator : public ExprMutator { << "Cannot convert when exactly one of the layouts is unknown"; const auto* tensor = GetStructInfoAs(expr); ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr; - Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout, - from.LeafValue()->layout, to.LeafValue()->layout); - return permute_dims(expr, LayoutToIntegers(axes)); + + if (from.LeafValue()->layout.ndim() == to.LeafValue()->layout.ndim()) { + Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout, + from.LeafValue()->layout, to.LeafValue()->layout); + return permute_dims(expr, LayoutToIntegers(axes)); + } else { + auto index_map = LayoutIndexMap(from.LeafValue()->layout.ndim(), from.LeafValue()->layout, + to.LeafValue()->layout); + ObjectPtr attrs = make_object(); + Array axis_separator; + Array input_axis_separator; + attrs->index_map = std::move(Downcast(LoadJSON(SaveJSON(index_map)))); + attrs->axis_separators = std::move(axis_separator); + attrs->input_axis_separators = std::move(input_axis_separator); + const Op& layout_transform_op_ = Op::Get("relax.layout_transform"); + auto ret_expr = Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); + return ret_expr; + } }; return TransformTupleLeaf( VarReplacer::Replace(expr, var_remap_), diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc index d746f9394a75..aca048820996 100644 --- a/src/relax/transform/infer_layout_utils.cc +++ b/src/relax/transform/infer_layout_utils.cc @@ -27,6 +27,35 @@ namespace relax { using tir::IterVar; using tir::Layout; +std::string TransposeSubLayoutStrLike(const std::string ref_str, const std::string& src_str, + const std::string& desired_str) { + std::string out; + for (const char& c : desired_str) { + if (std::isupper(c)) { + auto res = src_str.find(c, 0); + ICHECK(res != std::string::npos) << "Invalid Layout:" + << "can't find " << c << " in source layout" << src_str; + out.push_back(ref_str[res]); + } else if (isdigit(c)) { + out.push_back(c); + } else if (std::islower(c)) { + auto res = src_str.find(std::toupper(c), 0); + ICHECK(res != std::string::npos) << "Invalid Layout:" + << "can't find " << c << " in source layout" << src_str; + out.push_back(std::tolower(ref_str[res])); + } + } + return out; +} + +Layout TransposeSubLayoutLike(const Layout& ref, const Layout& src, const Layout& desired) { + std::string ref_str = ref.name(); + std::string src_str = src.name(); + std::string desired_str = desired.name(); + std::string out = TransposeSubLayoutStrLike(ref_str, src_str, desired_str); + return Layout(out); +} + Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst) { ICHECK(src.ndim() == dst.ndim() && input.ndim() == src.ndim()) << "Layouts must have the same size"; @@ -49,7 +78,11 @@ String TransposeStrLike(const String& input, const Layout& src, const Layout& ds int FindAxis(const Layout& dst, int axis) { axis = (axis + dst.ndim()) % dst.ndim(); - return dst.name().find('A' + axis); + std::string layout_name = dst.name(); + layout_name.erase(std::remove_if(layout_name.begin(), layout_name.end(), + [](unsigned char c) { return std::isdigit(c); }), + layout_name.end()); + return layout_name.find('A' + axis); } Layout InitialLayout(int ndim) { diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index 4e54d925446e..951fe92cb8ac 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -181,6 +181,25 @@ NLayout InitialNLayout(const StructInfo& sinfo); */ NLayout InitialNLayout(const Expr& expr); +/*! + * \brief Transposing given layout with subindexing + * \param ref The layout to be transformed. + * \param src The source layout. + * \param dst The destination layout. + * \return The transposed dst layout. + */ +Layout TransposeSubLayoutLike(const Layout& ref, const Layout& src, const Layout& desired); + +/*! + * \brief Transposing given layout in string format with subindexing + * \param ref The layout to be transformed. + * \param src The source layout. + * \param dst The destination layout. + * \return The transposed dst layout. + */ +std::string TransposeSubLayoutStrLike(const std::string ref_str, const std::string& src_str, + const std::string& desired_str); + /*! * \brief Transpose the input layout like the src layout to the dst layout. * \param input The input layout. diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 56b59ba23867..db4130f947d1 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -69,6 +69,9 @@ def main( return gv verify(Input, Expected) + # Channel not a proper multiple shouldn't alter the mod + verify(Input, Input, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Input, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) def test_conv2d_onlydim(): @@ -203,9 +206,10 @@ def main( lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast( lv0, R.Tensor((N, H, W, C), dtype="float32") ) - lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) - lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3) - gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4, axes=[0, 3, 1, 2]) + lv3: R.Tensor((N, C, H, W), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + gv: R.Tensor(dtype="float32", ndim=4) = R.add(lv3, w) R.output(gv) return gv @@ -467,6 +471,145 @@ def main( verify(Input, Expected) + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + lv4: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform( + lv4, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 1, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 4, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((1, 3, 3, 4, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + lv4: R.Tensor((2, 24, 24, 1, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform( + lv4, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv4) + return gv4 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + def test_conv2d_fma_relu_conv2d(): @I.ir_module @@ -1501,5 +1644,2946 @@ def main( verify(Input, Expected) +def test_conv2d_NCHW_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d( + x, + w, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_NHWC_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), "float32"), w: R.Tensor((4, 3, 3, 16), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 26, 26, 4), "float32") = R.nn.conv2d( + x, + w, + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype="float32", + ) + R.output(gv) + return gv + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), dtype="float32"), + w: R.Tensor((4, 3, 3, 16), dtype="float32"), + ) -> R.Tensor((2, 26, 26, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i3 // 4, i1, i2, i3 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i3, i1, i2, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i2, i3, i1 * 4 + i4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), dtype="float32"), + w: R.Tensor((4, 3, 3, 16), dtype="float32"), + ) -> R.Tensor((2, 26, 26, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1, i2, i3 // 4, i3 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1, i2, i3 * 4 + i4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + @I.ir_module + class Expected_N2nHWC4c: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), dtype="float32"), + w: R.Tensor((4, 3, 3, 16), dtype="float32"), + ) -> R.Tensor((2, 26, 26, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 2, i0 % 2, i1, i2, i3 // 4, i3 % 4), + index_dtype="int32", + ), + ) + lv1: R.Tensor((1, 3, 3, 8, 2, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3 // 2, i3 % 2, i0 % 4), + index_dtype="int32", + ), + ) + lv2: R.Tensor((1, 2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="N2nHWC4c", + kernel_layout="OHWI2i4o", + out_layout="N2nHWC4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4, i5: (i0 * 2 + i1, i2, i3, i4 * 4 + i5), + index_dtype="int32", + ), + ) + R.output(gv) + return gv + + verify(Input, Expected_N2nHWC4c, {"relax.nn.conv2d": ["N2nHWC4c", "OHWI2i4o"]}) + + +def test_conv2d_symbolic_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor("float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + gv: R.Tensor( + (N, T.int64(4), H + T.int64(1) - Hw, W + T.int64(1) - Ww), "float32" + ) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32", ndim=4) + ) -> R.Tensor(dtype="float32", ndim=4): + N = T.int64() + H = T.int64() + W = T.int64() + Hw = T.int64() + Ww = T.int64() + with R.dataflow(): + lv0: R.Tensor((N, 16, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, 16, H, W), dtype="float32") + ) + lv1: R.Tensor((4, 16, Hw, Ww), dtype="float32") = R.match_cast( + w, R.Tensor((4, 16, Hw, Ww), dtype="float32") + ) + lv: R.Tensor((N, 4, H, W, 4), dtype="float32") = R.layout_transform( + lv0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1_1: R.Tensor((1, 16, Hw, Ww, 4), dtype="float32") = R.layout_transform( + lv1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((N, 1, H + 1 - Hw, W + 1 - Ww, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1_1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv: R.Tensor((N, 4, H + 1 - Hw, W + 1 - Ww), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_matchcast_bias_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), + w: R.Tensor("float32", ndim=4), + bias: R.Tensor("float32", ndim=4), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + lv2: R.Tensor("float32", ndim=4) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + lv_bias = R.match_cast(bias, R.Tensor((Nb, Cb, Hb, Wb), "float32")) + gv = R.add(lv2, lv_bias) + R.output(gv) + return gv + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), + w: R.Tensor(dtype="float32", ndim=4), + bias: R.Tensor(dtype="float32", ndim=4), + ) -> R.Tensor(dtype="float32", ndim=4): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + with R.dataflow(): + lv0: R.Tensor((N, 16, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, 16, H, W), dtype="float32") + ) + lv1: R.Tensor((4, 16, Hw, Ww), dtype="float32") = R.match_cast( + w, R.Tensor((4, 16, Hw, Ww), dtype="float32") + ) + lv: R.Tensor((N, H, W, 4, 4), dtype="float32") = R.layout_transform( + lv0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1_1: R.Tensor((1, Hw, Ww, 16, 4), dtype="float32") = R.layout_transform( + lv1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((N, H + 1 - Hw, W + 1 - Ww, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1_1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv_bias: R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") = R.match_cast( + bias, R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") + ) + lv2_1: R.Tensor( + (Nb, Hb, Wb, (Cb - Cb % -4) // 4, 4), dtype="float32" + ) = R.layout_transform( + lv_bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv3: R.Tensor(dtype="float32", ndim=5) = R.add(lv2, lv2_1) + gv: R.Tensor(dtype="float32", ndim=4) = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), + w: R.Tensor(dtype="float32", ndim=4), + bias: R.Tensor(dtype="float32", ndim=4), + ) -> R.Tensor(dtype="float32", ndim=4): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + with R.dataflow(): + lv0: R.Tensor((N, 16, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, 16, H, W), dtype="float32") + ) + lv1: R.Tensor((4, 16, Hw, Ww), dtype="float32") = R.match_cast( + w, R.Tensor((4, 16, Hw, Ww), dtype="float32") + ) + lv: R.Tensor((N, 4, H, W, 4), dtype="float32") = R.layout_transform( + lv0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1_1: R.Tensor((1, 16, Hw, Ww, 4), dtype="float32") = R.layout_transform( + lv1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + lv2: R.Tensor((N, 1, H + 1 - Hw, W + 1 - Ww, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1_1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv_bias: R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") = R.match_cast( + bias, R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") + ) + lv2_1: R.Tensor( + (Nb, (Cb - Cb % -4) // 4, Hb, Wb, 4), dtype="float32" + ) = R.layout_transform( + lv_bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv3: R.Tensor(dtype="float32", ndim=5) = R.add(lv2, lv2_1) + gv: R.Tensor(dtype="float32", ndim=4) = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv) + return gv + + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_layout_incompatible_fallback(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), + w: R.Tensor("float32", ndim=4), + bias: R.Tensor("float32", ndim=4), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(15), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(15), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + lv2: R.Tensor("float32", ndim=4) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + lv_bias = R.match_cast(bias, R.Tensor((Nb, Cb, Hb, Wb), "float32")) + gv = R.add(lv2, lv_bias) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), + w: R.Tensor(dtype="float32", ndim=4), + bias: R.Tensor(dtype="float32", ndim=4), + ) -> R.Tensor(dtype="float32", ndim=4): + N, C, H, W = T.int64(), T.int64(15), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(15), T.int64(), T.int64() + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + with R.dataflow(): + lv0: R.Tensor((N, 15, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, 15, H, W), dtype="float32") + ) + lv1: R.Tensor((4, 15, Hw, Ww), dtype="float32") = R.match_cast( + w, R.Tensor((4, 15, Hw, Ww), dtype="float32") + ) + lv2: R.Tensor((N, 4, H + 1 - Hw, W + 1 - Ww), dtype="float32") = R.nn.conv2d( + lv0, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv_bias: R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") = R.match_cast( + bias, R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") + ) + gv: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv_bias) + R.output(gv) + return gv + + verify(Input, Expected, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + verify(Input, Expected, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_relu_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), dtype="float32") = R.nn.relu(x) + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), dtype="float32") = R.nn.relu(x) + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_relu_tanh_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.tanh(gv2) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.tanh(gv2) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_add_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv3: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_fma_relu_conv2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.ewise_fma(lv2, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform( + gv3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv4: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + lv5: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = R.nn.conv2d( + lv3, + lv4, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform( + lv5, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv4) + return gv4 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_sum_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 4), dtype="float32") = R.sum(gv, axis=[2, 3], keepdims=False) + gv2: R.Tensor((2, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0, i1 * 4 + i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 4), dtype="float32") = R.sum(gv, axis=[1, 2], keepdims=False) + gv2: R.Tensor((2, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0, i1 * 4 + i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_sum_keepdims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 3], keepdims=True) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 1, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 1, 1, 4), dtype="float32") = R.sum( + gv, axis=[2, 3], keepdims=True + ) + gv2: R.Tensor((2, 4, 1, 1), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 1, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 1, 1, 4), dtype="float32") = R.sum( + gv, axis=[1, 2], keepdims=True + ) + gv2: R.Tensor((2, 4, 1, 1), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_sum_reduce_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 26), "float32") = R.sum(gv, axis=[1, 2]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26), dtype="float32") = R.sum(gv, axis=[1, 2, 4], keepdims=False) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26), dtype="float32") = R.sum(gv, axis=[1, 3, 4], keepdims=False) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW2n4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 4, 28, 28, 2, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 2, i1 // 4, i2, i3, i0 % 2, i1 % 4), + index_dtype="int32", + ), + ) + lv1: R.Tensor((1, 8, 3, 3, 2, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1 // 2, i2, i3, i1 % 2, i0 % 4), + index_dtype="int32", + ), + ) + gv: R.Tensor((1, 1, 26, 26, 2, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW2n4c", + kernel_layout="OIHW2i4o", + out_layout="NCHW2n4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 26, 2), dtype="float32") = R.sum( + gv, axis=[1, 2, 5], keepdims=False + ) + gv2: R.Tensor((2, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0 * 2 + i2, i1), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + verify(Input, Expected_NCHW2n4c, {"relax.nn.conv2d": ["NCHW2n4c", "OIHW2i4o"]}) + + +def test_conv2d_sum_negative_dims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[-2, -1]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 4), dtype="float32") = R.sum(gv, axis=[2, 3], keepdims=False) + gv2: R.Tensor((2, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0, i1 * 4 + i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 4), dtype="float32") = R.sum(gv, axis=[1, 2], keepdims=False) + gv2: R.Tensor((2, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0, i1 * 4 + i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_transpose_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, axes=[3, 2, 1, 0]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((26, 26, 4, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((26, 26, 4, 2), dtype="float32") = R.permute_dims( + lv2, axes=[3, 2, 1, 0] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((26, 26, 4, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + gv2: R.Tensor((26, 26, 4, 2), dtype="float32") = R.permute_dims( + lv2, axes=[3, 2, 1, 0] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_expand_dims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=6): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32") = R.expand_dims( + lv2, axis=[-3, 1] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32") = R.expand_dims( + lv2, axis=[-3, 1] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_squeeze_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=3): + with R.dataflow(): + gv: R.Tensor((1, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((4, 26, 26), "float32") = R.squeeze(gv, axis=[0]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((1, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((4, 26, 26), dtype="float32") = R.squeeze(lv2, axis=[0]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + gv: R.Tensor((1, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + gv2: R.Tensor((4, 26, 26), dtype="float32") = R.squeeze(lv2, axis=[0]) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_strided_slice_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], axes=[1, 2, 3] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 2, 9, 7), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(1), R.prim_value(2), R.prim_value(3)), + (R.prim_value(0), R.prim_value(0), R.prim_value(0)), + (R.prim_value(4), R.prim_value(26), R.prim_value(26)), + (R.prim_value(2), R.prim_value(3), R.prim_value(4)), + assume_inbound=False, + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 2, 9, 7), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(1), R.prim_value(2), R.prim_value(3)), + (R.prim_value(0), R.prim_value(0), R.prim_value(0)), + (R.prim_value(4), R.prim_value(26), R.prim_value(26)), + (R.prim_value(2), R.prim_value(3), R.prim_value(4)), + assume_inbound=False, + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_relu_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 8, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 2, 26, 26, 4), dtype="float32") = R.concat((gv, gv2), axis=1) + gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 8, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 2, 4), dtype="float32") = R.concat((gv, gv2), axis=3) + gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_N4cHWC: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 8, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 % 4, i2, i3, i1 // 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 16), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i0 % 4, i2, i3, i1), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="N4cHWC", + kernel_layout="O4oHWI", + out_layout="N4cHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 4, 26, 26, 2), dtype="float32") = R.concat((gv, gv2), axis=4) + gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i4 * 4 + i1, i2, i3), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + # Concat axis after sub index + verify(Input, Expected_N4cHWC, {"relax.nn.conv2d": ["N4cHWC", "O4oHWI"]}) + + +def test_conv2d_relu_concat_split_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), R.Tensor((2, 4, 26, 26), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 2, 26, 26, 4), dtype="float32") = R.concat((gv, gv2), axis=1) + lv2: R.Tuple( + R.Tensor((2, 1, 26, 26, 4), dtype="float32"), + R.Tensor((2, 1, 26, 26, 4), dtype="float32"), + ) = R.split(gv3, indices_or_sections=2, axis=1) + lv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = lv2[0] + lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + lv5: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = lv2[1] + lv6: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv5, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv4: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((2, 4, 26, 26), dtype="float32"), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), R.Tensor((2, 4, 26, 26), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 26, 26, 2, 4), dtype="float32") = R.concat((gv, gv2), axis=3) + lv2: R.Tuple( + R.Tensor((2, 26, 26, 1, 4), dtype="float32"), + R.Tensor((2, 26, 26, 1, 4), dtype="float32"), + ) = R.split(gv3, indices_or_sections=2, axis=3) + lv3: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = lv2[0] + lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + lv5: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = lv2[1] + lv6: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv5, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + gv4: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((2, 4, 26, 26), dtype="float32"), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_N4cHWC: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), R.Tensor((2, 4, 26, 26), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 % 4, i2, i3, i1 // 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 16), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i0 % 4, i2, i3, i1), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="N4cHWC", + kernel_layout="O4oHWI", + out_layout="N4cHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26, 2), dtype="float32") = R.concat((gv, gv2), axis=4) + lv2: R.Tuple( + R.Tensor((2, 4, 26, 26, 1), dtype="float32"), + R.Tensor((2, 4, 26, 26, 1), dtype="float32"), + ) = R.split(gv3, indices_or_sections=2, axis=4) + lv3: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = lv2[0] + lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i4 * 4 + i1, i2, i3), index_dtype="int32" + ), + ) + lv5: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = lv2[1] + lv6: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv5, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i4 * 4 + i1, i2, i3), index_dtype="int32" + ), + ) + gv4: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((2, 4, 26, 26), dtype="float32"), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + verify(Input, Expected_N4cHWC, {"relax.nn.conv2d": ["N4cHWC", "O4oHWI"]}) + + +def test_conv2d_relu_concat_split_sub_indexed_div_exception(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=4, axis=1) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 2, 26, 26, 4), dtype="float32") = R.concat((gv, gv2), axis=1) + lv2: R.Tensor((2, 8, 26, 26), dtype="float32") = R.layout_transform( + gv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv4: R.Tuple( + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + ) = R.split(lv2, indices_or_sections=4, axis=1) + R.output(gv4) + return gv4 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_maxpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0], + layout="NCHW", + out_layout="NCHW", + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 13, 13), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 13, 13, 4), dtype="float32") = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + layout="NCHW4c", + out_layout="NCHW4c", + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 13, 13), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 13, 13, 1, 4), dtype="float32") = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + layout="NHWC4c", + out_layout="NHWC4c", + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_avgpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], layout="NCHW") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 13, 13), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 13, 13, 4), dtype="float32") = R.nn.adaptive_avg_pool2d( + gv, output_size=[13, 13], layout="NCHW4c", out_layout="NCHW4c" + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 13, 13), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 13, 13, 1, 4), dtype="float32") = R.nn.adaptive_avg_pool2d( + gv, output_size=[13, 13], layout="NHWC4c", out_layout="NHWC4c" + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_softmax_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.softmax(gv, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.softmax(lv2, axis=1) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_batchnorm_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm(gv, gamma, beta, moving_mean, moving_var, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm( + lv2, + gamma, + beta, + moving_mean, + moving_var, + axis=1, + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + momentum=0.10000000000000001, + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_layernorm_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.layer_norm( + gv, gamma, beta, axes=[-2, -1] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.layer_norm( + gv, + gamma, + beta, + axes=[2, 3], + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_resize2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.image.resize2d(gv, (52, 52), layout="NCHW") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 52, 52), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 4, 52, 52), dtype="float32") = R.image.resize2d( + lv2, + R.shape([52, 52]), + layout="NCHW", + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_unknown_bias_dim_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + w2: R.Tensor(dtype="float32"), + ) -> R.Tensor(None, "float32"): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = w2 + gv + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + w2: R.Tensor(dtype="float32"), + ) -> R.Tensor(dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor(dtype="float32") = R.add(w2, lv2) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_binary_broadcast_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + bias: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(lv2, bias) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_binary_ewise_scalar_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, R.const(1, "float32")) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add( + gv, R.const(1.0, "float32") + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + if __name__ == "__main__": tvm.testing.main()