From d6413548ab5f50e292fe76ef8cd13770aa2c30be Mon Sep 17 00:00:00 2001 From: Siva Date: Sun, 19 Jan 2025 23:42:04 +0530 Subject: [PATCH] [RELAX][PASS] Convert layout pass and ops enhanced to support sub indexing (#17568) Convert layout pass and ops enhanced to support sub indexing Majority of the operations made compatible with custom layouts. Incompatible ops will fallback to regular layout. Conv1D, Conv3D, Pool1D, Pool3D, AdaptiveAvgPool1D, AdaptiveAvgPool3D are left unchanged now. 2D networks are expected to work now. --- python/tvm/script/ir_builder/tir/ir.py | 3 +- src/relax/op/image/resize.cc | 4 + src/relax/op/nn/convolution.cc | 75 +- src/relax/op/nn/nn.cc | 26 +- src/relax/op/nn/pooling.cc | 32 + src/relax/op/op_common.cc | 22 + src/relax/op/op_common.h | 10 + src/relax/op/tensor/binary.cc | 15 + src/relax/op/tensor/index.cc | 4 + src/relax/op/tensor/manipulate.cc | 41 +- src/relax/op/tensor/statistical.cc | 24 +- src/relax/transform/convert_layout.cc | 40 +- src/relax/transform/infer_layout_utils.cc | 35 +- src/relax/transform/infer_layout_utils.h | 19 + .../relax/test_transform_convert_layout.py | 3090 ++++++++++++++++- 15 files changed, 3398 insertions(+), 42 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 59548634fc4a..104cc843e398 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1688,9 +1688,10 @@ def index_map( mapping: Callable, *, inverse_index_map: Optional[Callable] = None, + index_dtype: str = "int64", ) -> IndexMap: """Create a TIR Index mapping""" - return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map) + return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map, index_dtype=index_dtype) def target( diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 202702d78746..344e27456551 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -121,6 +121,10 @@ InferLayoutOutput InferLayoutResize2d(const Call& call, } else { // We dont have a desired layout for resize2d, propagate from the input instead. data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + // Not handling sub indexing now. + if (data_layout->layout.ndim() != data_layout->layout.ndim_primal()) { + data_layout = LayoutDecision(InitialLayout(4)); + } new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), data_layout->layout).name(); } return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout}, diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 7c7718b837d9..3ebbc544f470 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -308,30 +308,59 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, Layout desired_data_layout = (*it).second[0]; Layout desired_weight_layout = (*it).second[1]; Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; - ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) - << "Axis swap only"; - ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) - << "Axis swap only"; - data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout); - weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout); - output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout); - new_attrs->data_layout = (*it).second[0]; - new_attrs->kernel_layout = (*it).second[1]; - new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - } else { - // We don't have a desired layout for conv2d. - // We can just propagate the layout from the input. - data_layout = GetLayoutDecision(var_layout_map, call->args[0]); - weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); - output_layout = data_layout; - new_attrs->data_layout = - TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); - new_attrs->kernel_layout = - TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name(); - new_attrs->out_layout = - TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name(); + tir::Layout input_layout(attrs->data_layout, DataType::Int(64)); + tir::Layout kernel_layout(attrs->kernel_layout, DataType::Int(64)); + tir::Layout out_layout(attrs->out_layout, DataType::Int(64)); + + if ((desired_data_layout.ndim() == input_layout.ndim()) && + (desired_weight_layout.ndim() == kernel_layout.ndim()) && + (desired_output_layout.ndim() == out_layout.ndim())) { + // Just a transpose + data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout); + weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout); + output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } else { + // Layout Transform + auto data_si = GetStructInfo(call->args[0]); + auto kernel_si = GetStructInfo(call->args[1]); + TensorStructInfo data_sinfo = data_si.as().value(); + TensorStructInfo kernel_sinfo = kernel_si.as().value(); + Optional data_shape = GetRef(data_sinfo->shape.as()); + Optional kernel_shape = GetRef(kernel_sinfo->shape.as()); + + bool can_data_proved = + CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); + bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, desired_weight_layout, + kernel_shape.value()->values); + + if (can_data_proved && can_kernel_proved) { + data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout, desired_data_layout); + weight_layout = + TransposeSubLayoutLike(InitialLayout(4), kernel_layout, desired_weight_layout); + output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } + } } + + // We don't have a desired layout for conv2d or desired layouts not compatible. + // We can just propagate the layout from the input. + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; + new_attrs->data_layout = + TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); + new_attrs->kernel_layout = + TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name(); + new_attrs->out_layout = + TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name(); return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 6f8a90e3cba9..7eccf47e4b06 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -93,6 +93,16 @@ InferLayoutOutput InferLayoutSoftmax(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + + // TODO(Siva): We could handle if the axis is not the sub indexed one. + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_sinfo->ndim; + layout = LayoutDecision(InitialLayout(ndim)); + } + ObjectPtr new_attrs = make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, attrs->axis); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -290,8 +300,18 @@ InferLayoutOutput InferLayoutBatchNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + + // While dealing with sub layouts, its adviced to deal with batchnorm + // on other ways like decomposing or fusion methods. + // This handling is fail safe fallback. + const auto* input_sinfo = GetStructInfoAs(call->args[0]); + int ndim = input_sinfo->ndim; + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + layout = LayoutDecision(InitialLayout(ndim)); + } + ObjectPtr new_attrs = make_object(*attrs); - new_attrs->axis = FindAxis(layout->layout, attrs->axis); + new_attrs->axis = FindAxis(layout->layout, (attrs->axis + ndim) % ndim); return InferLayoutOutput( {layout, initial_layouts[1], initial_layouts[2], initial_layouts[3], initial_layouts[4]}, {{layout, initial_layouts[3], initial_layouts[4]}}, Attrs(new_attrs)); @@ -353,9 +373,11 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call, LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = make_object(*attrs); + const auto* input_sinfo = GetStructInfoAs(call->args[0]); + int ndim = input_sinfo->ndim; std::vector new_axis; for (const auto& axis : attrs->axes) { - new_axis.push_back(FindAxis(layout->layout, axis->value)); + new_axis.push_back(FindAxis(layout->layout, (axis->value + ndim) % ndim)); } new_attrs->axes = std::move(new_axis); return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 7fdcefc00bd0..565e6a00c60d 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -234,6 +234,23 @@ InferLayoutOutput InferLayoutPool2d(const Call& call, LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = make_object(*attrs); + + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + tir::Layout in_layout(attrs->layout, DataType::Int(64)); + auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); + auto data_si = GetStructInfo(call->args[0]); + TensorStructInfo data_sinfo = data_si.as().value(); + Optional data_shape = GetRef(data_sinfo->shape.as()); + if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { + // Not handling out_layout being different from in_layout now. Any use case ? + new_attrs->layout = desired_layout.name(); + new_attrs->out_layout = desired_layout.name(); + return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); + } else { + layout = InitialLayout(4); + } + } + new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -583,6 +600,21 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call, LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = make_object(*attrs); + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + tir::Layout in_layout(attrs->layout, DataType::Int(64)); + auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); + auto data_si = GetStructInfo(call->args[0]); + TensorStructInfo data_sinfo = data_si.as().value(); + Optional data_shape = GetRef(data_sinfo->shape.as()); + if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { + // Not handling out_layout being different from in_layout now. Any use case ? + new_attrs->layout = desired_layout.name(); + new_attrs->out_layout = desired_layout.name(); + return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); + } else { + layout = InitialLayout(4); + } + } new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 56bf708f5e06..f9c1ece38c18 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -185,5 +185,27 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call, return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs)); } +bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout, + Array shape) { + bool can_prove = true; + try { + tir::BijectiveLayout todesired(input_layout, desired_layout); + Array desired_shape = todesired.ForwardShape(shape); + Array back_shape = todesired.BackwardShape(desired_shape); + arith::Analyzer analyzer; + for (size_t i = 0; i < shape.size(); ++i) { + if (tir::is_const_int(shape[i])) { + if (!analyzer.CanProveEqual(shape[i], back_shape[i])) { + can_prove = false; + break; + } + } + } + } catch (std::exception& err) { + return false; + } + return can_prove; +} + } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index ed6725e27012..eb9caae4b9e1 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -570,6 +570,16 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind */ Array GetCallArgs(const Call& call); +/** + * \brief Checks the given shape can be proved from the source layout to dst layout + * \param input_layout is the layout of given shape + * \param desired_layout is the target layout the shape to be transformed + * \param shape array + * \return true or false depending on the compatibility + */ +bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout, + Array shape); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index bd4c681c7925..4a63993d507c 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -155,6 +155,21 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call, ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) << "Unknown dim tensors should not be handled by this function"; + Optional shape1 = GetRef(x1_sinfo->shape.as()); + Optional shape2 = GetRef(x2_sinfo->shape.as()); + // Lets handle sub indexing as long as primal dims are matching + if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { + if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { + if (CanProveLayoutTransform(layout2->layout, layout1->layout, shape2.value()->values)) { + return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); + } + } else if (shape1.defined()) { + if (CanProveLayoutTransform(layout1->layout, layout2->layout, shape1.value()->values)) { + return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + } + } + } + if (x1_sinfo->ndim <= x2_sinfo->ndim) { if (x1_sinfo->ndim == 0) { LayoutDecision out_layout = layout2; diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index e62dbe89d08a..8a8dc6de40e4 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -438,6 +438,10 @@ InferLayoutOutput InferLayoutStridedSlice(const Call& call, << "but expression " << call << " has argument " << call->args[0] << " of unknown dimensionality."; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + // Can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(tensor_sinfo->ndim)); + } auto opt_axes_tuple = UnpackTupleOfPrimValue(GetStructInfo(call->args[1])); CHECK(opt_axes_tuple) << "Layout inference of " << call->op diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f64b3ec4f979..452b1f223a80 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -393,6 +393,10 @@ InferLayoutOutput InferLayoutExpandDims(const Call& call, LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); int ndim = tensor_sinfo->ndim; + // Can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } int n_new_dim = attrs->axis.size(); int output_ndim = ndim + n_new_dim; std::vector is_new_dim(output_ndim, false); @@ -622,6 +626,12 @@ InferLayoutOutput InferLayoutPermuteDims(const Call& call, int ndim = tensor_sinfo->ndim; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + + // permute_dims can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } + Array order; if (attrs->axes.defined()) { order = attrs->axes.value(); @@ -942,10 +952,33 @@ InferLayoutOutput InferLayoutSplit(const Call& call, ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); - new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis); StructInfo out_sinfo = InferStructInfoSplit(call, BlockBuilder::Create(IRModule())); const auto* out_tuple = out_sinfo.as(); + + /* + * Fallback if the outputs can't be represented in input sub indexed layout + * This can happen after sub indexing, if we can't split the corresponding primal axis + */ + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + for (const auto& si : out_tuple->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo" + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + Optional shape_expr = GetRef(sinfo->shape.as()); + CHECK(shape_expr.defined()); + auto shape_arr = shape_expr.value(); + if (!CanProveLayoutTransform(InitialLayout(tensor_sinfo->ndim), existing_layout->layout, + shape_arr->values)) { + existing_layout = InitialLayout(tensor_sinfo->ndim); + break; + } + } + } + + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis); ICHECK(out_tuple != nullptr) << "Invalid Call"; NLayout tuple_layouts(Array(out_tuple->fields.size(), existing_layout)); return InferLayoutOutput({existing_layout}, {tuple_layouts}, Attrs(new_attrs)); @@ -1092,6 +1125,10 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call, } LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + // Can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); Array new_axis; for (size_t i = 0; i < new_axis_str.size(); ++i) { diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 24ccde4559e6..606001dfbf3f 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -108,25 +108,35 @@ InferLayoutOutput InferLayoutStatistical(const Call& call, std::string axis_str(ndim, '0'); for (const auto& iter : axis) { - axis_str[(iter->value + ndim) % ndim] = '1'; + axis_str[(iter->value + ndim) % ndim] = '#'; } for (int i = 0, j = 0; i < ndim; ++i) { - if (axis_str[i] != '1') { + if (axis_str[i] != '#') { axis_str[i] = 'A' + j++; } } LayoutDecision exisiting_layout = GetLayoutDecision(var_layout_map, call->args[0]); - String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), exisiting_layout->layout); + auto new_axis_str = TransposeSubLayoutStrLike(axis_str, InitialLayout(ndim).name(), + exisiting_layout->layout.name()); + std::string output_layout_ref = new_axis_str; + new_axis_str.erase(std::remove_if(new_axis_str.begin(), new_axis_str.end(), + [](unsigned char c) { return std::isdigit(c); }), + new_axis_str.end()); + Array new_axis; for (size_t i = 0; i < new_axis_str.size(); ++i) { - if (new_axis_str.at(i) == '1') { + if (new_axis_str.at(i) == '#') { new_axis.push_back(Integer(i)); } } - std::string output_layout = new_axis_str; - output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), '1'), - output_layout.end()); + std::string output_layout; + for (size_t i = 0; i < output_layout_ref.length(); ++i) { + if ((isdigit(output_layout_ref[i]) && (output_layout_ref[i + 1] == '#')) || + (output_layout_ref[i] == '#')) + continue; + output_layout.push_back(output_layout_ref[i]); + } ObjectPtr new_attrs = make_object(*attrs); new_attrs->axis = new_axis; diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 3798bdff351d..0898af1d7636 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -21,10 +21,12 @@ * \brief Automatic layout conversion pass, especially for axis swapping. */ +#include #include #include #include #include +#include #include "../op/tensor/manipulate.h" #include "infer_layout_utils.h" @@ -33,6 +35,7 @@ namespace tvm { namespace relax { +using tir::IndexMap; using tir::Layout; /*! @@ -87,6 +90,22 @@ class LayoutConvertMutator : public ExprMutator { return ret; } + IndexMap LayoutIndexMap(int ndim, const Layout& src_layout, const Layout& desired_layout) { + tir::BijectiveLayout todesired(src_layout, desired_layout); + Optional inverse_index_map; + + Array initial_indices; + Array initial_indices_expr; + initial_indices.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + auto var = tvm::tir::Var("i" + std::to_string(i), DataType::Int(32)); + initial_indices.push_back(var); + initial_indices_expr.push_back(var); + } + Array desired_shape = todesired.ForwardIndex(initial_indices_expr); + return IndexMap(initial_indices, desired_shape, std::move(inverse_index_map)); + } + Expr RewriteExpr(const Expr& expr, const NLayout& to) { auto fvisitleaf = [&](const Expr& expr, std::array layouts) -> Expr { NLayout from = layouts[0], to = layouts[1]; @@ -97,9 +116,24 @@ class LayoutConvertMutator : public ExprMutator { << "Cannot convert when exactly one of the layouts is unknown"; const auto* tensor = GetStructInfoAs(expr); ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr; - Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout, - from.LeafValue()->layout, to.LeafValue()->layout); - return permute_dims(expr, LayoutToIntegers(axes)); + + if (from.LeafValue()->layout.ndim() == to.LeafValue()->layout.ndim()) { + Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout, + from.LeafValue()->layout, to.LeafValue()->layout); + return permute_dims(expr, LayoutToIntegers(axes)); + } else { + auto index_map = LayoutIndexMap(from.LeafValue()->layout.ndim(), from.LeafValue()->layout, + to.LeafValue()->layout); + ObjectPtr attrs = make_object(); + Array axis_separator; + Array input_axis_separator; + attrs->index_map = std::move(Downcast(LoadJSON(SaveJSON(index_map)))); + attrs->axis_separators = std::move(axis_separator); + attrs->input_axis_separators = std::move(input_axis_separator); + const Op& layout_transform_op_ = Op::Get("relax.layout_transform"); + auto ret_expr = Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); + return ret_expr; + } }; return TransformTupleLeaf( VarReplacer::Replace(expr, var_remap_), diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc index d746f9394a75..aca048820996 100644 --- a/src/relax/transform/infer_layout_utils.cc +++ b/src/relax/transform/infer_layout_utils.cc @@ -27,6 +27,35 @@ namespace relax { using tir::IterVar; using tir::Layout; +std::string TransposeSubLayoutStrLike(const std::string ref_str, const std::string& src_str, + const std::string& desired_str) { + std::string out; + for (const char& c : desired_str) { + if (std::isupper(c)) { + auto res = src_str.find(c, 0); + ICHECK(res != std::string::npos) << "Invalid Layout:" + << "can't find " << c << " in source layout" << src_str; + out.push_back(ref_str[res]); + } else if (isdigit(c)) { + out.push_back(c); + } else if (std::islower(c)) { + auto res = src_str.find(std::toupper(c), 0); + ICHECK(res != std::string::npos) << "Invalid Layout:" + << "can't find " << c << " in source layout" << src_str; + out.push_back(std::tolower(ref_str[res])); + } + } + return out; +} + +Layout TransposeSubLayoutLike(const Layout& ref, const Layout& src, const Layout& desired) { + std::string ref_str = ref.name(); + std::string src_str = src.name(); + std::string desired_str = desired.name(); + std::string out = TransposeSubLayoutStrLike(ref_str, src_str, desired_str); + return Layout(out); +} + Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst) { ICHECK(src.ndim() == dst.ndim() && input.ndim() == src.ndim()) << "Layouts must have the same size"; @@ -49,7 +78,11 @@ String TransposeStrLike(const String& input, const Layout& src, const Layout& ds int FindAxis(const Layout& dst, int axis) { axis = (axis + dst.ndim()) % dst.ndim(); - return dst.name().find('A' + axis); + std::string layout_name = dst.name(); + layout_name.erase(std::remove_if(layout_name.begin(), layout_name.end(), + [](unsigned char c) { return std::isdigit(c); }), + layout_name.end()); + return layout_name.find('A' + axis); } Layout InitialLayout(int ndim) { diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index 4e54d925446e..951fe92cb8ac 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -181,6 +181,25 @@ NLayout InitialNLayout(const StructInfo& sinfo); */ NLayout InitialNLayout(const Expr& expr); +/*! + * \brief Transposing given layout with subindexing + * \param ref The layout to be transformed. + * \param src The source layout. + * \param dst The destination layout. + * \return The transposed dst layout. + */ +Layout TransposeSubLayoutLike(const Layout& ref, const Layout& src, const Layout& desired); + +/*! + * \brief Transposing given layout in string format with subindexing + * \param ref The layout to be transformed. + * \param src The source layout. + * \param dst The destination layout. + * \return The transposed dst layout. + */ +std::string TransposeSubLayoutStrLike(const std::string ref_str, const std::string& src_str, + const std::string& desired_str); + /*! * \brief Transpose the input layout like the src layout to the dst layout. * \param input The input layout. diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 56b59ba23867..db4130f947d1 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -69,6 +69,9 @@ def main( return gv verify(Input, Expected) + # Channel not a proper multiple shouldn't alter the mod + verify(Input, Input, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Input, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) def test_conv2d_onlydim(): @@ -203,9 +206,10 @@ def main( lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast( lv0, R.Tensor((N, H, W, C), dtype="float32") ) - lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) - lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3) - gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4, axes=[0, 3, 1, 2]) + lv3: R.Tensor((N, C, H, W), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + gv: R.Tensor(dtype="float32", ndim=4) = R.add(lv3, w) R.output(gv) return gv @@ -467,6 +471,145 @@ def main( verify(Input, Expected) + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + lv4: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform( + lv4, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 1, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 4, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((1, 3, 3, 4, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + lv4: R.Tensor((2, 24, 24, 1, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform( + lv4, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv4) + return gv4 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + def test_conv2d_fma_relu_conv2d(): @I.ir_module @@ -1501,5 +1644,2946 @@ def main( verify(Input, Expected) +def test_conv2d_NCHW_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d( + x, + w, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_NHWC_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), "float32"), w: R.Tensor((4, 3, 3, 16), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 26, 26, 4), "float32") = R.nn.conv2d( + x, + w, + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype="float32", + ) + R.output(gv) + return gv + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), dtype="float32"), + w: R.Tensor((4, 3, 3, 16), dtype="float32"), + ) -> R.Tensor((2, 26, 26, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i3 // 4, i1, i2, i3 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i3, i1, i2, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i2, i3, i1 * 4 + i4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), dtype="float32"), + w: R.Tensor((4, 3, 3, 16), dtype="float32"), + ) -> R.Tensor((2, 26, 26, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1, i2, i3 // 4, i3 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1, i2, i3 * 4 + i4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + @I.ir_module + class Expected_N2nHWC4c: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), dtype="float32"), + w: R.Tensor((4, 3, 3, 16), dtype="float32"), + ) -> R.Tensor((2, 26, 26, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 2, i0 % 2, i1, i2, i3 // 4, i3 % 4), + index_dtype="int32", + ), + ) + lv1: R.Tensor((1, 3, 3, 8, 2, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3 // 2, i3 % 2, i0 % 4), + index_dtype="int32", + ), + ) + lv2: R.Tensor((1, 2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="N2nHWC4c", + kernel_layout="OHWI2i4o", + out_layout="N2nHWC4c", + out_dtype="float32", + ) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4, i5: (i0 * 2 + i1, i2, i3, i4 * 4 + i5), + index_dtype="int32", + ), + ) + R.output(gv) + return gv + + verify(Input, Expected_N2nHWC4c, {"relax.nn.conv2d": ["N2nHWC4c", "OHWI2i4o"]}) + + +def test_conv2d_symbolic_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor("float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + gv: R.Tensor( + (N, T.int64(4), H + T.int64(1) - Hw, W + T.int64(1) - Ww), "float32" + ) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32", ndim=4) + ) -> R.Tensor(dtype="float32", ndim=4): + N = T.int64() + H = T.int64() + W = T.int64() + Hw = T.int64() + Ww = T.int64() + with R.dataflow(): + lv0: R.Tensor((N, 16, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, 16, H, W), dtype="float32") + ) + lv1: R.Tensor((4, 16, Hw, Ww), dtype="float32") = R.match_cast( + w, R.Tensor((4, 16, Hw, Ww), dtype="float32") + ) + lv: R.Tensor((N, 4, H, W, 4), dtype="float32") = R.layout_transform( + lv0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1_1: R.Tensor((1, 16, Hw, Ww, 4), dtype="float32") = R.layout_transform( + lv1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((N, 1, H + 1 - Hw, W + 1 - Ww, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1_1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv: R.Tensor((N, 4, H + 1 - Hw, W + 1 - Ww), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + verify(Input, Expected, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_matchcast_bias_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), + w: R.Tensor("float32", ndim=4), + bias: R.Tensor("float32", ndim=4), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + lv2: R.Tensor("float32", ndim=4) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + lv_bias = R.match_cast(bias, R.Tensor((Nb, Cb, Hb, Wb), "float32")) + gv = R.add(lv2, lv_bias) + R.output(gv) + return gv + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), + w: R.Tensor(dtype="float32", ndim=4), + bias: R.Tensor(dtype="float32", ndim=4), + ) -> R.Tensor(dtype="float32", ndim=4): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + with R.dataflow(): + lv0: R.Tensor((N, 16, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, 16, H, W), dtype="float32") + ) + lv1: R.Tensor((4, 16, Hw, Ww), dtype="float32") = R.match_cast( + w, R.Tensor((4, 16, Hw, Ww), dtype="float32") + ) + lv: R.Tensor((N, H, W, 4, 4), dtype="float32") = R.layout_transform( + lv0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1_1: R.Tensor((1, Hw, Ww, 16, 4), dtype="float32") = R.layout_transform( + lv1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv2: R.Tensor((N, H + 1 - Hw, W + 1 - Ww, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1_1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv_bias: R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") = R.match_cast( + bias, R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") + ) + lv2_1: R.Tensor( + (Nb, Hb, Wb, (Cb - Cb % -4) // 4, 4), dtype="float32" + ) = R.layout_transform( + lv_bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv3: R.Tensor(dtype="float32", ndim=5) = R.add(lv2, lv2_1) + gv: R.Tensor(dtype="float32", ndim=4) = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv) + return gv + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), + w: R.Tensor(dtype="float32", ndim=4), + bias: R.Tensor(dtype="float32", ndim=4), + ) -> R.Tensor(dtype="float32", ndim=4): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + with R.dataflow(): + lv0: R.Tensor((N, 16, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, 16, H, W), dtype="float32") + ) + lv1: R.Tensor((4, 16, Hw, Ww), dtype="float32") = R.match_cast( + w, R.Tensor((4, 16, Hw, Ww), dtype="float32") + ) + lv: R.Tensor((N, 4, H, W, 4), dtype="float32") = R.layout_transform( + lv0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1_1: R.Tensor((1, 16, Hw, Ww, 4), dtype="float32") = R.layout_transform( + lv1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + lv2: R.Tensor((N, 1, H + 1 - Hw, W + 1 - Ww, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1_1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv_bias: R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") = R.match_cast( + bias, R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") + ) + lv2_1: R.Tensor( + (Nb, (Cb - Cb % -4) // 4, Hb, Wb, 4), dtype="float32" + ) = R.layout_transform( + lv_bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv3: R.Tensor(dtype="float32", ndim=5) = R.add(lv2, lv2_1) + gv: R.Tensor(dtype="float32", ndim=4) = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv) + return gv + + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_layout_incompatible_fallback(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), + w: R.Tensor("float32", ndim=4), + bias: R.Tensor("float32", ndim=4), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(15), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(15), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + lv2: R.Tensor("float32", ndim=4) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + lv_bias = R.match_cast(bias, R.Tensor((Nb, Cb, Hb, Wb), "float32")) + gv = R.add(lv2, lv_bias) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(dtype="float32", ndim=4), + w: R.Tensor(dtype="float32", ndim=4), + bias: R.Tensor(dtype="float32", ndim=4), + ) -> R.Tensor(dtype="float32", ndim=4): + N, C, H, W = T.int64(), T.int64(15), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(15), T.int64(), T.int64() + Nb, Cb, Hb, Wb = T.int64(), T.int64(), T.int64(), T.int64() + with R.dataflow(): + lv0: R.Tensor((N, 15, H, W), dtype="float32") = R.match_cast( + x, R.Tensor((N, 15, H, W), dtype="float32") + ) + lv1: R.Tensor((4, 15, Hw, Ww), dtype="float32") = R.match_cast( + w, R.Tensor((4, 15, Hw, Ww), dtype="float32") + ) + lv2: R.Tensor((N, 4, H + 1 - Hw, W + 1 - Ww), dtype="float32") = R.nn.conv2d( + lv0, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv_bias: R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") = R.match_cast( + bias, R.Tensor((Nb, Cb, Hb, Wb), dtype="float32") + ) + gv: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv_bias) + R.output(gv) + return gv + + verify(Input, Expected, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + verify(Input, Expected, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_relu_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), dtype="float32") = R.nn.relu(x) + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), dtype="float32") = R.nn.relu(x) + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x0, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_relu_tanh_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.tanh(gv2) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.tanh(gv2) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_add_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv3: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_fma_relu_conv2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.ewise_fma(lv2, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform( + gv3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv4: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + lv5: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = R.nn.conv2d( + lv3, + lv4, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform( + lv5, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv4) + return gv4 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_sum_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 4), dtype="float32") = R.sum(gv, axis=[2, 3], keepdims=False) + gv2: R.Tensor((2, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0, i1 * 4 + i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 4), dtype="float32") = R.sum(gv, axis=[1, 2], keepdims=False) + gv2: R.Tensor((2, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0, i1 * 4 + i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_sum_keepdims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 3], keepdims=True) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 1, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 1, 1, 4), dtype="float32") = R.sum( + gv, axis=[2, 3], keepdims=True + ) + gv2: R.Tensor((2, 4, 1, 1), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 1, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 1, 1, 4), dtype="float32") = R.sum( + gv, axis=[1, 2], keepdims=True + ) + gv2: R.Tensor((2, 4, 1, 1), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_sum_reduce_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 26), "float32") = R.sum(gv, axis=[1, 2]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26), dtype="float32") = R.sum(gv, axis=[1, 2, 4], keepdims=False) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26), dtype="float32") = R.sum(gv, axis=[1, 3, 4], keepdims=False) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW2n4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 4, 28, 28, 2, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 2, i1 // 4, i2, i3, i0 % 2, i1 % 4), + index_dtype="int32", + ), + ) + lv1: R.Tensor((1, 8, 3, 3, 2, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1 // 2, i2, i3, i1 % 2, i0 % 4), + index_dtype="int32", + ), + ) + gv: R.Tensor((1, 1, 26, 26, 2, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW2n4c", + kernel_layout="OIHW2i4o", + out_layout="NCHW2n4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 26, 2), dtype="float32") = R.sum( + gv, axis=[1, 2, 5], keepdims=False + ) + gv2: R.Tensor((2, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0 * 2 + i2, i1), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + verify(Input, Expected_NCHW2n4c, {"relax.nn.conv2d": ["NCHW2n4c", "OIHW2i4o"]}) + + +def test_conv2d_sum_negative_dims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[-2, -1]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 4), dtype="float32") = R.sum(gv, axis=[2, 3], keepdims=False) + gv2: R.Tensor((2, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0, i1 * 4 + i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 4), dtype="float32") = R.sum(gv, axis=[1, 2], keepdims=False) + gv2: R.Tensor((2, 4), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2: (i0, i1 * 4 + i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_transpose_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, axes=[3, 2, 1, 0]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((26, 26, 4, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((26, 26, 4, 2), dtype="float32") = R.permute_dims( + lv2, axes=[3, 2, 1, 0] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((26, 26, 4, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + gv2: R.Tensor((26, 26, 4, 2), dtype="float32") = R.permute_dims( + lv2, axes=[3, 2, 1, 0] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_expand_dims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=6): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32") = R.expand_dims( + lv2, axis=[-3, 1] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32") = R.expand_dims( + lv2, axis=[-3, 1] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_squeeze_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=3): + with R.dataflow(): + gv: R.Tensor((1, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((4, 26, 26), "float32") = R.squeeze(gv, axis=[0]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((1, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((4, 26, 26), dtype="float32") = R.squeeze(lv2, axis=[0]) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + gv: R.Tensor((1, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + pad_value=None, + axis_separators=[], + input_axis_separators=[], + ) + gv2: R.Tensor((4, 26, 26), dtype="float32") = R.squeeze(lv2, axis=[0]) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_strided_slice_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], axes=[1, 2, 3] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 2, 9, 7), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(1), R.prim_value(2), R.prim_value(3)), + (R.prim_value(0), R.prim_value(0), R.prim_value(0)), + (R.prim_value(4), R.prim_value(26), R.prim_value(26)), + (R.prim_value(2), R.prim_value(3), R.prim_value(4)), + assume_inbound=False, + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 2, 9, 7), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(1), R.prim_value(2), R.prim_value(3)), + (R.prim_value(0), R.prim_value(0), R.prim_value(0)), + (R.prim_value(4), R.prim_value(26), R.prim_value(26)), + (R.prim_value(2), R.prim_value(3), R.prim_value(4)), + assume_inbound=False, + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_relu_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 8, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 2, 26, 26, 4), dtype="float32") = R.concat((gv, gv2), axis=1) + gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 8, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 26, 26, 2, 4), dtype="float32") = R.concat((gv, gv2), axis=3) + gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + @I.ir_module + class Expected_N4cHWC: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 8, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 % 4, i2, i3, i1 // 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 16), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i0 % 4, i2, i3, i1), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="N4cHWC", + kernel_layout="O4oHWI", + out_layout="N4cHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = R.nn.relu(gv) + lv2: R.Tensor((2, 4, 26, 26, 2), dtype="float32") = R.concat((gv, gv2), axis=4) + gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i4 * 4 + i1, i2, i3), index_dtype="int32" + ), + ) + R.output(gv3) + return gv3 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + # Concat axis after sub index + verify(Input, Expected_N4cHWC, {"relax.nn.conv2d": ["N4cHWC", "O4oHWI"]}) + + +def test_conv2d_relu_concat_split_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), R.Tensor((2, 4, 26, 26), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 2, 26, 26, 4), dtype="float32") = R.concat((gv, gv2), axis=1) + lv2: R.Tuple( + R.Tensor((2, 1, 26, 26, 4), dtype="float32"), + R.Tensor((2, 1, 26, 26, 4), dtype="float32"), + ) = R.split(gv3, indices_or_sections=2, axis=1) + lv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = lv2[0] + lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + lv5: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = lv2[1] + lv6: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv5, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv4: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((2, 4, 26, 26), dtype="float32"), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), R.Tensor((2, 4, 26, 26), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 26, 26, 2, 4), dtype="float32") = R.concat((gv, gv2), axis=3) + lv2: R.Tuple( + R.Tensor((2, 26, 26, 1, 4), dtype="float32"), + R.Tensor((2, 26, 26, 1, 4), dtype="float32"), + ) = R.split(gv3, indices_or_sections=2, axis=3) + lv3: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = lv2[0] + lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + lv5: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = lv2[1] + lv6: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv5, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + gv4: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((2, 4, 26, 26), dtype="float32"), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_N4cHWC: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), R.Tensor((2, 4, 26, 26), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 % 4, i2, i3, i1 // 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 16), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i0 % 4, i2, i3, i1), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="N4cHWC", + kernel_layout="O4oHWI", + out_layout="N4cHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26, 2), dtype="float32") = R.concat((gv, gv2), axis=4) + lv2: R.Tuple( + R.Tensor((2, 4, 26, 26, 1), dtype="float32"), + R.Tensor((2, 4, 26, 26, 1), dtype="float32"), + ) = R.split(gv3, indices_or_sections=2, axis=4) + lv3: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = lv2[0] + lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i4 * 4 + i1, i2, i3), index_dtype="int32" + ), + ) + lv5: R.Tensor((2, 4, 26, 26, 1), dtype="float32") = lv2[1] + lv6: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv5, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i4 * 4 + i1, i2, i3), index_dtype="int32" + ), + ) + gv4: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((2, 4, 26, 26), dtype="float32"), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + verify(Input, Expected_N4cHWC, {"relax.nn.conv2d": ["N4cHWC", "O4oHWI"]}) + + +def test_conv2d_relu_concat_split_sub_indexed_div_exception(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=4, axis=1) + R.output(gv4) + return gv4 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 2, 26, 26, 4), dtype="float32") = R.concat((gv, gv2), axis=1) + lv2: R.Tensor((2, 8, 26, 26), dtype="float32") = R.layout_transform( + gv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv4: R.Tuple( + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + R.Tensor((2, 2, 26, 26), dtype="float32"), + ) = R.split(lv2, indices_or_sections=4, axis=1) + R.output(gv4) + return gv4 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_maxpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0], + layout="NCHW", + out_layout="NCHW", + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 13, 13), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 13, 13, 4), dtype="float32") = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + layout="NCHW4c", + out_layout="NCHW4c", + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 13, 13), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 13, 13, 1, 4), dtype="float32") = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + layout="NHWC4c", + out_layout="NHWC4c", + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_avgpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], layout="NCHW") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 13, 13), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 13, 13, 4), dtype="float32") = R.nn.adaptive_avg_pool2d( + gv, output_size=[13, 13], layout="NCHW4c", out_layout="NCHW4c" + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NHWC4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 13, 13), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 4, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i2, i3, i1 // 4, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 3, 3, 16, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i2, i3, i1, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 26, 26, 1, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NHWC4c", + kernel_layout="OHWI4o", + out_layout="NHWC4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 13, 13, 1, 4), dtype="float32") = R.nn.adaptive_avg_pool2d( + gv, output_size=[13, 13], layout="NHWC4c", out_layout="NHWC4c" + ) + gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i3 * 4 + i4, i1, i2), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + verify(Input, Expected_NHWC4c, {"relax.nn.conv2d": ["NHWC4c", "OHWI4o"]}) + + +def test_conv2d_softmax_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.softmax(gv, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.softmax(lv2, axis=1) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_batchnorm_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm(gv, gamma, beta, moving_mean, moving_var, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm( + lv2, + gamma, + beta, + moving_mean, + moving_var, + axis=1, + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + momentum=0.10000000000000001, + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_layernorm_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.layer_norm( + gv, gamma, beta, axes=[-2, -1] + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.layer_norm( + gv, + gamma, + beta, + axes=[2, 3], + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_resize2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.image.resize2d(gv, (52, 52), layout="NCHW") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 52, 52), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 4, 52, 52), dtype="float32") = R.image.resize2d( + lv2, + R.shape([52, 52]), + layout="NCHW", + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_unknown_bias_dim_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + w2: R.Tensor(dtype="float32"), + ) -> R.Tensor(None, "float32"): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = w2 + gv + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + w2: R.Tensor(dtype="float32"), + ) -> R.Tensor(dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor(dtype="float32") = R.add(w2, lv2) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_binary_broadcast_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + @tvm.script.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + bias: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + gv, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(lv2, bias) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_binary_ewise_scalar_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, R.const(1, "float32")) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32"), + w: R.Tensor((4, 16, 3, 3), dtype="float32"), + ) -> R.Tensor((2, 4, 26, 26), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 16, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add( + gv, R.const(1.0, "float32") + ) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.layout_transform( + lv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + if __name__ == "__main__": tvm.testing.main()