From ed31ec7c73f0651a241354fb05a14607e67ad4af Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Fri, 17 Sep 2021 16:52:21 -0700 Subject: [PATCH] [checkpoint] nuke device_annotation.cc and python bindings Probably lots of dangling refs on the python side. Original handling of default devices and the '0' device type target not brought over and needs to be recovered. --- include/tvm/relay/analysis.h | 10 - python/tvm/relay/analysis/analysis.py | 17 - python/tvm/relay/op/_tensor.py | 1 - python/tvm/relay/transform/transform.py | 21 - src/relay/backend/build_module.cc | 56 +- src/relay/backend/te_compiler.cc | 2 + src/relay/transforms/device_annotation.cc | 402 ------------ src/relay/transforms/device_planner.cc | 52 +- tests/python/relay/test_pass_annotation.py | 671 --------------------- 9 files changed, 55 insertions(+), 1177 deletions(-) delete mode 100644 src/relay/transforms/device_annotation.cc delete mode 100644 tests/python/relay/test_pass_annotation.py diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 176ff9c8cd6f2..0f85587262ac4 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -211,16 +211,6 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const IRModule& mod); */ TVM_DLL tvm::Array AllTypeVars(const Type& t, const IRModule& mod); -/*! - * \brief Collect the device annotation operators. - * - * \param expr The expression. - * - * \return The annotated expression to device type mapping for annotation ops. - */ -TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); -TVM_DLL Map CollectAllDeviceAnnotationOps(const IRModule& mod); - /*! * \brief Finds cases that the given match expression does not catch, if any. * diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index f76bd3eccab05..b627005735815 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -253,23 +253,6 @@ def all_dtypes(expr): return set(_ffi_api.all_dtypes(expr)) -def collect_device_annotation_ops(expr): - """Collect the device annotation ops for the given expression. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - ret : Dict[tvm.relay.Expr, int] - A dictionary mapping tvm.relay.Expr to device type where the keys are - annotation expressions. - """ - return _ffi_api.CollectDeviceAnnotationOps(expr) - - def get_total_mac_number(expr): """ Count the number of MACs (multiply-accumulate) of a model diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index d7d99c017b2bf..18ce93322f434 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -91,7 +91,6 @@ register_broadcast_schedule("fast_erf") # a fake on_device schedule. # this will not be used in actual computation -# as on_device will be removed during DeviceAnnotation pass register_injective_schedule("on_device") diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 4b09be1f2cbd4..96eb6837003ad 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -544,27 +544,6 @@ def MergeCompilerRegions(): return _ffi_api.MergeCompilerRegions() -def RewriteAnnotatedOps(fallback_device): - """Rewrite the annotated program where annotation operators, e.g. - `on_device`, mark which device an expression should be scheduled to. - This pass helps heterogeneous execution where different operators may need - to be allocated on various devices. - - Parameters - ---------- - fallback_device : int - The fallback device type. It is also used as the default device for - operators with no annotated device. - - Returns - ------- - ret: tvm.transform.Pass - The registered pass that rewrites an expression with annotated - `on_device` operators. - """ - return _ffi_api.RewriteAnnotatedOps(fallback_device) - - def ToANormalForm(): """Turn Graph Normal Form expression into A Normal Form Expression. The scope of the root expression is the global scope. diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 80db7f576fa33..34e3117cf61ed 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -325,15 +325,16 @@ class RelayBuildModule : public runtime::ModuleNode { transform::PassContext pass_ctx = PassContext::Current(); Optional opt_fallback_dev = pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); - DLDeviceType fallback_dev = static_cast(opt_fallback_dev.value()->value); - ICHECK_GT(fallback_dev, 0U); -#if 0 - // TODO(mbs): Remove - if (targets_.size() > 1) { - relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev); - } -#endif - pass_seqs.push_back(transform::PlanDevices(fallback_dev)); + DLDeviceType default_device_type = static_cast(opt_fallback_dev.value()->value); + ICHECK_GT(default_device_type, 0U); + // What about the implied 'default' target with 'device type' 0? + UpdateHeterogeneousInputs(default_device_type); + // Make sure the 'default' target is always available keyed by 0. + // This is the current convention for conveying which of the >= 2 targets is the default. + targets_.Set(kNullDeviceType, CreateDefaultTarget(default_device_type)); + // TODO(mbs): Used to be some obsure logic for choosing a different fallback_device + // from the existing "on_device" annotations. What is that for? + pass_seqs.push_back(transform::PlanDevices(default_device_type)); // Fuse the operations if it is needed. pass_seqs.push_back(transform::FuseOps()); @@ -411,43 +412,6 @@ class RelayBuildModule : public runtime::ModuleNode { } } - /*! - * \brief Execute the device annotation passes to update the input program and - * target information. - * - * \param relay_module The input Relay module. - * \param fallback_device The fallback device for heterogeneous execution. - * - * \return updated_module The updated module after device annotation. - */ - IRModule RunDeviceAnnotationPass(const IRModule& relay_module, int fallback_device) { - UpdateHeterogeneousInputs(fallback_device); - - // If there's a unique device type used by all "on_device" CallNodes then use that - // as the fallback_device. - // TODO(mbs): This defaulting only roughly matches the original behavior. We should - // cleanup all the logic around default host and device targets. - Map annotations = CollectAllDeviceAnnotationOps(relay_module); - if (!annotations.empty()) { - std::unordered_set device_types; - for (const auto& pair : annotations) { - device_types.insert(static_cast((*annotations.begin()).second->value)); - } - if (device_types.size() == 1UL) { - fallback_device = *device_types.begin(); - } - } - // Make sure the 'default' target is always available keyed by 0. - // This is the current convention for conveying which of the >= 2 targets is the default. - targets_.Set(kNullDeviceType, CreateDefaultTarget(fallback_device)); - - // Insert "device_copy" CallNodes to account for any user-supplied "on_device" CallNodes. - auto updated_module = transform::RewriteAnnotatedOps(fallback_device)(relay_module); - ICHECK(updated_module.defined()); - - return updated_module; - } - /*! * \brief Compile a Relay IR module to runtime module. * diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 44fe7d1c7beaf..b60161999ece5 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -471,6 +471,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { auto device_copy = IsDeviceCopy(func); if (std::get<0>(device_copy)) { + // Record that device copy source and destination devices so the device planner can + // still follow along. auto source_device = std::get<1>(device_copy); auto dst_device = std::get<2>(device_copy); tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device)); diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc deleted file mode 100644 index b890f0760fab9..0000000000000 --- a/src/relay/transforms/device_annotation.cc +++ /dev/null @@ -1,402 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file device_annotation.cc - * \brief Passes to rewrite annotated program and retrieve the device allocation - * of expression. - * - * The following passes are performed: - * 1. Validate the unnecessary and redundant annotation. - * 2. Rewrite the annotated program and insert data copy operators. - * 3. Collect the device allocation of each expression. - */ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "../op/annotation/annotation.h" -#include "../op/memory/device_copy.h" - -namespace tvm { -namespace relay { - -namespace { - -bool IsDeviceCopyNode(const ExprNode* node) { - if (!node->IsInstance()) return false; - const auto* call_node = static_cast(node); - - if (call_node->attrs.as()) { - return true; - } - - auto tir_call_attrs = call_node->attrs.as(); - if (tir_call_attrs) { - auto metadata = tir_call_attrs->metadata; - return metadata.count("source_device") == 1 && metadata.count("dst_device") == 1; - } - - return false; -} - -} // namespace - -/*! - * \brief Builds a map from expression to device type based on existing - * "on_device" CallNodes. - * - * Only "on_device" CallNodes, their single args, and tuple-projection from such will be indexed. - */ -// TODO(mbs): Retire. -class ValidateAnnotation : private ExprVisitor { - public: - static std::unordered_map Validate(const Expr& expr) { - ValidateAnnotation valid; - valid(expr); - return valid.annotation_map_; - } - - private: - void VisitExpr_(const CallNode* call_node) final { - ExprVisitor::VisitExpr_(call_node); - auto props = GetOnDeviceProps(call_node); - if (props.body.defined()) { - if (annotation_map_.count(call_node)) { - ICHECK_EQ(annotation_map_.at(call_node), props.device_type) - << "An expression node can only be annotated to one device."; - } else { - annotation_map_.insert({call_node, props.device_type}); - } - - const auto* node = props.body.get(); - if (annotation_map_.count(node)) { - ICHECK_EQ(annotation_map_.at(node), props.device_type) - << "An expression node can only be annotated to one device."; - } else { - annotation_map_.insert({node, props.device_type}); - } - } - } - - void VisitExpr_(const TupleGetItemNode* get_elem) final { - ExprVisitor::VisitExpr_(get_elem); - const auto* tn = get_elem->tuple.get(); - if (annotation_map_.count(tn)) { - annotation_map_.insert({get_elem, annotation_map_.at(tn)}); - } - } - - std::unordered_map annotation_map_; -}; - -// Replace the use of an expression with the output of a `copy_device` operator -// if the `on_device` operator takes the annotated expr as an input. -// -// This actually replaces annotation ops with device copy ops and connects any -// two dependent expressions with a `device_copy` op when needed. Note that the -// device type of a `device_copy` op is identical to that of the destination op -// since it is where the data should be copied to. -class RewriteAnnotation : public ExprMutator { - public: - Expr Rewrite(const Expr& expr, int fallback_device) { - fallback_device_ = fallback_device; - annotation_map_ = ValidateAnnotation::Validate(expr); - return this->VisitExpr(expr); - } - - Expr VisitExpr_(const LetNode* op) final { - Expr value = GetDeviceCopyExpr(op->value, op); - Expr body = GetDeviceCopyExpr(op->body, op); - - if (value.same_as(op->value) && body.same_as(op->body)) { - return ExprMutator::VisitExpr_(op); - } else { - Expr new_let = Let(op->var, value, body); - UpdateAnnotationMap(op, new_let.get()); - return this->VisitExpr(new_let); - } - } - - Expr VisitExpr_(const TupleNode* op) { - Array fields; - bool annotated = false; - for (const auto& field : op->fields) { - annotated |= NeedDeviceCopy(field.get(), op); - fields.push_back(GetDeviceCopyExpr(field, op)); - } - - if (annotated) { - Expr new_tuple = Tuple(fields); - UpdateAnnotationMap(op, new_tuple.get()); - return this->VisitExpr(new_tuple); - } else { - return ExprMutator::VisitExpr_(op); - } - } - - Expr VisitExpr_(const TupleGetItemNode* op) final { - Expr tuple = op->tuple; - if (NeedDeviceCopy(tuple.get(), op)) { - Expr new_expr = TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); - UpdateAnnotationMap(op, new_expr.get()); - return this->VisitExpr(new_expr); - } else { - return ExprMutator::VisitExpr_(op); - } - } - - Expr VisitExpr_(const IfNode* if_node) final { - Expr cond = GetDeviceCopyExpr(if_node->cond, if_node); - Expr true_br = GetDeviceCopyExpr(if_node->true_branch, if_node); - Expr false_br = GetDeviceCopyExpr(if_node->false_branch, if_node); - - if (if_node->cond.same_as(cond) && if_node->true_branch.same_as(true_br) && - if_node->false_branch.same_as(false_br)) { - return ExprMutator::VisitExpr_(if_node); - } else { - Expr new_if = If(cond, true_br, false_br); - UpdateAnnotationMap(if_node, new_if.get()); - return this->VisitExpr(new_if); - } - } - - Expr VisitExpr_(const CallNode* call_node) final { - auto props = GetOnDeviceProps(call_node); - if (props.body.defined()) { - return this->VisitExpr(props.body); - } - - if (IsDeviceCopyNode(call_node)) { - return ExprMutator::VisitExpr_(call_node); - } - - Array new_args; - bool annotated = false; - for (const auto& arg : call_node->args) { - annotated |= NeedDeviceCopy(arg.get(), call_node); - new_args.push_back(GetDeviceCopyExpr(arg, call_node)); - } - - if (annotated) { - Call new_call = Call(call_node->op, new_args, call_node->attrs, call_node->type_args); - - UpdateAnnotationMap(call_node, new_call.get()); - return this->VisitExpr(new_call); - } else { - return ExprMutator::VisitExpr_(call_node); - } - } - - private: - void UpdateAnnotationMap(const ExprNode* old_node, const ExprNode* new_node) { - const auto it = annotation_map_.find(old_node); - if (it == annotation_map_.end()) { - annotation_map_.insert({new_node, fallback_device_}); - } else { - annotation_map_.insert({new_node, it->second}); - } - this->memo_[GetRef(old_node)] = GetRef(new_node); - } - - Expr GetDeviceCopyExpr(const Expr& src, const ExprNode* dst) { - const auto* src_node = src.get(); - if (!NeedDeviceCopy(src_node, dst)) return src; - - const auto sit = annotation_map_.find(src_node); - if (sit == annotation_map_.end()) { - const auto dit = annotation_map_.find(dst); - ICHECK(dit != annotation_map_.end()) - << "Device copy op is not required when both src and dst ops are not " - "annotated."; - return CreateDeviceCopy(src, fallback_device_, dit->second); - } else { - const auto dit = annotation_map_.find(dst); - int dst_dev_type = dit == annotation_map_.end() ? fallback_device_ : dit->second; - return CreateDeviceCopy(src, sit->second, dst_dev_type); - } - } - - // Check if a device copy op is need between two ops. - bool NeedDeviceCopy(const ExprNode* src, const ExprNode* dst) { - if (annotation_map_.count(src)) { - int src_dev_type = annotation_map_.at(src); - if (annotation_map_.count(dst)) { - return src_dev_type != annotation_map_.at(dst); - } else { - return src_dev_type != fallback_device_; - } - } else { - if (annotation_map_.count(dst)) { - // Though data copy op could be inserted whenever the `src` and `dst` - // ops are annotated to different devices, it leads to high overhead. - // - // Here we need across device data transferring only when `src` is a - // CallNode or FunctionNode and the `dst` is annotated with any device - // id other than fallback_device_. - if (src->IsInstance() || src->IsInstance()) { - return annotation_map_.at(dst) != fallback_device_; - } else { - // There shouldn't be any copy nodes between var/constant and another - // expression. - return !(src->IsInstance() || src->IsInstance()); - } - } else { - return false; - } - } - } - - /* - * \brief Create an operator to copy data from the source device to the - * destination device. - * \param src The source expression that produces data to be copied. - * \param src_dev_type The device type where the data is copied from. - * \param dst_dev_type The device type where the data is copied to. - * \return The created call node. - */ - Expr CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) { - Expr device_copy = DeviceCopy(src, static_cast(src_dev_type), - static_cast(dst_dev_type)); - annotation_map_.insert({device_copy.get(), dst_dev_type}); - return device_copy; - } - - const Op& device_copy_op_ = Op::Get("device_copy"); - std::unordered_map annotation_map_; - int fallback_device_; -}; - -/*! \brief Builds a map from "on_device" CallNodes to their device types. - * - * No other expression appear in the result map. - */ -class AnnotationVisitor : private ExprVisitor { - public: - static void AccumAnnotations(const Expr& expr, Map* annotations) { - AnnotationVisitor visitor(annotations); - visitor(expr); - } - - private: - explicit AnnotationVisitor(Map* annotations) : annotations_(annotations) {} - - void VisitExpr_(const CallNode* call_node) final { - auto props = GetOnDeviceProps(call_node); - if (props.body.defined()) { - annotations_->Set(GetRef(call_node), props.device_type); - } - ExprVisitor::VisitExpr_(call_node); - } - - Map* annotations_; -}; - -/*! - * \brief Inserts "device_copy" CallNodes where an existing "on_device" CallNode suggests - * a transition between device domains. All existing "on_device" CallNodes are removed. - */ -Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { - RewriteAnnotation rewrote = RewriteAnnotation(); - Expr new_expr = rewrote.Rewrite(expr, fallback_device); - - // Remove OnDevice operators. Note that these operators are only present at the - // leaves after annotation. Therefore, we can simply reconstruct the - // Function/Expr by removing them directly. - if (const FunctionNode* fn = new_expr.as()) { - auto params = fn->params; - auto body = fn->body; - std::vector new_body; - if (const TupleNode* tuple = body.as()) { - for (const auto& field : tuple->fields) { - if (!IsOnDeviceCall(field)) { - new_body.push_back(field); - } - } - ICHECK_GT(new_body.size(), 0U); - if (new_body.size() == 1) { - return Function(params, new_body[0], Type(nullptr), fn->type_params, fn->attrs); - } else if (tuple->fields.size() == new_body.size()) { - return new_expr; - } else { - Tuple tuple_body = Tuple(new_body); - return Function(params, tuple_body, Type(nullptr), fn->type_params, fn->attrs); - } - } else { - return new_expr; - } - } else if (const TupleNode* tuple = new_expr.as()) { - std::vector new_fields; - for (const auto& field : tuple->fields) { - if (!IsOnDeviceCall(field)) { - new_fields.push_back(field); - } - } - ICHECK_GT(new_fields.size(), 0U); - if (tuple->fields.size() == new_fields.size()) { - return new_fields.size() == 1 ? new_fields[0] : new_expr; - } else { - return new_fields.size() == 1 ? new_fields[0] : Tuple(new_fields); - } - } else { - return new_expr; - } -} - -Map CollectDeviceAnnotationOps(const Expr& expr) { - Map annotations; - AnnotationVisitor::AccumAnnotations(expr, &annotations); - return annotations; -} - -Map CollectAllDeviceAnnotationOps(const IRModule& mod) { - Map annotations; - for (const auto& pair : mod->functions) { - AnnotationVisitor::AccumAnnotations(pair.second, &annotations); - } - return annotations; -} - -TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps") - .set_body_typed(CollectDeviceAnnotationOps); - -namespace transform { - -Pass RewriteAnnotatedOps(int fallback_device) { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); - }; - return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); -} - -TVM_REGISTER_GLOBAL("relay._transform.RewriteAnnotatedOps").set_body_typed(RewriteAnnotatedOps); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index c68884fe0b90d..64f6b93263d7d 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -245,6 +245,9 @@ * * Might be simpler to just let every type have a device annotation rather than work in * a separate domain? * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + * * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls + * in tuples at the top level of function bodies or main expression, irrespective of the + * "on_device" body. What's up with that? */ #include "./device_planner.h" @@ -273,6 +276,29 @@ namespace transform { namespace { +/*! + * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather + * than the original "device_copy" operator. + * + * See te_compiler.cc for where this rewriting occurs. + */ +DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs == nullptr) { + return {}; + } + if (tir_call_attrs->metadata.count("source_device") != 1 || + tir_call_attrs->metadata.count("dst_device") != 1) { + return {}; + } + ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; + return { + call_node->args[0], + static_cast( + Downcast(tir_call_attrs->metadata["source_device"])->value), + static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; +} + class DeviceDomain; using DeviceDomainPtr = std::shared_ptr; @@ -581,25 +607,33 @@ class DeviceDomains { return Lookup(itr->second); } std::vector args_and_result; - if (call->op == OnDeviceOp()) { + + auto on_device_props = GetOnDeviceProps(call.get()); + auto device_copy_props = GetDeviceCopyProps(call.get()); + if (!device_copy_props.body.defined()) { + device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); + } + + if (on_device_props.body.defined()) { // on_device(expr, device_type=, is_fixed=false) // on_device : fn():?x? // // on_device(expr, device_type=, is_fixed=true) // on_device: fn(): - auto props = GetOnDeviceProps(call.get()); - args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.device_type)); - if (props.is_fixed) { + args_and_result.emplace_back( + ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type)); + if (on_device_props.is_fixed) { args_and_result.emplace_back(args_and_result.front()); } else { - args_and_result.emplace_back(Free(props.body->checked_type())); + args_and_result.emplace_back(Free(on_device_props.body->checked_type())); } - } else if (call->op == DeviceCopyOp()) { + } else if (device_copy_props.body.defined()) { // device_copy(expr, src_dev_type=, dst_dev_type=) // device_copy: fn(): - auto props = GetDeviceCopyProps(call.get()); - args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.src_dev_type)); - args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.dst_dev_type)); + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type)); + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type)); } else if (call->op == alloc_storage_op) { ICHECK_EQ(call->args.size(), 2U); // alloc_storage(size, alignment, device_type=) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py deleted file mode 100644 index bcb5895194620..0000000000000 --- a/tests/python/relay/test_pass_annotation.py +++ /dev/null @@ -1,671 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Unit tests for heterogeneous compilation and execution.""" -import json -import numpy as np - -import tvm -from tvm import relay -from tvm.contrib import graph_executor -from tvm.relay.expr_functor import ExprMutator -from tvm.relay import transform -from tvm.ir.instrument import pass_instrument -import tvm.testing - - -@tvm.instrument.pass_instrument -class Trace: - def run_before_pass(self, module, pass_info): - if pass_info.name == "ManifestAlloc": - pass # import pdb; pdb.set_trace() - - def run_after_pass(self, module, pass_info): - if pass_info.name == "ManifestAlloc": - pass # import pdb; pdb.set_trace() - - -def check_graph_executor( - target, ref_res, device, func, params, config, opt_level, expected_index=None -): - with tvm.transform.PassContext(opt_level=opt_level, config=config): - graph_executor_factory = relay.build(func, target, params=params) - - devices = [tvm.cpu(0), tvm.device(device)] - graph_json = json.loads(graph_executor_factory.graph_json) - if "device_index" in graph_json["attrs"]: - device_index = graph_json["attrs"]["device_index"][1] - assert device_index == expected_index - mod = graph_executor.GraphModule(graph_executor_factory["default"](*devices)) - mod.run() - res = mod.get_output(0).numpy() - tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) - - -def check_vm_runtime(target, ref_res, device, func, params, config, opt_level, expected_index=None): - with tvm.transform.PassContext(opt_level=opt_level, instruments=[Trace()], config=config): - mod = tvm.IRModule() - mod["main"] = func - exe = relay.vm.compile(mod, target) - dev = [tvm.cpu(0), tvm.device(device)] - vm = tvm.runtime.vm.VirtualMachine(exe, dev) - res = vm.invoke("main", **params) - tvm.testing.assert_allclose(res.numpy(), ref_res, rtol=1e-5, atol=1e-5) - - -def run_opt_pass(expr, passes): - passes = passes if isinstance(passes, list) else [passes] - mod = tvm.IRModule.from_expr(expr) - seq = tvm.transform.Sequential(passes) - with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) - return mod["main"] - - -def test_redundant_annotation(): - dev1 = tvm.device(1) - dev2 = tvm.device(2) - x = relay.var("x", shape=(3,)) - y = relay.var("y", shape=(3,)) - z = relay.var("z", shape=(3,)) - - def annotated(): - add = relay.add(x, y) - _add1 = relay.annotation.on_device(add, dev2) - _add2 = relay.annotation.on_device(add, dev2) - sub1 = relay.subtract(_add1, z) - sub2 = relay.subtract(_add2, z) - - func = relay.Function([x, y, z], relay.Tuple([sub1, sub2])) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - add = relay.add(x, y) - copy_add_sub1 = relay.device_copy(add, dev2, dev1) - sub1 = relay.subtract(copy_add_sub1, z) - copy_add_sub2 = relay.device_copy(add, dev2, dev1) - sub2 = relay.subtract(copy_add_sub2, z) - func = relay.Function([x, y, z], relay.Tuple([sub1, sub2])) - return func - - annotated_func = annotated() - expected_func = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -def test_annotate_expr(): - dev1 = tvm.device(1) - dev2 = tvm.device(2) - x = relay.var("x", shape=(3,)) - y = relay.var("y", shape=(3,)) - z = relay.var("z", shape=(3,)) - - def annotated(): - add = relay.add(x, y) - _add = relay.annotation.on_device(add, dev1) - sub = relay.subtract(_add, z) - _sub = relay.annotation.on_device(sub, dev2) - expr = run_opt_pass(_sub, transform.RewriteAnnotatedOps(dev1.device_type)) - return expr - - def expected(): - add = relay.add(x, y) - copy_add_sub = relay.device_copy(add, dev1, dev2) - sub = relay.subtract(copy_add_sub, z) - return sub - - annotated_expr = annotated() - expected_expr = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_expr, expected_expr) - - -def test_annotate_all(): - dev1 = tvm.device(1) - dev2 = tvm.device(2) - x = relay.var("x", shape=(3,)) - y = relay.var("y", shape=(3,)) - z = relay.var("z", shape=(3,)) - - def annotated(): - add = relay.add(x, y) - _add = relay.annotation.on_device(add, dev2) - sub = relay.subtract(_add, z) - _sub = relay.annotation.on_device(sub, dev2) - - func = relay.Function([x, y, z], _sub) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - add = relay.add(x, y) - sub = relay.subtract(add, z) - func = relay.Function([x, y, z], sub) - return func - - annotated_func = annotated() - expected_func = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -def test_annotate_none(): - dev1 = tvm.device(1) - dev2 = tvm.device(2) - x = relay.var("x", shape=(3,)) - y = relay.var("y", shape=(3,)) - z = relay.var("z", shape=(3,)) - - def annotated(): - add = relay.add(x, y) - sub = relay.subtract(add, z) - func = relay.Function([x, y, z], sub) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - add = relay.add(x, y) - sub = relay.subtract(add, z) - func = relay.Function([x, y, z], sub) - return func - - annotated_func = annotated() - expected_func = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -def check_annotated_graph(annotated_func, expected_func): - annotated_func = run_opt_pass(annotated_func, transform.InferType()) - expected_func = run_opt_pass(expected_func, transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -def test_conv_network(): - r"""The network is as following: - data1 data2 - | | - conv2d conv2d - \ / - add - | - conv2d - """ - batch_size = 1 - dshape = (batch_size, 64, 56, 56) - weight = relay.var("weight", shape=(64, 64, 3, 3)) - data1 = relay.var("data1", shape=dshape) - data2 = relay.var("data2", shape=dshape) - dev1 = tvm.device(1) - dev2 = tvm.device(2) - - def original(): - conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - add = relay.add(conv2d_1, conv2d_2) - conv2d_3 = relay.nn.conv2d(add, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - - func = relay.Function([data1, data2, weight], conv2d_3) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(tvm.device(3).device_type)) - return func - - def annotated(): - conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - _conv2d_1 = relay.annotation.on_device(conv2d_1, dev2) - conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - _conv2d_2 = relay.annotation.on_device(conv2d_2, dev2) - add = relay.add(_conv2d_1, _conv2d_2) - _add = relay.annotation.on_device(add, dev1) - conv2d_3 = relay.nn.conv2d(_add, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - _conv2d_3 = relay.annotation.on_device(conv2d_3, dev2) - - func = relay.Function([data1, data2, weight], _conv2d_3) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(tvm.device(3).device_type)) - return func - - class ScheduleConv2d(ExprMutator): - def __init__(self, device): - self.device = device - super().__init__() - - def visit_call(self, expr): - visit = super().visit_call(expr) - if expr.op == tvm.relay.op.get("nn.conv2d"): - return relay.annotation.on_device(visit, self.device) - else: - return visit - - def annotate_with_visitor(func): - sched = ScheduleConv2d(dev2) - func = sched.visit(func) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - device_copy1 = relay.device_copy(conv2d_1, dev2, dev1) - conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - device_copy2 = relay.device_copy(conv2d_2, dev2, dev1) - add = relay.add(device_copy1, device_copy2) - device_copy3 = relay.device_copy(add, dev1, dev2) - conv2d_3 = relay.nn.conv2d( - device_copy3, weight, channels=64, kernel_size=(3, 3), padding=(1, 1) - ) - - func = relay.Function([data1, data2, weight], conv2d_3) - return func - - def check_storage_and_device_types(): - default_device_type = 3 - func = annotated() - func = run_opt_pass( - func, - [ - transform.RewriteAnnotatedOps(default_device_type), - transform.PlanDevices(default_device_type), - transform.FuseOps(2), - ], - ) - smap = relay.backend._backend.GraphPlanMemory(func) - storage_ids = [] - device_types = [] - for _, storage_info in smap.expr_to_storage_info.items(): - - for sid in storage_info.storage_ids: - storage_ids.append(sid.value) - - for did in storage_info.device_types: - device_types.append(did.value) - - assert len(storage_ids) == 10 - assert len(set(storage_ids)) == 8 - assert len(set(device_types)) == 2 - assert set(device_types) == {1, 2} - - def test_manual_annotation(): - annotated_func = annotated() - expected_func = expected() - check_annotated_graph(annotated_func, expected_func) - check_storage_and_device_types() - - def test_visitor_annotation(): - annotated_func = annotate_with_visitor(original()) - expected_func = expected() - check_annotated_graph(annotated_func, expected_func) - - test_manual_annotation() - test_visitor_annotation() - - -def test_propogation(): - R""" The network and device type is as following: - x 1 - | - log 1 - / \ - log2 log10 2 - \ / - add 2 - | - tan 1 - """ - dev1 = tvm.device(1) - dev2 = tvm.device(2) - - expected_dev_type = {"log": dev1, "log2": dev2, "log10": dev2, "add": dev2, "tan": dev1} - - x = relay.var("x", shape=(3,)) - - def annotated(): - log = relay.log(x) - _log = relay.annotation.on_device(log, expected_dev_type["log"]) - log2 = relay.log2(_log) - _log2 = relay.annotation.on_device(log2, expected_dev_type["log2"]) - log10 = relay.log10(_log) - _log10 = relay.annotation.on_device(log10, expected_dev_type["log10"]) - add = relay.add(_log2, _log10) - _add = relay.annotation.on_device(add, expected_dev_type["add"]) - tan = relay.tan(_add) - _tan = relay.annotation.on_device(tan, expected_dev_type["tan"]) - - func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(dev1.device_type)) - return func - - def expected(): - log = relay.log(x) - _log_left = relay.device_copy(log, dev1, dev2) - _log_right = relay.device_copy(log, dev1, dev2) - log2 = relay.log2(_log_left) - log10 = relay.log10(_log_right) - add = relay.add(log2, log10) - _add = relay.device_copy(add, dev2, dev1) - tan = relay.tan(_add) - - func = run_opt_pass(tan, transform.InferType()) - return func - - annotated_expr = annotated() - expected_expr = expected() - assert tvm.ir.structural_equal(annotated_expr, expected_expr) - - smap = relay.backend._backend.GraphPlanMemory(annotated_expr) - for expr, storage_info in smap.expr_to_storage_info.items(): - # x is dev1 as output is dev1 - if isinstance(expr, tvm.relay.expr.Var): - assert storage_info.device_types[0] == dev1.device_type - else: - # device_copy op should be its dst_dev_type - if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs): - assert storage_info.device_types[0] == expr.attrs.dst_dev_type - else: - assert storage_info.device_types[0] == expected_dev_type[expr.op.name].device_type - - -def run_fusible_network(dev, tgt): - R""" The network is as following: - x y - \ / - add - / \ - sqrt log - \ / - subtract - | - exp - """ - x = relay.var("x", shape=(1, 10)) - y = relay.var("y", shape=(10, 10)) - x_data = np.random.rand(1, 10).astype("float32") - y_data = np.random.rand(10, 10).astype("float32") - tmp_add = x_data + y_data - tmp_sqrt = np.sqrt(tmp_add) - tmp_log = np.log(tmp_add) - tmp_sub = np.subtract(tmp_sqrt, tmp_log) - ref_res = np.exp(tmp_sub) - params = {"x": x_data, "y": y_data} - - def get_func(): - add = relay.add(x, y) - sqrt = relay.sqrt(add) - log = relay.log(add) - subtract = relay.subtract(sqrt, log) - exp = relay.exp(subtract) - - func = relay.Function([x, y], exp) - return func - - def test_fuse_log_add(device, tgt): - """Only log and add are fused.""" - fallback_device = tvm.device("cpu") - target = {"cpu": "llvm", device: tgt} - cpu_dev = fallback_device - dev_dev = tvm.device(device) - - def annotated(): - add = relay.add(x, y) - sqrt = relay.sqrt(add) - _sqrt = relay.annotation.on_device(sqrt, dev_dev) - log = relay.log(add) - subtract = relay.subtract(_sqrt, log) - exp = relay.exp(subtract) - _exp = relay.annotation.on_device(exp, dev_dev) - - func = relay.Function([x, y], _exp) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_dev.device_type)) - return func - - def expected(): - add = relay.add(x, y) - copy_add_sqrt = relay.device_copy(add, cpu_dev, dev_dev) - sqrt = relay.sqrt(copy_add_sqrt) - log = relay.log(add) - copy_sqrt_subtract = relay.device_copy(sqrt, dev_dev, cpu_dev) - subtract = relay.subtract(copy_sqrt_subtract, log) - copy_sub_exp = relay.device_copy(subtract, cpu_dev, dev_dev) - exp = relay.exp(copy_sub_exp) - - func = relay.Function([x, y], exp) - return func - - annotated_func = annotated() - expected_func = expected() - dev = tvm.device(device, 0) - dev_idx = dev.device_type - expected_index = [1, 1, 1, dev_idx, dev_idx, 1, 1, dev_idx, dev_idx] - check_annotated_graph(annotated_func, expected_func) - opt_level = 1 - config = {"relay.fallback_device_type": fallback_device.device_type} - check_graph_executor( - target, ref_res, device, annotated_func, params, config, opt_level, expected_index - ) - opt_level = 2 - check_vm_runtime( - target, ref_res, device, annotated_func, params, config, opt_level, expected_index - ) - - def test_fuse_all(device, tgt): - """Fuse all operators.""" - fallback_device = tvm.device("cpu") - target = {"cpu": "llvm", device: tgt} - cpu_dev = fallback_device - dev_dev = tvm.device(device) - - def annotated(): - add = relay.add(x, y) - _add = relay.annotation.on_device(add, dev_dev) - sqrt = relay.sqrt(_add) - _sqrt = relay.annotation.on_device(sqrt, dev_dev) - log = relay.log(_add) - _log = relay.annotation.on_device(log, dev_dev) - subtract = relay.subtract(_sqrt, _log) - _subtract = relay.annotation.on_device(subtract, dev_dev) - exp = relay.exp(_subtract) - _exp = relay.annotation.on_device(exp, dev_dev) - - func = relay.Function([x, y], _exp) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_dev.device_type)) - return func - - annotated_func = annotated() - expected_func = get_func() - check_annotated_graph(annotated_func, expected_func) - opt_level = 1 - config = {"relay.fallback_device_type": fallback_device.device_type} - check_graph_executor(target, ref_res, device, annotated_func, params, config, opt_level) - opt_level = 2 - check_vm_runtime(target, ref_res, device, annotated_func, params, config, opt_level) - - def test_fallback_exp(device, tgt): - fallback_device = tvm.device("cpu") - target = {"cpu": "llvm", device: tgt} - cpu_dev = fallback_device - dev_dev = tvm.device(device) - - def annotated(): - add = relay.add(x, y) - sqrt = relay.sqrt(add) - log = relay.log(add) - subtract = relay.subtract(sqrt, log) - exp = relay.exp(subtract) - _exp = relay.annotation.on_device(exp, cpu_dev) - - func = relay.Function([x, y], _exp) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev_dev.device_type)) - return func - - def expected(): - add = relay.add(x, y) - sqrt = relay.sqrt(add) - log = relay.log(add) - subtract = relay.subtract(sqrt, log) - copy_sub_exp = relay.device_copy(subtract, dev_dev, cpu_dev) - exp = relay.exp(copy_sub_exp) - - func = relay.Function([x, y], exp) - return func - - annotated_func = annotated() - expected_func = expected() - dev = tvm.device(device, 0) - dev_idx = dev.device_type - expected_index = [dev_idx, dev_idx, dev_idx, 1, 1] - opt_level = 1 - config = {"relay.fallback_device_type": fallback_device.device_type} - check_annotated_graph(annotated_func, expected_func) - check_graph_executor( - target, ref_res, device, annotated_func, params, config, opt_level, expected_index - ) - opt_level = 2 - check_vm_runtime( - target, ref_res, device, annotated_func, params, config, opt_level, expected_index - ) - - def test_fallback_all_operators(device, tgt): - target = {device: tgt, "cpu": "llvm"} - annotated_func = get_func() - expected_func = get_func() - check_annotated_graph(annotated_func, expected_func) - opt_level = 2 - check_graph_executor(target, ref_res, device, annotated_func, params, {}, opt_level) - check_vm_runtime(target, ref_res, device, annotated_func, params, {}, opt_level) - - test_fuse_log_add(dev, tgt) - test_fuse_all(dev, tgt) - test_fallback_exp(dev, tgt) - test_fallback_all_operators(dev, tgt) - - -def run_unpropagatable_graph(dev, tgt): - r"""The network is as following: - a b c d - \ / \ / - add mul - \ / - subtract - """ - - a = relay.var("a", shape=(10, 10)) - b = relay.var("b", shape=(10, 10)) - c = relay.var("c", shape=(10, 10)) - d = relay.var("d", shape=(10, 10)) - a_data = np.random.rand(10, 10).astype("float32") - b_data = np.random.rand(10, 10).astype("float32") - c_data = np.random.rand(10, 10).astype("float32") - d_data = np.random.rand(10, 10).astype("float32") - tmp_add = a_data + b_data - tmp_mul = np.multiply(c_data, d_data) - ref_res = np.subtract(tmp_add, tmp_mul) - - fallback_device = tvm.device("cpu") - target = {"cpu": "llvm", dev: tgt} - cpu_dev = fallback_device - dev_dev = tvm.device(dev) - - def annotated(): - add = relay.add(a, b) - _add = relay.annotation.on_device(add, dev_dev) - mul = relay.multiply(c, d) - _mul = relay.annotation.on_device(mul, cpu_dev) - sub = relay.subtract(_add, _mul) - _sub = relay.annotation.on_device(sub, dev_dev) - func = relay.Function([a, b, c, d], _sub) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(dev_dev.device_type)) - return func - - def expected(): - add = relay.add(a, b) - mul = relay.multiply(c, d) - copy_mul_sub = relay.device_copy(mul, cpu_dev, dev_dev) - sub = relay.subtract(add, copy_mul_sub) - func = relay.Function([a, b, c, d], sub) - return func - - annotated_func = annotated() - expected_func = expected() - expected_index = [2, 2, 2, 1, 1, 1, 2, 2] - check_annotated_graph(annotated_func, expected_func) - params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data} - opt_level = 0 - config = {"relay.fallback_device_type": fallback_device.device_type} - - check_graph_executor( - target, ref_res, dev, annotated_func, params, config, opt_level, expected_index - ) - - opt_level = 2 - check_vm_runtime(target, ref_res, dev, annotated_func, params, config, opt_level) - - -@tvm.testing.requires_opencl -def test_check_run_opencl(): - dev = "opencl" - tgt = "opencl" - run_fusible_network(dev, tgt) - run_unpropagatable_graph(dev, tgt) - - -@tvm.testing.requires_opencl -def test_check_run_opencl_intel(): - dev = "opencl" - tgt = str(tvm.target.intel_graphics()) - run_fusible_network(dev, tgt) - run_unpropagatable_graph(dev, tgt) - - -@tvm.testing.requires_cuda -def test_check_run_cuda(): - dev = "cuda" - tgt = "cuda" - run_fusible_network(dev, tgt) - run_unpropagatable_graph(dev, tgt) - - -@tvm.testing.requires_cuda -def test_tuple_get_item(): - dev = "cuda" - cpu_dev = tvm.cpu(0) - gpu_dev = tvm.device(dev) - - def expected(): - x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32")) - split = relay.op.split(x, 3) - elem0 = relay.device_copy(split[0], gpu_dev, cpu_dev) - elem1 = relay.device_copy(split[1], gpu_dev, cpu_dev) - sub = elem0 - elem1 - func = relay.Function(relay.analysis.free_vars(sub), sub) - return func - - def annotated(): - x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32")) - split = relay.op.split(x, 3) - split = split.astuple() - split = relay.annotation.on_device(split, gpu_dev) - split = relay.TupleWrapper(split, 3) - sub = split[0] - split[1] - func = relay.Function(relay.analysis.free_vars(sub), sub) - func = run_opt_pass(func, transform.RewriteAnnotatedOps(cpu_dev.device_type)) - return func - - annotated_func = annotated() - expected_func = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(annotated_func, expected_func) - - -if __name__ == "__main__": - test_redundant_annotation() - test_annotate_expr() - test_annotate_all() - test_annotate_none() - test_conv_network() - test_tuple_get_item()