diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 0dc07944836d..02e6b012af72 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1248,3 +1248,32 @@ def SplitArgs(max_function_args): The registered pass for constant folding. """ return _ffi_api.SplitArgs(max_function_args) + + +def LiftDtypeTransformation(): + """ + Automatic function signature transformation to fold type transformations. + For example, when a function has a tensor of type float32 as a + parameter, and the first operation on that tensor is a cast or quantize + operation, that operation is folded into the function signature -- + the resultant type of the first operation is the new type of the tensor + parameter. + + For this pass to fold a type transformation, the following conditions + must be met: + - The relay module must contain only a single function. + - The type transformation operation must be either a "cast" + or "qnn.quantize". + - Each function parameter is used only once + per program. There should be no structure that looks like: + + in in + | \ but the following is ok: | + cast add cast + + Returns + ------- + ret : tvm.transform.Pass + The registered pass. + """ + return _ffi_api.LiftDtypeTransformation() diff --git a/src/relay/transforms/lift_dtype_transformation.cc b/src/relay/transforms/lift_dtype_transformation.cc new file mode 100644 index 000000000000..8633bd83ddde --- /dev/null +++ b/src/relay/transforms/lift_dtype_transformation.cc @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/lift_dtype_transformation.cc + * \brief A pass for transforming relay graph function + * signatures such that when a function parameter is + * transformed by a subsequent cast or quantize operation, + * that operation is folded into the signature itself. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief This class transforms a relay module's function signature + * such that when a function parameter is transformed by a subsequent + * "cast" or "qnn.quantize" operation, that operation is folded into + * the signature itself. For example, + * + * def @main(%data: Tensor[(1, 3, 224, 224), float32]) { + * %0 = qnn.quantize(%data, 2f, 0, out_dtype="uint8"); + * add(%0, %0) + * } + * + * would be transformed to + * + * def @main(%data: Tensor[(1, 3, 224, 224), uint8]) { + * add(%0, %0) + * } + * + * Note that now it is the user's responsibility to modify their + * input pre-processing pipeline to satisfy the new signature's + * constraints. Care should especially be taken when lifting a + * quantize transformation. + * + * For this pass to fold a type transformation, the following conditions + * must be met: + * - The relay module must contain only a single function. + * - The type transformation operation must be either a "cast" + * or "qnn.quantize". + * - Each function parameter is used only once + * per program. There should be no structure that looks like: + * + * in in + * / \ but the following is ok: | + * cast add cast + */ +class LiftDtypeTransformationRewriter : public MixedModeMutator { + protected: + Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { + // This rewrite identifies and removes the op that transforms the + // type of a function parameter, then updates the parameter with the + // expected output dtype of the removed op. + const CallNode* post_call_node = post.as(); + ICHECK(post_call_node) << "Expected a CallNode, but got " << post; + + Expr cur_op = pre_call_node->op; + for (auto arg : pre_call_node->args) { + auto maybe_var_node = arg.as(); + if (maybe_var_node) { + auto var = Downcast(arg); + auto it = input_transform_map_.find(var); + if (it != input_transform_map_.end()) { + // Checks that the function parameter var hasn't been an arg + // to a CallNode yet. + CHECK(!it->second) << "Function param with name '" << var->name_hint() + << "' is fed into more than one call; " + << "aborting transformation"; + + it->second = pre_call_node; + + // Get the type to transform the function signature to + DataType out_dtype; + if (cur_op == cast_op_) { + auto attrs = pre_call_node->attrs.as(); + out_dtype = attrs->dtype; + } else if (cur_op == quantize_op_) { + auto attrs = pre_call_node->attrs.as(); + out_dtype = attrs->out_dtype; + } else { + CHECK(false) << "LiftDtypeTransformation will only fold cast and " + << "quantize type transformations"; + } + + // Mutate the var node type + VarNode* var_node = const_cast(maybe_var_node); + const TensorTypeNode* anno = var_node->type_annotation.as(); + auto mut_anno = const_cast(anno); + auto shape = anno->shape; + mut_anno->dtype = out_dtype; + + return GetRef(var_node); + } else { + LOG(WARNING) << "Variable '" << var->name_hint() << "' encountered" + << " but wasn't registered as a function parameter"; + } + } + } + + return Call(cur_op, post_call_node->args, pre_call_node->attrs, pre_call_node->type_args, + pre_call_node->span); + } + + Expr VisitExpr_(const FunctionNode* node) { + function_count_++; + if (function_count_ > 1) { + CHECK(false) << "LiftDtypeTransformation is supported for only single-function graphs"; + } + + for (auto param : node->params) { + input_transform_map_.insert(std::pair(param, NULL)); + } + auto body = this->Mutate(node->body); + + return Function(node->params, body, node->ret_type, node->type_params, node->attrs, node->span); + } + + const Op cast_op_ = Op::Get("cast"); + const Op quantize_op_ = Op::Get("qnn.quantize"); + + private: + // Maps function parameter to the first-encountered call node within + // the function that takes in that input. + std::map input_transform_map_; + + // Tracks number of functions in this program. + int function_count_; +}; + +Expr LiftDtypeTransformation(const Expr& expr, const IRModule& mod) { + return LiftDtypeTransformationRewriter().Mutate(expr); +} + +namespace transform { + +Pass LiftDtypeTransformation() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LiftDtypeTransformation(f, m)); + }; + return CreateFunctionPass(pass_func, 4, "LiftDtypeTransformation", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.LiftDtypeTransformation") + .set_body_typed(LiftDtypeTransformation); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_lift_dype_transformation.py b/tests/python/relay/test_lift_dype_transformation.py new file mode 100644 index 000000000000..2e7ab1fcf552 --- /dev/null +++ b/tests/python/relay/test_lift_dype_transformation.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay + + +def test_simple_cast_fold(): + data = relay.var("data", shape=[1, 3, 224, 224], dtype="float32") + out = relay.cast(data, "float16") + out = relay.add(out, out) + mod = tvm.IRModule.from_expr(out) + mod = tvm.relay.transform.InferType()(mod) + mod = tvm.relay.transform.LiftDtypeTransformation()(mod) + + data_fp16 = relay.var("data", shape=[1, 3, 224, 224], dtype="float16") + out = relay.add(data_fp16, data_fp16) + expected_mod = tvm.IRModule.from_expr(out) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert tvm.ir.structural_equal(mod, expected_mod) + + +def test_simple_quantize_fold(): + data = relay.var("data", shape=[1, 3, 224, 224], dtype="float32") + out = relay.qnn.op.quantize(data, relay.const(2.0), relay.const(0), out_dtype="uint8") + out = relay.add(out, out) + + mod = tvm.IRModule.from_expr(out) + mod = tvm.relay.transform.InferType()(mod) + mod = tvm.relay.transform.LiftDtypeTransformation()(mod) + + data_fp16 = relay.var("data", shape=[1, 3, 224, 224], dtype="uint8") + out = relay.add(data_fp16, data_fp16) + expected_mod = tvm.IRModule.from_expr(out) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert tvm.ir.structural_equal(mod, expected_mod)