From fd8a9d0bcef806785508de77a8a2dfd6f2cacb18 Mon Sep 17 00:00:00 2001 From: An Wang Date: Wed, 20 Oct 2021 13:59:49 -0700 Subject: [PATCH 01/13] before I delete a bunch of stuff --- python/tvm/relay/transform/transform.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 0dc07944836d..cca208b0d970 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1248,3 +1248,10 @@ def SplitArgs(max_function_args): The registered pass for constant folding. """ return _ffi_api.SplitArgs(max_function_args) + + +def FoldTypeTransformation(): + """ + Automatic function signature transformation + """ + return _ffi_api.FoldTypeTransformation() \ No newline at end of file From 4f6aaf7feb24370cfac0c8b2c2d19fcd2783731b Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 22 Oct 2021 16:03:53 -0700 Subject: [PATCH 02/13] it works halfway don't touch it --- .../transforms/fold_type_transformation.cc | 251 ++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 src/relay/transforms/fold_type_transformation.cc diff --git a/src/relay/transforms/fold_type_transformation.cc b/src/relay/transforms/fold_type_transformation.cc new file mode 100644 index 000000000000..1191a52bb68c --- /dev/null +++ b/src/relay/transforms/fold_type_transformation.cc @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/fold_type_transformation.cc + * \brief A pass for taking transforming relay graph function + * signatures. + */ + +#include +#include +#include +#include + +#include + + +namespace tvm { +namespace relay { + +/* Description of FoldTypeTransformation +TODO +*/ + +// class HeaderMutator : public ExprMutator { + +// } +using namespace tvm::tir; + +class FoldTypeTransformationRewriter : public MixedModeMutator { + int count = 0; + protected: + Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { + const CallNode* post_call_node = post.as(); + CHECK(post_call_node) << "Expected a CallNode, but got " << post; + + // std::cout << "pre call node " << pre_call_node->op << std::endl; + // std::cout << "pre call node " << pre_call_node->args << std::endl; + // std::cout << "post expr " << post << std::endl; + // CHECK(false) << "temp"; + + Expr cur_op = post_call_node->op; + + for (auto arg : pre_call_node->args) { + auto maybe_var_node = arg.as(); + if (maybe_var_node) { + std::string var_name = maybe_var_node->name_hint(); + + std::cout << "num map elements START " << input_transform_map_.size() << std::endl; + auto var = Downcast(arg); + input_transform_map_.insert(std::pair(var, pre_call_node)); + + auto it = input_transform_map_.find(var); + if (it != input_transform_map_.end()) { + // Checks that the function-level input var hasn't been an arg + // to a CallNode yet. + CHECK(!it->second) << "input with name '" << var->name_hint() << "' is fed into more than one call, aborting transformation"; + + it->second = pre_call_node; + + // Get the type to transform the function signature to + DataType out_dtype; + if (cur_op == cast_op_) { + auto attrs = pre_call_node->attrs.as(); + out_dtype = attrs->dtype; + } else if (cur_op == quantize_op_) { + auto attrs = pre_call_node->attrs.as(); + out_dtype = attrs->out_dtype; + } else { + CHECK(false) << "FoldTypeTransformation will only fold cast and quantize type transformations for function inputs."; + } + + // Mutate the var node type + VarNode* var_node = (VarNode*)maybe_var_node; + const TensorTypeNode* anno = var_node->type_annotation.as(); + auto mut_anno = (TensorTypeNode*) anno; + auto shape = anno->shape; + mut_anno->dtype = out_dtype; + + // TODO: Instead of mutating the var node in-place, create a new var node. + // This also requires updating the function signature. Need to store the var node + // in the input_transform_map_ probably, then update the function once all + // Rewrite_ calls are complete. + + return GetRef(var_node); + } else { + std::cout << "Did not find var with name " << var->name_hint() << " in the map" << std::endl; + } + } + } + + return Call(cur_op, post_call_node->args, pre_call_node->attrs, pre_call_node->type_args, pre_call_node->span); + } + + + // Expr VisitExpr_(const CallNode* node) { + // // this iterates from the bottom of the program up + // Op op = Downcast(node->op); + // std::cout << "op name " << op->name << std::endl; + + // for (auto arg : pre_call_node->args) { + // auto maybe_var_node = arg.as(); + // if (maybe_var_node) { + // std::string var_name = maybe_var_node->name_hint(); + // auto it = unvisited_input_names_.find(var_name); + // if (it != unvisited_input_names_.end()) { + // CHECK(cur_op == cast_op_) << "Expected a cast op, but got " << cur_op; + + // std::cout << "call attrs " << pre_call_node->attrs << std::endl; + // auto attrs = pre_call_node->attrs.as(); + // auto dtype = attrs->dtype; + + // auto this_is_a_thing = DataType::Int(32); + + // unvisited_input_names_.erase(it); + // std::cout << "Removing " << var_name << " from unvisited input names" << std::endl; + // } + // } + // } + + // Expr expr; + // if (op == quantize_op_) {// || op == cast_op_) { + // expr = GetRef(node); + // std::cout << "at a quantize op" << std::endl; + // // Get the type input names of the op + // auto inputs = node->args; + // std::cout << "INPUTS SI<<<<<<<<<<<<<<<<<<<<(); + // // auto node = expr.as(); + + // std::cout << "node ptr " << tensor_node << std::endl; + + // expr = ExprMutator::VisitExpr_(node); + // } else { + // expr = ExprMutator::VisitExpr_(node); + // } + + // // static const Op& op = Op::Get("nn.batch_flatten"); + // // return Call(oexpr + // } + + Expr VisitExpr_(const FunctionNode* node) { + function_count_++; + if (function_count_ > 1) { + CHECK(false) << "FoldTypeTransformation is supported for only single-function graphs"; + } + + tvm::Array ty_params; + bool all_ty_params_unchanged = true; + + for (auto ty_param : node->type_params) { + TypeVar new_ty_param = Downcast(VisitType(ty_param)); + ty_params.push_back(new_ty_param); + all_ty_params_unchanged &= new_ty_param.same_as(ty_param); + + std::cout << "type param" << ty_param << std::endl; + std::cout << "all params unchanged " << all_ty_params_unchanged << std::endl; + } + + tvm::Array params; + bool all_params_unchanged = true; + for (auto param : node->params) { + Var new_param = Downcast(this->Mutate(param)); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + // std::cout << "param " << param << std::endl; + std::string name = param->name_hint(); + unvisited_input_names_.insert(name); + + input_transform_map_.insert(std::pair(param, NULL)); + + std::cout << "all params unchanked " << all_params_unchanged << std::endl; + } + + auto ret_type = this->VisitType(node->ret_type); + auto body = this->Mutate(node->body); + + // std::cout << "ret type" << node->ret_type << std::endl; + // std::cout << "num type params" << params.size() << std::endl; + // std::cout << "num type params" << node->params.size() << std::endl; + + std::cout << "params unchanged ? " << all_params_unchanged << " " << all_ty_params_unchanged << std::endl; + std::cout << "body same? " << body.same_as(node->body) << std::endl; + if (all_ty_params_unchanged && all_params_unchanged && ret_type.same_as(node->ret_type) && + body.same_as(node->body)) { + return GetRef(node); + } else { + auto f = Function(params, body, ret_type, ty_params, node->attrs, node->span); + std::cout << "are we in here" << std::endl; + return f; + } + } + + const Op cast_op_ = Op::Get("cast"); + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); + + private: + // An input name is removed from this set when we visit a call node that + // references the corresponding input. For this pass, we expect that + // program-level inputs are only referenced once. + std::unordered_set unvisited_input_names_; + + // Maps function-level input to the first-encountered call node within + // the function that takes in that input. + std::map input_transform_map_; + // std::map> input_transform_map_; + + // Tracks number of functions in this program. + int function_count_; +}; + +Expr FoldTypeTransformation(const Expr& expr, const IRModule& mod) { + return FoldTypeTransformationRewriter().Mutate(expr); +} + +namespace transform { + +Pass FoldTypeTransformation() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FoldTypeTransformation(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "FoldTypeTransformation", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FoldTypeTransformation") + .set_body_typed(FoldTypeTransformation); + +} // namespace transform + +} // namespace relay +} // namespace tvm \ No newline at end of file From 83a2f2b53e3d30a0eb55948e486942c46d3b052a Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 22 Oct 2021 16:56:40 -0700 Subject: [PATCH 03/13] fold --- .../transforms/fold_type_transformation.cc | 170 +++++------------- .../relay/test_fold_type_transformation.py | 51 ++++++ 2 files changed, 92 insertions(+), 129 deletions(-) create mode 100644 tests/python/relay/test_fold_type_transformation.py diff --git a/src/relay/transforms/fold_type_transformation.cc b/src/relay/transforms/fold_type_transformation.cc index 1191a52bb68c..7d6f3f059d83 100644 --- a/src/relay/transforms/fold_type_transformation.cc +++ b/src/relay/transforms/fold_type_transformation.cc @@ -19,8 +19,10 @@ /*! * \file src/relay/transforms/fold_type_transformation.cc - * \brief A pass for taking transforming relay graph function - * signatures. + * \brief A pass for transforming relay graph function + * signatures such that when a function-level inputs is + * transformed by a subsequent cast or quantize operation, + * that operation is folded into the signature itself. */ #include @@ -28,21 +30,37 @@ #include #include -#include - - namespace tvm { namespace relay { -/* Description of FoldTypeTransformation -TODO -*/ - -// class HeaderMutator : public ExprMutator { - -// } -using namespace tvm::tir; - +/*! \brief This class transforms a relay module's function signature + * such that when a function-level input is transformed by a subsequent + * "cast" or "qnn.quantize" operation, that operation is folded into + * the signature itself. For example, + * + * def @main(%data: Tensor[(1, 3, 224, 224), float32]) { + * %0 = qnn.quantize(%data, 2f, 0, out_dtype="uint8"); + * add(%0, %0) + * } + * + * would be transformed to + * + * def @main(%data: Tensor[(1, 3, 224, 224), uint8]) { + * add(%0, %0) + * } + * + * Note that now it is the user's responsibility to modify their + * input pre-processing pipeline to satisfy the new signature's + * constraints. + * + * For this pass to fold a type transformation, the following conditions + * must be met: + * - The relay module must contain only a single function. + * - The type of each function-level input is transformed only once + * per program. + * - The type transformation operation must be either a "cast" + * or "qnn.quantize". + */ class FoldTypeTransformationRewriter : public MixedModeMutator { int count = 0; protected: @@ -50,27 +68,18 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { const CallNode* post_call_node = post.as(); CHECK(post_call_node) << "Expected a CallNode, but got " << post; - // std::cout << "pre call node " << pre_call_node->op << std::endl; - // std::cout << "pre call node " << pre_call_node->args << std::endl; - // std::cout << "post expr " << post << std::endl; - // CHECK(false) << "temp"; - - Expr cur_op = post_call_node->op; - + Expr cur_op = pre_call_node->op; for (auto arg : pre_call_node->args) { auto maybe_var_node = arg.as(); if (maybe_var_node) { - std::string var_name = maybe_var_node->name_hint(); - - std::cout << "num map elements START " << input_transform_map_.size() << std::endl; auto var = Downcast(arg); - input_transform_map_.insert(std::pair(var, pre_call_node)); - auto it = input_transform_map_.find(var); if (it != input_transform_map_.end()) { // Checks that the function-level input var hasn't been an arg // to a CallNode yet. - CHECK(!it->second) << "input with name '" << var->name_hint() << "' is fed into more than one call, aborting transformation"; + CHECK(!it->second) << "Function input with name '" << var->name_hint() + << "' is fed into more than one call; " + << "aborting transformation"; it->second = pre_call_node; @@ -83,7 +92,8 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { auto attrs = pre_call_node->attrs.as(); out_dtype = attrs->out_dtype; } else { - CHECK(false) << "FoldTypeTransformation will only fold cast and quantize type transformations for function inputs."; + CHECK(false) << "FoldTypeTransformation will only fold cast and " + << "quantize type transformations"; } // Mutate the var node type @@ -93,14 +103,10 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { auto shape = anno->shape; mut_anno->dtype = out_dtype; - // TODO: Instead of mutating the var node in-place, create a new var node. - // This also requires updating the function signature. Need to store the var node - // in the input_transform_map_ probably, then update the function once all - // Rewrite_ calls are complete. - return GetRef(var_node); } else { - std::cout << "Did not find var with name " << var->name_hint() << " in the map" << std::endl; + LOG(WARNING) << "Variable '" << var->name_hint() << "' encountered" + << " but wasn't registered as a function-level input"; } } } @@ -108,121 +114,27 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { return Call(cur_op, post_call_node->args, pre_call_node->attrs, pre_call_node->type_args, pre_call_node->span); } - - // Expr VisitExpr_(const CallNode* node) { - // // this iterates from the bottom of the program up - // Op op = Downcast(node->op); - // std::cout << "op name " << op->name << std::endl; - - // for (auto arg : pre_call_node->args) { - // auto maybe_var_node = arg.as(); - // if (maybe_var_node) { - // std::string var_name = maybe_var_node->name_hint(); - // auto it = unvisited_input_names_.find(var_name); - // if (it != unvisited_input_names_.end()) { - // CHECK(cur_op == cast_op_) << "Expected a cast op, but got " << cur_op; - - // std::cout << "call attrs " << pre_call_node->attrs << std::endl; - // auto attrs = pre_call_node->attrs.as(); - // auto dtype = attrs->dtype; - - // auto this_is_a_thing = DataType::Int(32); - - // unvisited_input_names_.erase(it); - // std::cout << "Removing " << var_name << " from unvisited input names" << std::endl; - // } - // } - // } - - // Expr expr; - // if (op == quantize_op_) {// || op == cast_op_) { - // expr = GetRef(node); - // std::cout << "at a quantize op" << std::endl; - // // Get the type input names of the op - // auto inputs = node->args; - // std::cout << "INPUTS SI<<<<<<<<<<<<<<<<<<<<(); - // // auto node = expr.as(); - - // std::cout << "node ptr " << tensor_node << std::endl; - - // expr = ExprMutator::VisitExpr_(node); - // } else { - // expr = ExprMutator::VisitExpr_(node); - // } - - // // static const Op& op = Op::Get("nn.batch_flatten"); - // // return Call(oexpr - // } - Expr VisitExpr_(const FunctionNode* node) { function_count_++; if (function_count_ > 1) { CHECK(false) << "FoldTypeTransformation is supported for only single-function graphs"; } - tvm::Array ty_params; - bool all_ty_params_unchanged = true; - - for (auto ty_param : node->type_params) { - TypeVar new_ty_param = Downcast(VisitType(ty_param)); - ty_params.push_back(new_ty_param); - all_ty_params_unchanged &= new_ty_param.same_as(ty_param); - - std::cout << "type param" << ty_param << std::endl; - std::cout << "all params unchanged " << all_ty_params_unchanged << std::endl; - } - - tvm::Array params; - bool all_params_unchanged = true; for (auto param : node->params) { - Var new_param = Downcast(this->Mutate(param)); - params.push_back(new_param); - all_params_unchanged &= param.same_as(new_param); - // std::cout << "param " << param << std::endl; - std::string name = param->name_hint(); - unvisited_input_names_.insert(name); - input_transform_map_.insert(std::pair(param, NULL)); - - std::cout << "all params unchanked " << all_params_unchanged << std::endl; } - - auto ret_type = this->VisitType(node->ret_type); auto body = this->Mutate(node->body); - // std::cout << "ret type" << node->ret_type << std::endl; - // std::cout << "num type params" << params.size() << std::endl; - // std::cout << "num type params" << node->params.size() << std::endl; - - std::cout << "params unchanged ? " << all_params_unchanged << " " << all_ty_params_unchanged << std::endl; - std::cout << "body same? " << body.same_as(node->body) << std::endl; - if (all_ty_params_unchanged && all_params_unchanged && ret_type.same_as(node->ret_type) && - body.same_as(node->body)) { - return GetRef(node); - } else { - auto f = Function(params, body, ret_type, ty_params, node->attrs, node->span); - std::cout << "are we in here" << std::endl; - return f; - } + return Function(node->params, body, node->ret_type, node->type_params, node->attrs, node->span); } const Op cast_op_ = Op::Get("cast"); const Op quantize_op_ = Op::Get("qnn.quantize"); - const Op dequantize_op_ = Op::Get("qnn.dequantize"); private: - // An input name is removed from this set when we visit a call node that - // references the corresponding input. For this pass, we expect that - // program-level inputs are only referenced once. - std::unordered_set unvisited_input_names_; - // Maps function-level input to the first-encountered call node within // the function that takes in that input. std::map input_transform_map_; - // std::map> input_transform_map_; // Tracks number of functions in this program. int function_count_; diff --git a/tests/python/relay/test_fold_type_transformation.py b/tests/python/relay/test_fold_type_transformation.py new file mode 100644 index 000000000000..a4980316b81e --- /dev/null +++ b/tests/python/relay/test_fold_type_transformation.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay + + +def test_simple_cast_fold(): + data = relay.var("data", shape=[1, 3, 224, 224], dtype="float32") + out = relay.cast(data, "float16") + out = relay.add(out, out) + mod = tvm.IRModule.from_expr(out) + mod = tvm.relay.transform.InferType()(mod) + mod = tvm.relay.transform.FoldTypeTransformation()(mod) + + data_fp16 = relay.var("data", shape=[1, 3, 224, 224], dtype="float16") + out = relay.add(data_fp16, data_fp16) + expected_mod = tvm.IRModule.from_expr(out) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert tvm.ir.structural_equal(mod, expected_mod) + + +def test_simple_quantize_fold(): + data = relay.var("data", shape=[1, 3, 224, 224], dtype="float32") + out = relay.qnn.op.quantize(data, relay.const(2.0), relay.const(0), out_dtype="uint8") + out = relay.add(out, out) + + mod = tvm.IRModule.from_expr(out) + mod = tvm.relay.transform.InferType()(mod) + mod = tvm.relay.transform.FoldTypeTransformation()(mod) + + data_fp16 = relay.var("data", shape=[1, 3, 224, 224], dtype="uint8") + out = relay.add(data_fp16, data_fp16) + expected_mod = tvm.IRModule.from_expr(out) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert tvm.ir.structural_equal(mod, expected_mod) From f43b2cfe633928c659579f5e36cc405c8768462c Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 25 Oct 2021 13:02:39 -0700 Subject: [PATCH 04/13] lint --- python/tvm/relay/transform/transform.py | 3 ++- src/relay/transforms/fold_type_transformation.cc | 12 +++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cca208b0d970..018e7364028f 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1254,4 +1254,5 @@ def FoldTypeTransformation(): """ Automatic function signature transformation """ - return _ffi_api.FoldTypeTransformation() \ No newline at end of file + return _ffi_api.FoldTypeTransformation() + diff --git a/src/relay/transforms/fold_type_transformation.cc b/src/relay/transforms/fold_type_transformation.cc index 7d6f3f059d83..03a917d045a3 100644 --- a/src/relay/transforms/fold_type_transformation.cc +++ b/src/relay/transforms/fold_type_transformation.cc @@ -62,7 +62,7 @@ namespace relay { * or "qnn.quantize". */ class FoldTypeTransformationRewriter : public MixedModeMutator { - int count = 0; + protected: Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { const CallNode* post_call_node = post.as(); @@ -97,9 +97,9 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { } // Mutate the var node type - VarNode* var_node = (VarNode*)maybe_var_node; + VarNode* var_node = reinterpret_cast(maybe_var_node); const TensorTypeNode* anno = var_node->type_annotation.as(); - auto mut_anno = (TensorTypeNode*) anno; + auto mut_anno = reinterpret_cast(anno); auto shape = anno->shape; mut_anno->dtype = out_dtype; @@ -111,7 +111,8 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { } } - return Call(cur_op, post_call_node->args, pre_call_node->attrs, pre_call_node->type_args, pre_call_node->span); + return Call(cur_op, post_call_node->args, pre_call_node->attrs, pre_call_node->type_args, + pre_call_node->span); } Expr VisitExpr_(const FunctionNode* node) { @@ -160,4 +161,5 @@ TVM_REGISTER_GLOBAL("relay._transform.FoldTypeTransformation") } // namespace transform } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm + From 79e953f2d940ea2f68aa21317ac0331b5ad8f054 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 25 Oct 2021 13:14:43 -0700 Subject: [PATCH 05/13] lint --- src/relay/transforms/fold_type_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/fold_type_transformation.cc b/src/relay/transforms/fold_type_transformation.cc index 03a917d045a3..ff3fb98b2c58 100644 --- a/src/relay/transforms/fold_type_transformation.cc +++ b/src/relay/transforms/fold_type_transformation.cc @@ -112,7 +112,7 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { } return Call(cur_op, post_call_node->args, pre_call_node->attrs, pre_call_node->type_args, - pre_call_node->span); + pre_call_node->span); } Expr VisitExpr_(const FunctionNode* node) { From 5b380472b74f603742aa00cffeb409d1b242d837 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 25 Oct 2021 15:51:38 -0700 Subject: [PATCH 06/13] lint --- src/relay/transforms/fold_type_transformation.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/transforms/fold_type_transformation.cc b/src/relay/transforms/fold_type_transformation.cc index ff3fb98b2c58..b6138f13fa2b 100644 --- a/src/relay/transforms/fold_type_transformation.cc +++ b/src/relay/transforms/fold_type_transformation.cc @@ -62,7 +62,6 @@ namespace relay { * or "qnn.quantize". */ class FoldTypeTransformationRewriter : public MixedModeMutator { - protected: Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { const CallNode* post_call_node = post.as(); From 5b9c865d15a56a124a20401304c8bd486bc8dc63 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 25 Oct 2021 16:22:53 -0700 Subject: [PATCH 07/13] lint --- src/relay/transforms/fold_type_transformation.cc | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/relay/transforms/fold_type_transformation.cc b/src/relay/transforms/fold_type_transformation.cc index b6138f13fa2b..69f44b06d4e0 100644 --- a/src/relay/transforms/fold_type_transformation.cc +++ b/src/relay/transforms/fold_type_transformation.cc @@ -27,8 +27,8 @@ #include #include -#include #include +#include namespace tvm { namespace relay { @@ -37,22 +37,22 @@ namespace relay { * such that when a function-level input is transformed by a subsequent * "cast" or "qnn.quantize" operation, that operation is folded into * the signature itself. For example, - * + * * def @main(%data: Tensor[(1, 3, 224, 224), float32]) { * %0 = qnn.quantize(%data, 2f, 0, out_dtype="uint8"); * add(%0, %0) * } - * + * * would be transformed to - * + * * def @main(%data: Tensor[(1, 3, 224, 224), uint8]) { * add(%0, %0) * } - * + * * Note that now it is the user's responsibility to modify their * input pre-processing pipeline to satisfy the new signature's - * constraints. - * + * constraints. + * * For this pass to fold a type transformation, the following conditions * must be met: * - The relay module must contain only a single function. @@ -161,4 +161,3 @@ TVM_REGISTER_GLOBAL("relay._transform.FoldTypeTransformation") } // namespace relay } // namespace tvm - From caa434321d2c4224c77321b9cdcad964ee016f9d Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 25 Oct 2021 16:38:42 -0700 Subject: [PATCH 08/13] no newlines --- python/tvm/relay/transform/transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 018e7364028f..10e1fb93cdc8 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1255,4 +1255,3 @@ def FoldTypeTransformation(): Automatic function signature transformation """ return _ffi_api.FoldTypeTransformation() - From a8900c8cf76919ff23054f16f4f67a93466cac10 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 26 Oct 2021 09:17:59 -0700 Subject: [PATCH 09/13] lint --- src/relay/transforms/fold_type_transformation.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/fold_type_transformation.cc b/src/relay/transforms/fold_type_transformation.cc index 69f44b06d4e0..45e5ca953da1 100644 --- a/src/relay/transforms/fold_type_transformation.cc +++ b/src/relay/transforms/fold_type_transformation.cc @@ -96,9 +96,9 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { } // Mutate the var node type - VarNode* var_node = reinterpret_cast(maybe_var_node); + VarNode* var_node = const_cast(maybe_var_node); const TensorTypeNode* anno = var_node->type_annotation.as(); - auto mut_anno = reinterpret_cast(anno); + auto mut_anno = const_cast(anno); auto shape = anno->shape; mut_anno->dtype = out_dtype; From 255a21ec6df51f4b5bc50711a33427de4233187e Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 26 Oct 2021 16:21:04 -0700 Subject: [PATCH 10/13] renames --- python/tvm/relay/transform/transform.py | 30 +++++++++++-- ...mation.cc => lift_dtype_transformation.cc} | 44 ++++++++++--------- ...on.py => test_lift_dype_transformation.py} | 4 +- 3 files changed, 52 insertions(+), 26 deletions(-) rename src/relay/transforms/{fold_type_transformation.cc => lift_dtype_transformation.cc} (76%) rename tests/python/relay/{test_fold_type_transformation.py => test_lift_dype_transformation.py} (94%) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 10e1fb93cdc8..20e94d8ade8a 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1250,8 +1250,30 @@ def SplitArgs(max_function_args): return _ffi_api.SplitArgs(max_function_args) -def FoldTypeTransformation(): - """ - Automatic function signature transformation +def LiftDtypeTransformation(): + """ + Automatic function signature transformation to fold type transformations. + For example, when a function has a tensor of type float32 as a + parameter, and the first operation on that tensor is a cast or quantize + operation, that operation is folded into the function signature -- + the resultant type of the first operation is the new type of the tensor + parameter. + + For this pass to fold a type transformation, the following conditions + must be met: + - The relay module must contain only a single function. + - The type transformation operation must be either a "cast" + or "qnn.quantize". + - Each function parameter is used only once + per program. There should be no structure that looks like: + + in in + / \ but the following is ok: | + cast add cast + + Returns + ------- + ret : tvm.transform.Pass + The registered pass. """ - return _ffi_api.FoldTypeTransformation() + return _ffi_api.LiftDtypeTransformation() diff --git a/src/relay/transforms/fold_type_transformation.cc b/src/relay/transforms/lift_dtype_transformation.cc similarity index 76% rename from src/relay/transforms/fold_type_transformation.cc rename to src/relay/transforms/lift_dtype_transformation.cc index 45e5ca953da1..b63eeca3474c 100644 --- a/src/relay/transforms/fold_type_transformation.cc +++ b/src/relay/transforms/lift_dtype_transformation.cc @@ -18,9 +18,9 @@ */ /*! - * \file src/relay/transforms/fold_type_transformation.cc + * \file src/relay/transforms/lift_dtype_transformation.cc * \brief A pass for transforming relay graph function - * signatures such that when a function-level inputs is + * signatures such that when a function parameter is * transformed by a subsequent cast or quantize operation, * that operation is folded into the signature itself. */ @@ -34,7 +34,7 @@ namespace tvm { namespace relay { /*! \brief This class transforms a relay module's function signature - * such that when a function-level input is transformed by a subsequent + * such that when a function parameter is transformed by a subsequent * "cast" or "qnn.quantize" operation, that operation is folded into * the signature itself. For example, * @@ -56,16 +56,20 @@ namespace relay { * For this pass to fold a type transformation, the following conditions * must be met: * - The relay module must contain only a single function. - * - The type of each function-level input is transformed only once - * per program. * - The type transformation operation must be either a "cast" * or "qnn.quantize". + * - Each function parameter is used only once + * per program. There should be no structure that looks like: + * + * in in + * / \ but the following is ok: | + * cast add cast */ -class FoldTypeTransformationRewriter : public MixedModeMutator { +class LiftDtypeTransformationRewriter : public MixedModeMutator { protected: Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { const CallNode* post_call_node = post.as(); - CHECK(post_call_node) << "Expected a CallNode, but got " << post; + ICHECK(post_call_node) << "Expected a CallNode, but got " << post; Expr cur_op = pre_call_node->op; for (auto arg : pre_call_node->args) { @@ -74,9 +78,9 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { auto var = Downcast(arg); auto it = input_transform_map_.find(var); if (it != input_transform_map_.end()) { - // Checks that the function-level input var hasn't been an arg + // Checks that the function parameter var hasn't been an arg // to a CallNode yet. - CHECK(!it->second) << "Function input with name '" << var->name_hint() + CHECK(!it->second) << "Function param with name '" << var->name_hint() << "' is fed into more than one call; " << "aborting transformation"; @@ -91,7 +95,7 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { auto attrs = pre_call_node->attrs.as(); out_dtype = attrs->out_dtype; } else { - CHECK(false) << "FoldTypeTransformation will only fold cast and " + CHECK(false) << "LiftDtypeTransformation will only fold cast and " << "quantize type transformations"; } @@ -105,7 +109,7 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { return GetRef(var_node); } else { LOG(WARNING) << "Variable '" << var->name_hint() << "' encountered" - << " but wasn't registered as a function-level input"; + << " but wasn't registered as a function parameter"; } } } @@ -117,7 +121,7 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { Expr VisitExpr_(const FunctionNode* node) { function_count_++; if (function_count_ > 1) { - CHECK(false) << "FoldTypeTransformation is supported for only single-function graphs"; + CHECK(false) << "LiftDtypeTransformation is supported for only single-function graphs"; } for (auto param : node->params) { @@ -132,7 +136,7 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { const Op quantize_op_ = Op::Get("qnn.quantize"); private: - // Maps function-level input to the first-encountered call node within + // Maps function parameter to the first-encountered call node within // the function that takes in that input. std::map input_transform_map_; @@ -140,22 +144,22 @@ class FoldTypeTransformationRewriter : public MixedModeMutator { int function_count_; }; -Expr FoldTypeTransformation(const Expr& expr, const IRModule& mod) { - return FoldTypeTransformationRewriter().Mutate(expr); +Expr LiftDtypeTransformation(const Expr& expr, const IRModule& mod) { + return LiftDtypeTransformationRewriter().Mutate(expr); } namespace transform { -Pass FoldTypeTransformation() { +Pass LiftDtypeTransformation() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(FoldTypeTransformation(f, m)); + return Downcast(LiftDtypeTransformation(f, m)); }; - return CreateFunctionPass(pass_func, 0, "FoldTypeTransformation", {}); + return CreateFunctionPass(pass_func, 0, "LiftDtypeTransformation", {}); } -TVM_REGISTER_GLOBAL("relay._transform.FoldTypeTransformation") - .set_body_typed(FoldTypeTransformation); +TVM_REGISTER_GLOBAL("relay._transform.LiftDtypeTransformation") + .set_body_typed(LiftDtypeTransformation); } // namespace transform diff --git a/tests/python/relay/test_fold_type_transformation.py b/tests/python/relay/test_lift_dype_transformation.py similarity index 94% rename from tests/python/relay/test_fold_type_transformation.py rename to tests/python/relay/test_lift_dype_transformation.py index a4980316b81e..2e7ab1fcf552 100644 --- a/tests/python/relay/test_fold_type_transformation.py +++ b/tests/python/relay/test_lift_dype_transformation.py @@ -24,7 +24,7 @@ def test_simple_cast_fold(): out = relay.add(out, out) mod = tvm.IRModule.from_expr(out) mod = tvm.relay.transform.InferType()(mod) - mod = tvm.relay.transform.FoldTypeTransformation()(mod) + mod = tvm.relay.transform.LiftDtypeTransformation()(mod) data_fp16 = relay.var("data", shape=[1, 3, 224, 224], dtype="float16") out = relay.add(data_fp16, data_fp16) @@ -41,7 +41,7 @@ def test_simple_quantize_fold(): mod = tvm.IRModule.from_expr(out) mod = tvm.relay.transform.InferType()(mod) - mod = tvm.relay.transform.FoldTypeTransformation()(mod) + mod = tvm.relay.transform.LiftDtypeTransformation()(mod) data_fp16 = relay.var("data", shape=[1, 3, 224, 224], dtype="uint8") out = relay.add(data_fp16, data_fp16) From 93bf3261a1e4360adbd37611fc933d98f1c8b58f Mon Sep 17 00:00:00 2001 From: An Wang Date: Thu, 28 Oct 2021 15:42:12 -0700 Subject: [PATCH 11/13] increase opt level to 4, address feedback --- src/relay/transforms/lift_dtype_transformation.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/lift_dtype_transformation.cc b/src/relay/transforms/lift_dtype_transformation.cc index b63eeca3474c..8633bd83ddde 100644 --- a/src/relay/transforms/lift_dtype_transformation.cc +++ b/src/relay/transforms/lift_dtype_transformation.cc @@ -51,7 +51,8 @@ namespace relay { * * Note that now it is the user's responsibility to modify their * input pre-processing pipeline to satisfy the new signature's - * constraints. + * constraints. Care should especially be taken when lifting a + * quantize transformation. * * For this pass to fold a type transformation, the following conditions * must be met: @@ -60,7 +61,7 @@ namespace relay { * or "qnn.quantize". * - Each function parameter is used only once * per program. There should be no structure that looks like: - * + * * in in * / \ but the following is ok: | * cast add cast @@ -68,6 +69,9 @@ namespace relay { class LiftDtypeTransformationRewriter : public MixedModeMutator { protected: Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { + // This rewrite identifies and removes the op that transforms the + // type of a function parameter, then updates the parameter with the + // expected output dtype of the removed op. const CallNode* post_call_node = post.as(); ICHECK(post_call_node) << "Expected a CallNode, but got " << post; @@ -155,7 +159,7 @@ Pass LiftDtypeTransformation() { [=](Function f, IRModule m, PassContext pc) { return Downcast(LiftDtypeTransformation(f, m)); }; - return CreateFunctionPass(pass_func, 0, "LiftDtypeTransformation", {}); + return CreateFunctionPass(pass_func, 4, "LiftDtypeTransformation", {}); } TVM_REGISTER_GLOBAL("relay._transform.LiftDtypeTransformation") From d4c8fba48c546df190a8d3c090fcd58f27cb8df7 Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 29 Oct 2021 15:44:28 -0700 Subject: [PATCH 12/13] whitespace --- python/tvm/relay/transform/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20e94d8ade8a..4c5002f987b0 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1255,7 +1255,7 @@ def LiftDtypeTransformation(): Automatic function signature transformation to fold type transformations. For example, when a function has a tensor of type float32 as a parameter, and the first operation on that tensor is a cast or quantize - operation, that operation is folded into the function signature -- + operation, that operation is folded into the function signature -- the resultant type of the first operation is the new type of the tensor parameter. @@ -1266,7 +1266,7 @@ def LiftDtypeTransformation(): or "qnn.quantize". - Each function parameter is used only once per program. There should be no structure that looks like: - + in in / \ but the following is ok: | cast add cast From 398f7f77f40d883822af77bdee48d6c9f4641e4e Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 2 Nov 2021 10:24:42 -0700 Subject: [PATCH 13/13] lint --- python/tvm/relay/transform/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 4c5002f987b0..02e6b012af72 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1268,7 +1268,7 @@ def LiftDtypeTransformation(): per program. There should be no structure that looks like: in in - / \ but the following is ok: | + | \ but the following is ok: | cast add cast Returns