diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index 9afdb7210cba..f347eddae760 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -50,19 +50,6 @@ namespace alter_op_layout { class AlterTransformMemorizerNode : public TransformMemorizerNode { public: static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode"; -}; - -/*! - * \brief Container that provides the transformation function for alter layout.. - */ -class AlterTransformMemorizer : public TransformMemorizer { - public: - AlterTransformMemorizer() {} - explicit AlterTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} - - AlterTransformMemorizerNode* operator->() { - return static_cast(get_mutable()); - } /*! * \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by @@ -102,7 +89,23 @@ class AlterTransformMemorizer : public TransformMemorizer { return GetRef(new_call); } - using TransformMemorizer::CallWithNewLayouts; + Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } +}; + +/*! + * \brief Container that provides the transformation function for alter layout.. + */ +class AlterTransformMemorizer : public TransformMemorizer { + public: + AlterTransformMemorizer() = default; + explicit AlterTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} + + AlterTransformMemorizerNode* operator->() { + return static_cast(get_mutable()); + } + using ContainerType = AlterTransformMemorizerNode; }; @@ -113,10 +116,12 @@ class AlterTransformMemorizer : public TransformMemorizer { */ Expr AlterOpLayout(const Expr& expr) { // TODO(@icemelon9): need to rerun type inference after applying an alter op. - AlterTransformMemorizer alterMemorizer(make_object()); - auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; }; - - return ForwardRewrite(expr, LayoutRewriter, fcontext); + AlterTransformMemorizer alter_memorizer(make_object()); + std::function fcontext = [=](const Call& call) -> ObjectRef { + return alter_memorizer; + }; + FForwardRewrite rewrite_func = LayoutRewriter; + return ForwardRewrite(expr, rewrite_func, fcontext); } } // namespace alter_op_layout diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index e74ea0115857..e10be508529e 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -58,22 +58,6 @@ class ConvertTransformMemorizerNode : public TransformMemorizerNode { explicit ConvertTransformMemorizerNode(Map> desired_layouts) : desired_layouts_(std::move(desired_layouts)) {} - /*! \brief A mapping of op_name to array of desired layouts for each input. */ - Map> desired_layouts_; -}; - -/*! - * \brief Container that provides the transformation function for convert layout. - */ -class ConvertTransformMemorizer : public TransformMemorizer { - public: - ConvertTransformMemorizer() {} - explicit ConvertTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} - - ConvertTransformMemorizerNode* operator->() { - return static_cast(get_mutable()); - } - /*! * \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the * desired layout as specified by the user. @@ -89,7 +73,7 @@ class ConvertTransformMemorizer : public TransformMemorizer { Expr new_e; bool modified = false; if (fconvert_layout.count(op)) { - auto desired_layouts = operator->()->desired_layouts_; + auto desired_layouts = desired_layouts_; if (desired_layouts.find(op->name) != desired_layouts.end()) { tvm::Array tinfos; for (auto& expr : ref_call->args) { @@ -124,7 +108,26 @@ class ConvertTransformMemorizer : public TransformMemorizer { return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span); } - using TransformMemorizer::CallWithNewLayouts; + Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } + + /*! \brief A mapping of op_name to array of desired layouts for each input. */ + Map> desired_layouts_; +}; + +/*! + * \brief Container that provides the transformation function for convert layout. + */ +class ConvertTransformMemorizer : public TransformMemorizer { + public: + ConvertTransformMemorizer() = default; + explicit ConvertTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} + + ConvertTransformMemorizerNode* operator->() { + return static_cast(get_mutable()); + } + using ContainerType = ConvertTransformMemorizerNode; }; diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index fbb7bc9cd985..7bfb31a299ad 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -57,6 +57,21 @@ class TransformMemorizerNode : public Object { } }; + /*! + * \brief Defines the call transformation for derived passes. The new layouts are defined by + * used for different targets using a packed func. + * \param ref_call The original call. + * \param new_attrs Updated attributes consistent with new layouts. + * \param new_args The traversed/recursed args to the call. + * \return The new Call after calling the packed func. + */ + virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, + const std::vector& new_args) = 0; + + virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } + /*! \brief The memorizer map. */ std::unordered_map memo; @@ -69,11 +84,9 @@ class TransformMemorizerNode : public Object { */ class TransformMemorizer : public ObjectRef { public: - TransformMemorizer() {} + TransformMemorizer() = default; explicit TransformMemorizer(ObjectPtr n) : ObjectRef(n) {} - virtual ~TransformMemorizer() {} - TransformMemorizerNode* operator->() { return static_cast(get_mutable()); } @@ -146,19 +159,6 @@ class TransformMemorizer : public ObjectRef { return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); } - /*! - * \brief Defines the call transformation for derived passes. The new layouts are defined by - * used for different targets using a packed func. - * \param ref_call The original call. - * \param new_attrs Updated attributes consistent with new layouts. - * \param new_args The traversed/recursed args to the call. - * \return The new Call after calling the packed func. - */ - virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, - const std::vector& new_args) = 0; - virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) { - return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); - } using ContainerType = TransformMemorizerNode; }; @@ -312,7 +312,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj if (ref_call->op.as()) { Op op = Downcast(ref_call->op); if (falter_layout.count(op) && !finfer_layout.count(op)) { - return memorizer.CallWithNewLayouts(ref_call, normal_new_args); + return memorizer->CallWithNewLayouts(ref_call, normal_new_args); } } } @@ -349,7 +349,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj } // new_op = alter(op) - Call new_call = memorizer.CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args); + Call new_call = memorizer->CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args); // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 468c4d40b942..67cdfdedce0e 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -69,20 +69,6 @@ def example(): return relay.Function([x, weight], z2) -############################################################################### -# Let us register layout alteration for a conv2d op so that we can apply the -# layout alteration pass on the example. How alter layout pass works is out -# the scope of this tutorial. - - -@relay.op.register_alter_op_layout("nn.conv2d", level=101) -def alter_conv2d(attrs, inputs, tinfos, out_type): - data, weight = inputs - new_attrs = dict(attrs) - new_attrs["data_layout"] = "NCHW16c" - return relay.nn.conv2d(data, weight, **new_attrs) - - ############################################################################### # Optimize the Program # -------------------- @@ -188,21 +174,6 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): mod3 = seq(mod) print(mod3) -############################################################################### -# The passes applied so far are target independent. The pass infra also -# provides a means to make pass target-aware. For example, the layout -# alteration pass falls in such category. - -with tvm.transform.PassContext(opt_level=3): - mod4 = seq(mod) -print(mod4) - -seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()]) -with tvm.transform.PassContext(opt_level=3): - with tvm.target.Target("llvm"): - mod5 = seq1(mod) -print(mod5) - ############################################################################## # Implement a Pass Using Python Decorator # ------------------------------------------ @@ -257,7 +228,6 @@ def visit_constant(self, c): tvm.transform.PrintIR(), relay.transform.EliminateCommonSubexpr(), relay.transform.FuseOps(), - relay.transform.AlterOpLayout(), ] )