From ebf87bead62e2089b7e4496715b6264f7c0b63f5 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Fri, 27 Aug 2021 16:06:59 -0700 Subject: [PATCH] [checkpoint] cross-device example working on vm! [checkpoint] Can't wrap on_device around big-lambdas. [checkpoint] ToANF working I think - Cleanup pairs-of-pairs for OnDeviceProps etc. - Don't wrap OnDevice around expressions that don't need it. [checkpoint] ANF working again [checkpoint] Visitor helpers, lambda lifting tracks devices 1/2 way through ANF tracking devices but currently very broken [checkpoint] Rollback ANF scope changes, need to revisit [checkpoint] Standalone pass unit tests all pass :-) [checkpoint] TupleGetItem is working [checkpoint] Giving up on FindFixedAndFreeExpressions, will introduce rewrite 'phase 0' instead [checkpoint] Unit tests starting to work - Add 'is_fixed' field and the 'implicit is_fixed=true' rule - Start porting original context planning and annotation tests [checkpoint] comment polish, get rid of no-op overrides. [checkpoint] handle pattern-bound vars [checkpoint] Rework intro comment. Introduce UnifyCollapsed. [checkpoint] fix 'stored on' vs 'executes on' confusion [checkpoint] more tests and bug fixes [checkpoint] improve test [checkpoint] basic tests passing [checkpoint] get basic test going again [checkpoint] Fix merge snafu [checkpoint] Rename to device_planner / PlanDevices, improve comments, kind checking [checkpoint] Switch to higher-order domains, add defaulting visitor. [checkpoint] builds again [checkpoint] (Does not build) Rework handling of let- and param-bound functions. [checkpoint] (won't build) device_copy can capture scope, cleanup var->function tracking [checkpoint] Python uses Devices not types. [checkpoint] Renames, restore lost make_devices_explicit.cc [checkpoint] rename test_pass_context_analysis.py to test_pass_make_devices_explicit.py [checkpoint] few more rollbacks [checkpoint] rollback bogus rename [checkpoint] Cleanup default device handling. [checkpoint] bug fixes, working on trivial example again The default device stuff is messed up. [checkpoint] Cleanup on_device handling. Fix param device lookups. [checkpoint] Merged LowerTE Pass [checkpoint] Get going with interpreter. - ToANormalForm considers the arg to on_devivce an inner scope. - FuseOps does not consider on_device a primitive - Interpreter knows on_device is id [checkpoint] undo accidental rename [checkpoint] starting unit test [checkpoint] Get rid of device_map from LowerTE - Inserted the transform in I think the right place for VM, AOT, Interpreter and GraphExecutor. - LowerTE still needs the memory plan, so still a lot of re-doing of memory planning going on. But at least the device map does not need to be rebuilt. - Add logging context help -- preparing for the long climb to get all the tests going. Still need to figure out all the default device stuff, I don't think that's being handled correctly. [checkpoint] Mixin helper, capture OnDeviceAttrs for params. TODO: - Make sure device pass actually runs. - Handle default device when targets_.size() == 1. - Device vs int vs DLDeviceType confusion everywhere [checkpoint] Make device assignment a pass. VM compiler still needs explicit map. All very rough. Lots of mismatches between Device and DLDeviceType as unit of annotation. [checkpoint] better messages [checkpoint] Merge in VLOG so can try it out with larger cl. Will need to split it out again. [checkpoint] rollback WithAttr node since seems using CallNode is the pattern [checkpoint] Got rid of CollectDeviceInfo [checkpoint] fiddling with WithAttr [checkpoint] trivial [checkpoint] rename context_analysis.cc to make_devices_explicit.cc and move to transforms/ --- CMakeLists.txt | 1 + include/tvm/relay/analysis.h | 23 +- include/tvm/relay/attrs/annotation.h | 45 +- include/tvm/relay/attrs/device_copy.h | 1 + include/tvm/relay/expr.h | 13 + include/tvm/relay/expr_functor.h | 3 +- include/tvm/relay/transform.h | 15 +- include/tvm/runtime/container/array.h | 2 +- include/tvm/runtime/ndarray.h | 26 +- include/tvm/runtime/vm/vm.h | 4 +- python/tvm/relay/analysis/analysis.py | 32 - python/tvm/relay/op/annotation/annotation.py | 39 +- python/tvm/relay/transform/transform.py | 23 +- src/node/structural_equal.cc | 7 +- src/relay/analysis/context_analysis.cc | 719 ------ .../{op/annotation => attrs}/annotation.cc | 133 +- src/relay/attrs/annotation.h | 114 + src/relay/attrs/device_copy.cc | 117 + src/relay/attrs/device_copy.h | 79 + src/relay/backend/aot_executor_codegen.cc | 141 +- src/relay/backend/build_module.cc | 81 +- src/relay/backend/graph_executor_codegen.cc | 65 +- src/relay/backend/graph_plan_memory.cc | 158 +- src/relay/backend/interpreter.cc | 97 +- src/relay/backend/te_compiler.cc | 110 +- src/relay/backend/te_compiler.h | 9 +- src/relay/backend/vm/compiler.cc | 234 +- src/relay/backend/vm/compiler.h | 12 +- src/relay/backend/vm/lambda_lift.cc | 85 +- src/relay/ir/dataflow_matcher.cc | 8 +- src/relay/ir/expr.cc | 25 + src/relay/ir/expr_functor.cc | 22 +- src/relay/ir/indexed_graph.cc | 92 +- src/relay/ir/indexed_graph.h | 19 +- src/relay/op/memory/memory.cc | 44 +- src/relay/op/memory/memory.h | 1 - src/relay/quantize/partition.cc | 2 +- src/relay/quantize/realize.cc | 2 +- src/relay/transforms/device_annotation.cc | 311 +-- src/relay/transforms/device_planner.cc | 1941 +++++++++++++++++ src/relay/transforms/device_planner.h | 188 ++ src/relay/transforms/fold_scale_axis.cc | 4 + src/relay/transforms/fuse_ops.cc | 36 +- src/relay/transforms/let_list.h | 2 +- src/relay/transforms/memory_alloc.cc | 125 +- src/relay/transforms/pass_utils.h | 65 +- src/relay/transforms/pattern_utils.h | 4 - src/relay/transforms/split_args.cc | 3 +- src/relay/transforms/to_a_normal_form.cc | 400 ++-- .../transforms/to_basic_block_normal_form.cc | 36 +- src/relay/transforms/type_infer.cc | 6 +- src/runtime/ndarray.cc | 4 +- src/runtime/vm/serialize_utils.h | 4 +- src/runtime/vm/vm.cc | 2 +- tests/python/relay/test_pass_annotation.py | 9 +- .../relay/test_pass_context_analysis.py | 205 -- tests/python/relay/test_pass_lambda_lift.py | 5 +- tests/python/relay/test_pass_plan_devices.py | 1062 +++++++++ 58 files changed, 4818 insertions(+), 2197 deletions(-) delete mode 100644 src/relay/analysis/context_analysis.cc rename src/relay/{op/annotation => attrs}/annotation.cc (63%) create mode 100644 src/relay/attrs/annotation.h create mode 100644 src/relay/attrs/device_copy.cc create mode 100644 src/relay/attrs/device_copy.h create mode 100644 src/relay/transforms/device_planner.cc create mode 100644 src/relay/transforms/device_planner.h delete mode 100644 tests/python/relay/test_pass_context_analysis.py create mode 100644 tests/python/relay/test_pass_plan_devices.py diff --git a/CMakeLists.txt b/CMakeLists.txt index a403d9462d450..f4ef16166b17e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -267,6 +267,7 @@ file(GLOB_RECURSE RELAY_PASS_SRCS src/relay/analysis/*.cc src/relay/transforms/*.cc src/relay/quantize/*.cc + src/relay/attrs/*.cc ) file(GLOB RELAY_BACKEND_SRCS src/relay/backend/*.cc diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 264f2609a4b6b..176ff9c8cd6f2 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -212,22 +212,14 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const IRModule& mod); TVM_DLL tvm::Array AllTypeVars(const Type& t, const IRModule& mod); /*! - * \brief Collect the device mapping information of each expression. - * - * \param expr The expression. - * - * \return The device mapping. - */ -TVM_DLL Map CollectDeviceInfo(const Expr& expr); - -/*! - * \brief Collect the device anntation operators. + * \brief Collect the device annotation operators. * * \param expr The expression. * * \return The annotated expression to device type mapping for annotation ops. */ TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); +TVM_DLL Map CollectAllDeviceAnnotationOps(const IRModule& mod); /*! * \brief Finds cases that the given match expression does not catch, if any. @@ -268,17 +260,6 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod); */ TVM_DLL Map> GetCalibrateOutputMap(const IRModule& mod); -/*! - * \brief Analyze the device context of each IR node in a given relay module. - * - * \param mod The module for analysis. - * \param default_device The default device used by unassigned IR nodes. - * - * \return The mapping between an IR node and its associated device. - */ -TVM_DLL std::unordered_map -ContextAnalysis(const IRModule& mod, const Device& default_device); - } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 8379e6471561d..aa8507f6c03b1 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -32,14 +32,55 @@ namespace tvm { namespace relay { /*! - * \brief Options for the device annotation operators. + * \brief Attributes for the "on_device" operator. + * + * The relay call + * \code + * on_device(expr, device_type=2) + * \endcode + * denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2 + * (i.e. \p kDLCuda). Semantically the operator is the identity function. */ struct OnDeviceAttrs : public tvm::AttrsNode { + // TODO(mbs): Replace device types with TargetDevice. + /*! \brief Device type on which argument expression should be evaluated. */ int device_type; + /*! + * \brief If true, the result device must also be \p device_type and device planning should + * not insert any "device_copy" calls to respect this annotation. + * + * This is used by the device planning pass itself when annotating the planned program. + */ + bool is_fixed; TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { TVM_ATTR_FIELD(device_type) - .describe("The virutal device/context type that an expression is annotated with.") + .describe("The type of the virtual device which should hold the expression result.") + .set_default(0); + TVM_ATTR_FIELD(is_fixed) + .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") + .set_default(false); + } +}; + +/*! + * \brief Attributes for Relay function definitions which capture the devices for the + * function parameters and result. + */ +struct FunctionOnDeviceAttrs : public tvm::AttrsNode { + constexpr static const char* kFunctionAttrsKey = "on_device"; + + /*! \brief Device type on which each of the function's arguments already resides. */ + Array param_device_types; + // TODO(mbs): Replace device types with TargetDevice. + /*! \brief Device type on which function body should be evaluated. */ + int result_device_type; + + TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") { + TVM_ATTR_FIELD(param_device_types) + .describe("The type of the virtual device which holds each function parameters."); + TVM_ATTR_FIELD(result_device_type) + .describe("The type of the virtual device which will hold the function's result.") .set_default(0); } }; diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index 7da92b3ff7639..f7b0a04f45fa8 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -35,6 +35,7 @@ namespace relay { * \brief Options for the device copy operators. */ struct DeviceCopyAttrs : public tvm::AttrsNode { + // TODO(mbs): Should be TargetDevice. int dst_dev_type; int src_dev_type; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index daad8514f9ff5..1d98e583f092e 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -316,7 +316,20 @@ class Call : public Expr { TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), Array type_args = Array(), Span span = Span()); + /*! + * \brief Returns a copy of this with given properties. A null property denotes 'no change'. Returns + * this if all properties are unchanged. Returns a modified this if this is the only reference + * to the underlying node. + */ + // TODO(mbs): Extend to all node types. + Call CopyWith(Optional opt_op = Optional(), + Optional> opt_args = Optional>(nullptr), + Optional opt_attrs = Optional(nullptr), + Optional> opt_type_args = Optional>(nullptr), + Optional opt_span = Optional(nullptr)); + TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; /*! diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 688ad8254fa85..f96faffb24f4f 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -37,6 +37,7 @@ #include #include #include + namespace tvm { namespace relay { @@ -227,7 +228,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor { * * MixedModeVisitor provides the same recursive API as ExprVisitor, and uses * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions - * of the graph and processes them iteratatively to prevent stack overflows + * of the graph and processes them iteratively to prevent stack overflows */ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { public: diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index bdc46d71a77de..86ea8be92979b 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -430,13 +430,24 @@ TVM_DLL Pass SimplifyExpr(); * \brief A pass for manifesting explicit memory allocations and rewriting * specific dialects. * - * \param target_host The target used by the host for compliation. - * \param targets The device type and target pairs for compliation. + * \param target_host The target used by the host for compilation. + * \param targets The device type and target pairs for compilation. * * \return The pass. */ TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +/*! + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + * every Relay sub-expression should run (and the result stored). Captures the result of that + * analysis using new "on_device" and "device_copy" CallNodes. See + * tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} + * for help recovering the device for an arbitrary sub-expression in downstream transformations. + * + * \param default_device_type DLDeviceType for default device. + */ +TVM_DLL Pass PlanDevices(DLDeviceType default_device_type); + } // namespace transform /*! diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 8830653da88cc..26f4e545deb75 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase { }; /*! - * \brief Array, container representing a contigious sequence of ObjectRefs. + * \brief Array, container representing a contiguous sequence of ObjectRefs. * * Array implements in-place copy-on-write semantics. * diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 1127a9ae732cd..a4c285e3dd086 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -38,9 +38,19 @@ #include namespace tvm { -namespace runtime { -typedef DLDevice Device; +// alias DLDevice +using Device = DLDevice; + +// A 'null' device type, does not correspond to any DLDeviceType enum. +// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case +// as a singleton target map indexed by the invalid DLDeviceType '0'. +constexpr DLDeviceType kNullDeviceType = static_cast(0); + +// An 'invalid' device type, does not correspond to any DLDeviceType enum. +constexpr DLDeviceType kInvalidDeviceType = static_cast(-1); + +namespace runtime { /*! * \brief Managed NDArray. @@ -481,23 +491,19 @@ inline bool NDArray::Load(dmlc::Stream* strm) { } } // namespace runtime - -// alias Device -using tvm::runtime::Device; - } // namespace tvm namespace std { template <> -struct hash { - std::size_t operator()(const tvm::runtime::Device& dev) const { +struct hash { + std::size_t operator()(const tvm::Device& dev) const { return ((dev.device_id << 8) | dev.device_type); } }; template <> -struct equal_to { - bool operator()(const tvm::runtime::Device& lhs, const tvm::runtime::Device& rhs) const { +struct equal_to { + bool operator()(const tvm::Device& lhs, const tvm::Device& rhs) const { return (lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id); } }; diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 2fdfec9452af5..831336b9dbfe6 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -84,11 +84,11 @@ struct VMFunction { /*! \brief The size of the frame for this function */ Index register_file_size; /*! \brief The device type of each parameter for this function. */ - std::vector params_device_type; + std::vector params_device_type; VMFunction(const std::string& name, std::vector params, const std::vector& instructions, Index register_file_size, - const std::vector params_device_type = {}) + const std::vector params_device_type = {}) : name(name), params(params), instructions(instructions), diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index c7b6c60849a14..e50c8405feb37 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -28,21 +28,6 @@ from .feature import Feature -def context_analysis(mod, default_device): - """Analyze the device context information of each IR node in a Relay - program. - - Parameters - ---------- - mod : tvm.IRModule - The input module. - - default_device : tvm.runtime.Device - The default context allocated to an IR node. - """ - return _ffi_api.ContextAnalysis(mod, default_device) - - def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited @@ -268,23 +253,6 @@ def all_dtypes(expr): return set(_ffi_api.all_dtypes(expr)) -def collect_device_info(expr): - """Collect the device allocation map for the given expression. The device - ids are propagated from the `device_copy` operators. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - ret : Dict[tvm.relay.ir.expr, int] - A dictionary mapping tvm.relay.Expr to device type. - """ - return _ffi_api.CollectDeviceInfo(expr) - - def collect_device_annotation_ops(expr): """Collect the device annotation ops for the given expression. diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 809b6369b0854..a7e09294a66b0 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -22,7 +22,19 @@ from .. import op as reg -def on_device(data, device): +def _device_to_int(device): + if isinstance(device, _Device): + return device.device_type + elif isinstance(device, str): + return _nd.device(device).device_type + else: + raise ValueError( + "device is expected to be the type of Device or " + "str, but received %s" % (type(device)) + ) + + +def on_device(data, device, is_fixed=False): """Annotate an expression with a certain device type. Parameters @@ -33,21 +45,26 @@ def on_device(data, device): device : Union[:py:class:`Device`, str] The device type to annotate. + is_fixed : bool + If true, annotation does not imply a device_copy may be inserted. + (This parameter is used internally by the compiler and unit tests and + should not need to be set in user programs.) + Returns ------- result : tvm.relay.Expr The annotated expression. """ - if isinstance(device, _Device): - device = device.device_type - elif isinstance(device, str): - device = _nd.device(device).device_type - else: - raise ValueError( - "device is expected to be the type of Device or " - "str, but received %s" % (type(device)) - ) - return _make.on_device(data, device) + return _make.on_device(data, _device_to_int(device), is_fixed) + + +# for testing only +def function_on_device(function, param_devices, result_device): + """Attaches attribute to function indicating the devices for its parameters and result. + """ + + return _make.function_on_device(function, + [_device_to_int(d) for d in param_devices], _device_to_int(result_device)) def stop_fusion(data): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 9a7857a01fe68..400b75f9a5fe4 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -546,7 +546,7 @@ def MergeCompilerRegions(): def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. - `on_deivce`, mark which device an expression should be scheduled to. + `on_device`, mark which device an expression should be scheduled to. This pass helps heterogeneous execution where different operators may need to be allocated on various devices. @@ -562,7 +562,7 @@ def RewriteAnnotatedOps(fallback_device): The registered pass that rewrites an expression with annotated `on_device` operators. """ - return _ffi_api.RewriteDeviceAnnotation(fallback_device) + return _ffi_api.RewriteAnnotatedOps(fallback_device) def ToANormalForm(): @@ -801,15 +801,11 @@ def gradient(expr, mod=None, mode="higher_order"): The transformed expression. """ if mode == "first_order": - warnings.warn( - "using transform.gradient for first-order AD is deprecated, please use the" - "FirstOrderGradient module pass", - DeprecationWarning, - ) + warnings.warn("using transform.gradient for first-order AD is deprecated, please use the" + "FirstOrderGradient module pass", DeprecationWarning, ) if mod is not None: raise RuntimeError( - "to run first-order AD on a module, please use the FirstOrderGradient module pass." - ) + "to run first-order AD on a module, please use the FirstOrderGradient module pass.") return FirstOrderGradient()(tvm.IRModule.from_expr(expr))["main"] if mode == "higher_order": return _ffi_api.gradient(expr, mod) @@ -1167,6 +1163,15 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() +def PlanDevices(default_device): + """ + Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + every Relay sub-expression should run (and the result stored). Captures the result of that + analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of + the default_device is ignored. + """ + return _ffi_api.PlanDevices(default_device) + def FoldExplicitPadding(): """ FoldExplicitPadding finds explict padding before an op that can support diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 1fa72c92b6fc1..8e52af60d2351 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,6 +19,7 @@ /*! * \file src/node/structural_equal.cc */ +#include #include #include #include @@ -119,8 +120,10 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { // Check the result. bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { if (assert_mode_ && !result) { - LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n" - << "lhs = " << lhs << "\nrhs = " << rhs; + LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl + << PrettyPrint(lhs) << std::endl + << "and rhs:" << std::endl + << PrettyPrint(rhs); } return result; } diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc deleted file mode 100644 index 35813f67d0948..0000000000000 --- a/src/relay/analysis/context_analysis.cc +++ /dev/null @@ -1,719 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/analysis/context_analysis.cc - * \brief A pass for analyzing device attribute of each IR node. - * - * We use union-find data structures to analyze the context information of each - * sub-expression in a Relay program in this pass. Only the device copy node in - * Relay directly contains bidiretional device information. We use it to - * bidirectionally propagate the device info of its inputs and outputs. - * - * However, to support dynamism (e.g dynamic inputs), Relay introduces several - * concepts to compute the shape of tensors and operators at runtime, i.e. - * shape_of, shape_func, and reshape_tensor. These nodes are also referred to as - * VM dialects as we have native VM instructions for them. These dialects are - * intrinsically CPU friendly, therefore, they are only designed to be - * executed on CPU. We, hence, unify their inputs and outputs to CPU as well. - * Note the input of shape_of is a tensor and we only need the tensor shape. - * Therefore, the input could be sitting on GPU as well since no real data is - * needed. The context of the input would be propagated from its other - * consumers or fallback to the default device. - * - * Another type of dialect is used fo memory allocation, namely, alloc_storage - * and alloc_tensor. alloc_storage contains a context field to indicate where - * the chunk of memory is allocated. Therefore, we unify the context of - * alloc_storage with the context field. Other inputs, such as size and - * alignment, are left on CPU. - * - * Based on the above rules, we keep unifying the connected expressions and - * propagating their device information. An error will be raised whenever there - * is a unification conflict. All IR nodes that are not propagated with device - * context will fallback to the specified device. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace relay { - -using PackedAnalysisResultMap = Map>; -using AnalysisResultMap = - std::unordered_map; - -namespace analysis { - -// Cache ops -static const Op& device_copy_op = Op::Get("device_copy"); -static const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); -static const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); -static const Op& shape_of_op = Op::Get("vm.shape_of"); -static const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); -static const Op& shape_func_of = Op::Get("vm.shape_func"); -static const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); - -class DeviceDomain; -using DeviceDomainPtr = std::shared_ptr; - -/* - * \brief A class to represent the device of a domain, i.e. a segment of relay program. - */ -class DeviceDomain { - public: - // Construct an empty domain. - DeviceDomain() { - device_.device_type = static_cast(-1); - device_.device_id = -1; - } - - // Construct a domain based on a given context. - explicit DeviceDomain(const Device& dev) : device_(dev) {} - - // Check if the current domain is empty. - bool IsEmptyDomain() const { - return static_cast(device_.device_type) == -1 && device_.device_id == -1; - } - - // Check if the current domain equals the other one. - bool operator==(const DeviceDomain& other) const { - return device_.device_type == other.device_.device_type && - device_.device_id == other.device_.device_id; - } - - bool operator!=(const DeviceDomain& other) const { return !(*this == other); } - - private: - // Create a hash for a domain. - struct Hash { - size_t operator()(const DeviceDomainPtr& domain) const { - if (domain->IsEmptyDomain()) { - return static_cast(reinterpret_cast(domain.get())); - } else { - size_t const h1(std::hash()(static_cast(domain->device_.device_type))); - size_t const h2(std::hash()(domain->device_.device_id)); - return h1 ^ (h2 << 1); - } - } - }; - - // Create an equality for domains. - struct Equal { - public: - bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { - // We compare the pointer for empty domains. - if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) return lhs.get() == rhs.get(); - - // Otherwise device type and id are used to check equality. - return (*lhs.get() == *rhs.get()); - } - }; - - /* \brief The device to be assigned to the current domain. */ - Device device_; - - friend DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); - friend class ContextAnalyzer; -}; - -// Join two domains. -DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { - if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) { - return lhs; - } else if (lhs->IsEmptyDomain()) { - return rhs; - } else if (rhs->IsEmptyDomain()) { - return lhs; - } else { - ICHECK(*lhs.get() == *rhs.get()) << "All expressions must have a singular device to unify"; - return lhs; - } -} - -/* - * \brief Compute on which device each sub-expression will execute. A union find - * algorithm is used to assign and merge the context domains. - */ -class ContextAnalyzer : public MixedModeVisitor { - public: - ContextAnalyzer(const IRModule& mod, const GlobalVar& current_func, - const Device& default_device) - : MixedModeVisitor(9), // the number of repeated visits a node can perform - mod_(mod), - current_func_(current_func), - default_device_(default_device) { - cpu_dev_.device_type = kDLCPU; - cpu_dev_.device_id = 0; - } - - // Create an empty domain. - // This usually happens when we enter a new scope, i.e. Function. - DeviceDomainPtr Bottom() { return std::make_shared(DeviceDomain()); } - - // Create a domain with the given device context. - DeviceDomainPtr DeviceType(const Device& dev) { - return std::make_shared(DeviceDomain(dev)); - } - - // Find the root of a device. - DeviceDomainPtr Lookup(DeviceDomainPtr device) { - while (device_uf_.count(device) && device != device_uf_[device]) { - // Path compression - if (device_uf_.count(device_uf_[device])) { - device_uf_[device] = device_uf_[device_uf_[device]]; - } - device = device_uf_[device]; - } - return device; - } - - // Unify two domains. - DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { - lhs = Lookup(lhs); - rhs = Lookup(rhs); - auto unified_device = Join(lhs, rhs); - if (lhs != unified_device) { - device_uf_[lhs] = unified_device; - } - - if (rhs != unified_device) { - device_uf_[rhs] = unified_device; - } - - return unified_device; - } - - // Unify the domain for two IR nodes. - DeviceDomainPtr UnifyExpr(const Expr& lhs, const Expr& rhs) { - auto lhs_dom = DeviceFor(lhs); - auto rhs_dom = DeviceFor(rhs); - return Unify(lhs_dom, rhs_dom); - } - - // Lookup or insert an IR node to device domain map. - DeviceDomainPtr DeviceFor(const Expr& expr) { - auto it = expr_to_device_.find(expr); - if (it == expr_to_device_.end()) { - auto bottom = Bottom(); - expr_to_device_[expr] = bottom; - return bottom; - } else { - return it->second; - } - } - - // Unify the device context for a device copy node. Device copy node is - // the only node that carries bidirectional devices in the input program. The device - // attribute of other nodes can be propagated from it. - void UnifyDeviceCopy(const std::vector& inps, const std::vector& outputs, - DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { - Device src_dev; - src_dev.device_type = src_dev_type; - src_dev.device_id = 0; - auto src_domain = DeviceType(src_dev); - for (const auto& it : inps) { - auto lhs = DeviceFor(it); - Unify(lhs, src_domain); - } - - Device dst_dev; - dst_dev.device_type = dst_dev_type; - dst_dev.device_id = 0; - auto dst_domain = DeviceType(dst_dev); - for (const auto& it : outputs) { - auto lhs = DeviceFor(it); - Unify(lhs, dst_domain); - } - } - - // Unify the domain of inputs and outputs of a relay call. - // - // For most call nodes, the op, inputs, and outputs should all be in the - // same domain, i.e. having the same context. However, device_copy call node - // needs to be handled differently as it copies data from one device to - // another. - DeviceDomainPtr UnifyCall(const Expr& call_op, const Array& inps, - const Array& outputs, DeviceDomainPtr device) { - device = Unify(device, DeviceFor(call_op)); - - for (const auto& it : inps) { - device = Unify(device, DeviceFor(it)); - } - - for (const auto& it : outputs) { - device = Unify(device, DeviceFor(it)); - } - - return device; - } - - void VisitExpr_(const CallNode* cn) final { - Call call = GetRef(cn); - - if (IsDeviceCopy(call)) { - UnifyDeviceCopyCall(cn); - } else if (call->op == alloc_storage_op) { - UnifyAllocStorageCall(cn); - } else if (call->op == alloc_tensor_op) { - UnifyAllocTensorCall(cn); - } else if (call->op == shape_func_of) { - UnifyShapeFuncCall(cn); - } else if (call->op == shape_of_op) { - UnifyShapeOfCall(cn); - } else if (call->op == invoke_tvm_op) { - UnifyInvokeTVMOpCall(cn); - } else if (call->op == reshape_tensor_op) { - UnifyReshapeTensorCall(cn); - } else if (call->op.as()) { - UnifyFunctionCall(cn); - } else if (call->op.as()) { - UnifyGlobalVarCall(cn); - } else if (call->op.as()) { - UnifyVarCall(cn); - } else { - UnifyCall(call, cn->args, {call}, Bottom()); - MixedModeVisitor::VisitExpr_(cn); - } - } - - void VisitExpr_(const LetNode* ln) final { - Expr expr = GetRef(ln); - // Iteratively visit let nodes to avoid stack overflow. - while (expr->IsInstance()) { - Let let = Downcast(expr); - // Save currying/closures since they will be invoked later - auto ty = let->value->checked_type(); - if (ty->IsInstance()) { - auto gv = ExtractClosure(let); - ICHECK(gv.defined() && gv->IsInstance()); - closures_[let->var] = Downcast(gv); - } - - // Unify let var, value, and body - Unify(DeviceFor(let->var), DeviceFor(let->value)); - UnifyExpr(let, let->body); - MixedModeVisitor::VisitExpr(let->value); - expr = let->body; - } - // Visit the last body - MixedModeVisitor::VisitExpr(expr); - } - - void VisitExpr_(const FunctionNode* fn) final { - auto func = GetRef(fn); - // No need to step into fused primitive functions as they are handled as - // a whole. - if (fn->HasNonzeroAttr(attr::kPrimitive)) { - return; - } - - auto device = Unify(DeviceFor(func), DeviceFor(fn->body)); - for (const auto& it : fn->params) { - DeviceFor(it); - } - MixedModeVisitor::VisitExpr(fn->body); - } - - void VisitExpr_(const TupleNode* tn) final { - // We only support tuple with the same of device. - Tuple tup = GetRef(tn); - if (tn->fields.size() > 0) { - auto device = DeviceFor(tup->fields[0]); - for (size_t i = 1; i < tup->fields.size(); i++) { - device = Unify(device, DeviceFor(tup->fields[i])); - } - Unify(device, DeviceFor(tup)); - } - MixedModeVisitor::VisitExpr_(tn); - } - - void VisitExpr_(const TupleGetItemNode* tn) final { - TupleGetItem item = GetRef(tn); - - Unify(DeviceFor(item), DeviceFor(item->tuple)); - - MixedModeVisitor::VisitExpr_(tn); - } - - void VisitExpr_(const MatchNode* mn) final { - // For match node, we unify the value and the rhs of each clause - Match m = GetRef(mn); - auto device = Unify(DeviceFor(m), DeviceFor(m->data)); - for (const auto& c : m->clauses) { - device = Unify(device, DeviceFor(c->rhs)); - } - MixedModeVisitor::VisitLeaf(mn->data); - for (const Clause& c : mn->clauses) { - this->VisitClause(c); - MixedModeVisitor::VisitLeaf(c->rhs); - } - } - - void VisitExpr_(const GlobalVarNode* gvn) final { DeviceFor(GetRef(gvn)); } - - void VisitExpr_(const VarNode* vn) { DeviceFor(GetRef(vn)); } - - void VisitExpr_(const ConstantNode* cn) final { DeviceFor(GetRef(cn)); } - - // Return the analysis results. - AnalysisResultMap Results() { - AnalysisResultMap ret; - for (const auto& it : expr_to_device_) { - auto device = Lookup(it.second); - if (device->IsEmptyDomain()) { - ret[it.first] = default_device_; - } else { - ret[it.first] = device->device_; - } - } - - return ret; - } - - private: - Expr ExtractClosure(Expr expr) const { - while (expr->IsInstance()) { - Let let = Downcast(expr); - expr = let->value; - if (expr->IsInstance()) { - return expr; - } else { - const auto* cn = expr.as(); - if (cn && cn->op->IsInstance()) { - return cn->op; - } - } - } - return Expr(nullptr); - } - - // Check if an expression is a device copy call. - bool IsDeviceCopy(const Expr& expr) const { - if (!expr->IsInstance()) return false; - - Call call = Downcast(expr); - if (call->op == device_copy_op) return true; - - // Fused function with device copy op as the body - // device copy op is opaque therefore the fused function only has one node. - if (const FunctionNode* fn = call->op.as()) { - if (const CallNode* cn = fn->body.as()) { - return cn->op == device_copy_op; - } - } - - return false; - } - - // Check if a function is a closure. - bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } - - // Check if a function is a currying function. - bool IsCurrying(const Function& func) { - if (const auto* let = func->body.as()) { - return closures_.find(let->var) != closures_.end(); - } - return false; - } - - // Process device copy call node - void UnifyDeviceCopyCall(const CallNode* call) { - ICHECK_EQ(call->args.size(), 1U); - - std::vector inps{call->args[0]}; - std::vector outs{GetRef(call)}; - DLDeviceType src_dev_type, dst_dev_type; - const DeviceCopyAttrs* attrs = nullptr; - if (const auto* fn = call->op.as()) { - // device_copy is fused, propagate device to the fused function. - inps.push_back(fn->params[0]); - outs.push_back(call->op); - Expr body = fn->body; - ICHECK(body->IsInstance() && IsDeviceCopy(body)); - Call call_body = Downcast(body); - attrs = call_body->attrs.as(); - } else { - attrs = call->attrs.as(); - } - ICHECK(attrs != nullptr); - src_dev_type = static_cast(attrs->src_dev_type); - dst_dev_type = static_cast(attrs->dst_dev_type); - - // Device copy op only has one input which is now annotated with the - // same device to the source device type of the device copy op. - // The call itself has the same device type to the destination. - UnifyDeviceCopy(inps, outs, src_dev_type, dst_dev_type); - MixedModeVisitor::VisitExpr_(call); - } - - void UnifyAllocStorageCall(const CallNode* call) { - // [size, alignment] - ICHECK_EQ(call->args.size(), 2U); - - // The arguments of alloc storage should be on CPU. - for (int i = 0; i < 2; i++) { - Unify(DeviceFor(call->args[i]), DeviceType(cpu_dev_)); - MixedModeVisitor::VisitExpr(call->args[i]); - } - Device dev; - const auto* attrs = call->attrs.as(); - dev.device_type = static_cast(attrs->device_type); - dev.device_id = attrs->device_id; - Unify(DeviceFor(GetRef(call)), DeviceType(dev)); - } - - void UnifyAllocTensorCall(const CallNode* call) { - // [storage, offset, shape] - ICHECK_EQ(call->args.size(), 3U); - - Expr storage = call->args[0]; - Expr shape = call->args[1]; - Unify(DeviceFor(storage), DeviceFor(GetRef(call))); - - // The shape for alloc_tensor should be on CPU. - Unify(DeviceFor(shape), DeviceType(cpu_dev_)); - MixedModeVisitor::VisitExpr(shape); - } - - void UnifyShapeFuncCall(const CallNode* call) { - // [func, inputs, outputs] - ICHECK_EQ(call->args.size(), 3U); - auto shape_func_domain = DeviceType(cpu_dev_); - - // No need to unify the op of a shape_func as shape_func doesn't - // invoke the op itself. It should be handled by invoke_tvm_op. - // Therefore, we skip call.args[0] here. - Tuple inps = Downcast(call->args[1]); - Tuple outputs = Downcast(call->args[2]); - UnifyCall(GetRef(call), inps->fields, outputs->fields, shape_func_domain); - for (const auto& it : inps->fields) { - MixedModeVisitor::VisitExpr(it); - } - - for (const auto& it : outputs->fields) { - MixedModeVisitor::VisitExpr(it); - } - } - - void UnifyInvokeTVMOpCall(const CallNode* call) { - // [op, inputs, outputs] - ICHECK_EQ(call->args.size(), 3U); - Tuple inps = Downcast(call->args[1]); - Tuple outputs = Downcast(call->args[2]); - UnifyCall(call->args[0], inps->fields, outputs->fields, Bottom()); - MixedModeVisitor::VisitExpr_(call); - } - - void UnifyShapeOfCall(const CallNode* call) { - // vm shape_of is always on the CPU. - ICHECK_EQ(call->args.size(), 1U); - MixedModeVisitor::VisitExpr(call->args[0]); - // Note we don't unify the input of a shape_of with the cpu domain. This is - // because vm.shape_of has a native instruction to compute the shape of - // a tensor regardless its device type. - // Instead, the device type of the input is left for its other consumers to - // unify or it will fallback to the default context. - Unify(DeviceFor(GetRef(call)), DeviceType(cpu_dev_)); - } - - void UnifyReshapeTensorCall(const CallNode* call) { - // [data, shape] - ICHECK_EQ(call->args.size(), 2U); - Expr data = call->args[0]; - Expr shape = call->args[1]; - Unify(DeviceFor(GetRef(call)), DeviceFor(data)); - - // The shape field of reshape_tensor is always on the CPU. - Unify(DeviceFor(shape), DeviceType(cpu_dev_)); - MixedModeVisitor::VisitExpr(data); - MixedModeVisitor::VisitExpr(shape); - } - - void UnifyFunctionCall(const CallNode* call) { - auto device = DeviceFor(GetRef(call)); - // Unify the arguments of the caller. - for (const auto& arg : call->args) { - device = Unify(device, DeviceFor(arg)); - MixedModeVisitor::VisitExpr(arg); - } - - // Unify the parameters of the callee. - if (!call->op->IsInstance()) return; - Function func = Downcast(call->op); - for (const auto& param : func->params) { - device = Unify(device, DeviceFor(param)); - MixedModeVisitor::VisitExpr(param); - } - - // Unify the function expression and its body - Unify(device, DeviceFor(call->op)); - Unify(device, DeviceFor(func->body)); - - // Step into the callee. It will be skipped if the callee if a primitive - // function - MixedModeVisitor::VisitExpr(call->op); - } - - // Invoke a global function. - void UnifyGlobalVarCall(const CallNode* call) { - auto device = DeviceFor(GetRef(call)); - ICHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module"; - GlobalVar gv = Downcast(call->op); - auto func = Downcast(mod_->Lookup(gv)); - ICHECK_EQ(call->args.size(), func->params.size()) - << "The number of arguments doesn't match the number of parameters of the function."; - - for (size_t i = 0; i < call->args.size(); i++) { - Expr arg = call->args[i]; - Expr param = func->params[i]; - MixedModeVisitor::VisitExpr(arg); - - // Save the the arg to function mapping for closures as it will - // be invoked/unified later. - ICHECK(arg->checked_type().defined()) - << "Type inference is required to run the context analysis passes."; - if (arg->checked_type()->IsInstance()) { - auto it = closures_.find(arg); - if (it != closures_.end()) { - closures_[param] = it->second; - } else { - ICHECK(arg->IsInstance()); - closures_[param] = Downcast(arg); - } - } - Unify(DeviceFor(arg), DeviceFor(param)); - } - device = Unify(device, DeviceFor(call->op)); - device = Unify(device, DeviceFor(func)); - device = Unify(device, DeviceFor(func->body)); - - // Step into the callee. We need to skip recursive calls, otherwise, it - // would be a infinite loop. - // - // TODO(@zhiics) This may cause problem for mutual recursive calls as well. - auto cur_func = current_func_; - current_func_ = gv; - if (cur_func->name_hint != gv->name_hint) { - MixedModeVisitor::VisitExpr(func); - } - // Exit the frame. - current_func_ = cur_func; - } - - void UnifyVarCall(const CallNode* call) { - // It is a closure when we call a var. - // Unify the corresponding arguement and parameter. - auto device = DeviceFor(GetRef(call)); - auto it = closures_.find(call->op); - ICHECK(it != closures_.end()) << "Cannot find var: " << call->op; - auto glb_var = it->second; - ICHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module"; - Function func = Downcast(mod_->Lookup(glb_var)); - // Unify the underlying function for clousre or currying functions. - while (IsClosure(func) || IsCurrying(func)) { - device = Unify(device, DeviceFor(func)); - if (IsClosure(func)) { - func = Downcast(func->body); - } else if (IsCurrying(func)) { - Let let = Downcast(func->body); - func = Downcast(mod_->Lookup(closures_[let->var])); - } else { - LOG(FATAL) << "func is expected to be a closure or a currying function"; - } - } - - ICHECK_EQ(call->args.size(), func->params.size()); - for (size_t i = 0; i < call->args.size(); i++) { - Unify(DeviceFor(call->args[i]), DeviceFor(func->params[i])); - MixedModeVisitor::VisitExpr(call->args[i]); - } - device = Unify(device, DeviceFor(call->op)); - device = Unify(device, DeviceFor(glb_var)); - device = Unify(device, DeviceFor(func)); - - // Step into the global function. - auto cur_func = current_func_; - current_func_ = glb_var; - if (cur_func->name_hint != glb_var->name_hint) { - MixedModeVisitor::VisitExpr(func); - } - current_func_ = cur_func; - } - - private: - /* \brief The cpu context. */ - Device cpu_dev_; - /* \brief The module that helps context analysis. */ - const IRModule& mod_; - /* \brief The current function that is being analyzed. */ - GlobalVar current_func_; - /* \brief The default device that could be attached to an expression. */ - const Device& default_device_; - /* \brief The IR node to device domain mapping. */ - std::unordered_map - expr_to_device_; - /* \brief The domain map for union-find. */ - std::unordered_map - device_uf_; - /* - * \brief The expr to global var map. It saves the closures/currying that - * will be invoked lazily. - */ - std::unordered_map closures_; -}; - -} // namespace analysis - -AnalysisResultMap ContextAnalysis(const IRModule& mod, const Device& default_device) { - // TODO(@zhiics) Apply the pass to all functions/entries - auto entry = mod->GetGlobalVar("main"); - auto ca = analysis::ContextAnalyzer(mod, entry, default_device); - auto expr = mod->Lookup(entry); - ca.VisitExpr(expr); - return ca.Results(); -} - -// Unpack the device type and deivce id fields in Device for PackedFunc calls -// as Device is not in the object system. -PackedAnalysisResultMap ContextAnalysisPacked(const IRModule& mod, const Device& default_device) { - PackedAnalysisResultMap ret; - auto res = ContextAnalysis(mod, default_device); - for (const auto& it : res) { - Integer dev_ty = static_cast(it.second.device_type); - Integer dev_id = it.second.device_id; - ret.Set(it.first, {dev_ty, dev_id}); - } - - return ret; -} - -TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis").set_body_typed(ContextAnalysisPacked); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/attrs/annotation.cc similarity index 63% rename from src/relay/op/annotation/annotation.cc rename to src/relay/attrs/annotation.cc index b59c5a3e9ff3f..99ea8f4d5989e 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/attrs/annotation.cc @@ -19,10 +19,12 @@ /*! * - * \file src/relay/op/annotation/annotation.cc - * \brief Registration of annotation operators. + * \file src/relay/attrs/annotation.cc + * \brief Helpers for working with various 'annotations' attributes. */ +#include "./annotation.h" + #include #include #include @@ -30,21 +32,46 @@ #include #include -#include "../../transforms/infer_layout_utils.h" -#include "../type_relations.h" +#include "../op/type_relations.h" +#include "../transforms/infer_layout_utils.h" namespace tvm { namespace relay { -// relay.annotation.on_device TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); +const Op& OnDeviceOp() { + static const Op& op = Op::Get("on_device"); + return op; +} + +Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { + auto attrs = make_object(); + attrs->device_type = device_type; + attrs->is_fixed = is_fixed; + Span span = expr->span; + return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span); +} + +Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { + if (device_type == kInvalidDeviceType) { + return expr; + } + if (expr->IsInstance() || expr->IsInstance() || + expr->IsInstance() || expr->IsInstance()) { + return expr; + } + if (const auto* function_node = expr.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return expr; + } + } + return OnDevice(expr, device_type, is_fixed); +} + TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") - .set_body_typed([](Expr data, int device_type) { - auto attrs = make_object(); - attrs->device_type = device_type; - static const Op& op = Op::Get("on_device"); - return Call(op, {data}, Attrs(attrs), {}); + .set_body_typed([](Expr expr, int device_type, bool is_fixed) { + return OnDevice(expr, static_cast(device_type), is_fixed); }); RELAY_REGISTER_OP("on_device") @@ -56,12 +83,98 @@ RELAY_REGISTER_OP("on_device") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("TNonComputational", true) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, const Type& out_type) -> Array { return {topi::identity(inputs[0])}; }); +OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { + if (call_node->op == OnDeviceOp()) { + ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; + ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; + const auto* on_device_attrs = call_node->attrs.as(); + ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; + auto device_type = static_cast(on_device_attrs->device_type); + // Follow nesting: + // on_device(on_device(expr, device_type=1), device_type=2) == {expr, 1} + auto inner = GetOnDeviceProps(call_node->args[0]); + if (inner.body.defined()) { + return {inner.body, inner.device_type, on_device_attrs->is_fixed || inner.is_fixed}; + } else { + return {call_node->args[0], device_type, on_device_attrs->is_fixed}; + } + } + return {}; +} + +OnDeviceProps GetOnDeviceProps(const Expr& expr) { + if (const auto* call_node = expr.as()) { + return GetOnDeviceProps(call_node); + } + return {}; +} + +TVM_REGISTER_NODE_TYPE(FunctionOnDeviceAttrs); + +Function FunctionOnDevice(Function function, Array param_device_types, + DLDeviceType result_device_type) { + auto attrs = make_object(); + attrs->param_device_types = std::move(param_device_types); + attrs->result_device_type = result_device_type; + return WithAttr(std::move(function), FunctionOnDeviceAttrs::kFunctionAttrsKey, + Attrs(std::move(attrs))); +} + +Function FunctionOnDevice(Function function, const std::vector& param_device_types, + DLDeviceType result_device_type) { + Array arr; + arr.reserve(param_device_types.size()); + for (const auto device_type : param_device_types) { + arr.push_back(static_cast(device_type)); + } + return FunctionOnDevice(function, arr, result_device_type); +} + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") + .set_body_typed([](Function function, Array param_device_types, + int result_device_type) { + return FunctionOnDevice(function, param_device_types, + static_cast(result_device_type)); + }); + +DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node) { + auto opt_attrs = function_node->GetAttr(FunctionOnDeviceAttrs::kFunctionAttrsKey); + if (!opt_attrs) { + // No annotation. + return kInvalidDeviceType; + } + const auto* opt_function_on_device_attrs = opt_attrs.value().as(); + ICHECK(opt_function_on_device_attrs != nullptr) + << "function '" << FunctionOnDeviceAttrs::kFunctionAttrsKey + << "' annotation must be a FunctionOnDeviceAttrs"; + return static_cast(opt_function_on_device_attrs->result_device_type); +} + +DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i) { + ICHECK_LT(i, function_node->params.size()) + << "param index " << i << " out of range for function of arity " + << function_node->params.size(); + auto opt_attrs = function_node->GetAttr(FunctionOnDeviceAttrs::kFunctionAttrsKey); + if (!opt_attrs) { + // No annotation. + return kInvalidDeviceType; + } + const auto* opt_function_on_device_attrs = opt_attrs.value().as(); + ICHECK(opt_function_on_device_attrs != nullptr) + << "function '" << FunctionOnDeviceAttrs::kFunctionAttrsKey + << "' annotation must be a FunctionOnDeviceAttrs"; + ICHECK_EQ(opt_function_on_device_attrs->param_device_types.size(), function_node->params.size()) + << "annotation parameters do not match function arity"; + return static_cast(opt_function_on_device_attrs->param_device_types[i]->value); +} + Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); return Call(op, {data}, Attrs{}, {}); diff --git a/src/relay/attrs/annotation.h b/src/relay/attrs/annotation.h new file mode 100644 index 0000000000000..f4c7e7f73e7b9 --- /dev/null +++ b/src/relay/attrs/annotation.h @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/attrs/annotation.h + * \brief Helpers for working with various 'annotation' attributes. + */ +#ifndef TVM_SRC_RELAY_ATTRS_ANNOTATION_ANNOTATION_H_ +#define TVM_SRC_RELAY_ATTRS_ANNOTATION_ANNOTATION_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Returns the "on_device" operator. */ +const Op& OnDeviceOp(); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. + */ +Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. However + * returns \p expr directly if: + * - \p device_type is \p kInvalidDeviceType, which signals there are no device annotations + * already in play. + * - \p expr is an operator or primitive function literal. These are device polymorphic. + * - \p expr is a global or local var. These already have an implied device. + * - \p expr is a constructor. There should probably be device polymorphic but are in an + * in-between state at the moment. + */ +Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); + +/*! \brief Result of \p GetOnDeviceProps. */ +struct OnDeviceProps { + Expr body; // = null + DLDeviceType device_type = kInvalidDeviceType; + bool is_fixed = false; + + OnDeviceProps() = default; + + OnDeviceProps(const Expr& body, DLDeviceType deviceType, bool isFixed) + : body(body), device_type(deviceType), is_fixed(isFixed) {} +}; + +/*! + * \brief Returns the body expression, device type and is_fixed field for \p call_node if it is + * an "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p + * false. + */ +OnDeviceProps GetOnDeviceProps(const CallNode* call_node); + +/*! + * \brief Returns the body expression, device type and is_fixed field for \p expr if it is an + * "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p false. + */ +OnDeviceProps GetOnDeviceProps(const Expr& expr); + +/*! \brief Returns true if \p expr is an on_device CallNode. */ +inline bool IsOnDeviceCall(const Expr& expr) { return GetOnDeviceProps(expr).body.defined(); } + +/*! + * \brief Returns \p function annotated with "on_device" attributes capturing parameter and result + * devices types. However returns \p function directly if all device types are \p + * kInvalidDeviceType. + */ +Function FunctionOnDevice(Function function, Array param_device_types, + DLDeviceType body_device_type); +Function FunctionOnDevice(Function function, const std::vector& param_device_types, + DLDeviceType body_device_type); + +/*! + * \brief Returns the device type for the resut of \p function_node, or \p kInvalidDeviceType + * if function does not have "on_device" annotation. + */ +DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node); + +/*! + * \brief Returns the device type for the \p i'th parameter of \p function_node, or + * \p kInvalidDeviceType if function does not have "on_device" annotation. + */ +DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i); + +/*! \brief Wraps \p data in a "stop_fusion" annotation. */ +Expr StopFusion(Expr data); + +/*! \brief Wraps \p data in a "cast_hint" annotation for \p dtype. */ +Expr CastHint(Expr data, DataType dtype); + +} // namespace relay +} // namespace tvm + +#endif // TVM_SRC_RELAY_ATTRS_ANNOTATION_ANNOTATION_H_ diff --git a/src/relay/attrs/device_copy.cc b/src/relay/attrs/device_copy.cc new file mode 100644 index 0000000000000..5dffd3a1c1c33 --- /dev/null +++ b/src/relay/attrs/device_copy.cc @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/attrs/device_copy.cc + * \brief Helpers for working with "device_copy" attributes. + */ + +#include "./device_copy.h" + +#include +#include +#include +#include +#include + +#include "../op/type_relations.h" +#include "../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { + +// relay.device_copy +TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); + +const Op& DeviceCopyOp() { + static const Op& op = Op::Get("device_copy"); + return op; +} + +Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { + auto attrs = make_object(); + attrs->src_dev_type = src_dev_type; + attrs->dst_dev_type = dst_dev_type; + Span span = expr->span; + return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(attrs), /*type_args=*/{}, span); +} + +Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { + if (src_dev_type == dst_dev_type) { + return expr; + } + ICHECK_NE(src_dev_type, kInvalidDeviceType); + ICHECK_NE(dst_dev_type, kInvalidDeviceType); + return DeviceCopy(expr, src_dev_type, dst_dev_type); +} + +TVM_REGISTER_GLOBAL("relay.op._make.device_copy") + .set_body_typed([](Expr expr, int src_dev_type, int dst_dev_type) { + return DeviceCopy(expr, static_cast(src_dev_type), + static_cast(dst_dev_type)); + }); + +RELAY_REGISTER_OP("device_copy") + .describe(R"code( +Copy data from one tensor to another. The source and destination might be +on different devices. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); + +DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { + if (call_node->op == DeviceCopyOp()) { + ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument"; + ICHECK(call_node->attrs.defined()) << "device_copy requires attributes"; + const auto* device_copy_attrs = call_node->attrs.as(); + ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs"; + auto src_dev_type = static_cast(device_copy_attrs->src_dev_type); + auto dst_dev_type = static_cast(device_copy_attrs->dst_dev_type); + // Follow nesting: + // device_copy(device_copy(expr, src_dev_type=1, dst_dev_type=2), + // src_dev_type=2, dst_dev_type=3) ==> {expr, 1, 3} + auto inner = GetDeviceCopyProps(call_node->args[0]); + if (inner.body.defined()) { + return {inner.body, inner.src_dev_type, inner.dst_dev_type}; + } else { + return {call_node->args[0], src_dev_type, dst_dev_type}; + } + } + return {}; +} + +DeviceCopyProps GetDeviceCopyProps(const Expr& expr) { + if (const auto* call_node = expr.as()) { + return GetDeviceCopyProps(call_node); + } + return {}; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/attrs/device_copy.h b/src/relay/attrs/device_copy.h new file mode 100644 index 0000000000000..a987d1eabdf37 --- /dev/null +++ b/src/relay/attrs/device_copy.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/attrs/device_copy.h + * \brief Helpers for working with "device_copy" attributes. + */ + +#ifndef TVM_SRC_RELAY_ATTRS_DEVICE_COPY_H_ +#define TVM_SRC_RELAY_ATTRS_DEVICE_COPY_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Returns the "device_copy" operator. */ +const Op& DeviceCopyOp(); + +/*! + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on + * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. + */ +Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); + +/*! + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on + * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. + * However, return \p expr directly if \p src_dev_type equals \p dst_dev_type. + */ +Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); + +/*! \brief Result of \p GetDeviceCopyProps. */ +struct DeviceCopyProps { + Expr body; // = null + DLDeviceType src_dev_type = kInvalidDeviceType; + DLDeviceType dst_dev_type = kInvalidDeviceType; + + DeviceCopyProps() = default; + + DeviceCopyProps(const Expr& body, DLDeviceType srcDevType, DLDeviceType dstDevType) + : body(body), src_dev_type(srcDevType), dst_dev_type(dstDevType) {} +}; + +/*! + * \brief Returns the body expression, source, and destination device types for \p call_node if it + * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType + * device types. + */ +DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node); + +/*! + * \brief Returns the body expression, source, and destination device types for \p expr if it + * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType + * device types. + */ +DeviceCopyProps GetDeviceCopyProps(const Expr& expr); + +} // namespace relay +} // namespace tvm + +#endif // TVM_SRC_RELAY_ATTRS_DEVICE_COPY_H_ diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 70779ac58abf5..c9a8ad648beac 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -38,8 +39,10 @@ #include #include -#include "te_compiler.h" -#include "utils.h" +#include "../attrs/annotation.h" +#include "../transforms/device_planner.h" +#include "./te_compiler.h" +#include "./utils.h" namespace tvm { namespace relay { @@ -53,18 +56,12 @@ using StorageMap = * This is an on demand allocator for AOT. A new temporary * (storage allocator identifier) is allocated for each operation. */ -class AOTOnDemandAllocator : public MixedModeVisitor { +class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { public: - // run the visitor on a function. - void Run(const Function& func) { - node_device_map_ = CollectDeviceInfo(func); + AOTOnDemandAllocator() = default; - for (Expr param : func->params) { - CreateStorage(param.operator->()); - } - - GetStorage(func->body); - } + // run the visitor on a global function. + void Run(const Function& func) { VisitExpr(func); } std::vector GetReturnIds() const { return return_ids_; } @@ -75,8 +72,9 @@ class AOTOnDemandAllocator : public MixedModeVisitor { AssignReturnSid(GetRef(op)); } - void VisitExpr_(const CallNode* op) final { + void DeviceAwareVisitExpr_(const CallNode* op) final { // create token for the call node. + VisitExpr(op->op); CreateStorage(op); for (Expr arg : op->args) { GetStorage(arg); @@ -86,8 +84,15 @@ class AOTOnDemandAllocator : public MixedModeVisitor { void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } - void VisitExpr_(const FunctionNode* op) final { - // do not recurse into sub function. + void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { + if (function_nesting() > 1) { + // do not recurse into sub function. + return; + } + for (const auto& param : func_node->params) { + CreateStorage(param.get()); + } + GetStorage(func_node->body); } void VisitExpr_(const GlobalVarNode* op) final { @@ -127,7 +132,9 @@ class AOTOnDemandAllocator : public MixedModeVisitor { void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } - void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "let is not supported."; } + void PreVisitLetBinding_(const Var& var, const Expr& value) final { + LOG(FATAL) << "let is not supported."; + } private: void AssignReturnSid(Expr e) { @@ -152,8 +159,7 @@ class AOTOnDemandAllocator : public MixedModeVisitor { * \param prototype The prototype token. * \return The required memory size. */ - size_t GetMemorySizeBytes(const TensorTypeNode* ttype) { - ICHECK(ttype != nullptr); + size_t GetMemorySizeBytes(const TensorType& ttype) { size_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); @@ -170,44 +176,40 @@ class AOTOnDemandAllocator : public MixedModeVisitor { * \return The corresponding token. */ StorageInfo GetStorage(const Expr& expr) { - this->VisitExpr(expr); - auto it = storage_device_map_.find(expr); + auto props = GetOnDeviceProps(expr); + // See through "on_device" calls. + Expr true_expr = props.body.defined() ? props.body : expr; + VisitExpr(true_expr); + auto it = storage_device_map_.find(true_expr); ICHECK(it != storage_device_map_.end()); return it->second; } /*! * \brief Create storage for the expression. - * \param expr The expression. */ void CreateStorage(const ExprNode* op) { + Expr expr = GetRef(op); + return CreateStorage(expr, GetInScopeDeviceType(expr)); + } + + /*! + * \brief Create storage to hold the result of evaluating \p expr on \p device_type. + */ + void CreateStorage(const Expr& expr, DLDeviceType device_type) { std::vector storage_ids; std::vector device_types; std::vector storage_sizes_in_bytes; - Expr expr = GetRef(op); - int device_type_int = - node_device_map_.count(GetRef(op)) ? node_device_map_[expr]->value : 0; - if (const auto* tuple_type = op->checked_type().as()) { - for (Type t : tuple_type->fields) { - const auto* ttype = t.as(); - ICHECK(ttype); - storage_ids.push_back(next_available_sid_++); - storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); - device_types.push_back(DLDeviceType(device_type_int)); - } - } else { - const auto* ttype = op->checked_type().as(); - ICHECK(ttype); + for (const auto& ttype : FlattenTupleType(expr->checked_type())) { storage_ids.push_back(next_available_sid_++); storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); - device_types.push_back(DLDeviceType(device_type_int)); + device_types.push_back(device_type); } storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes); } + /*! \brief mapping of expression -> storageInfo*/ StorageMap storage_device_map_; - /*! \brief mapping of expression -> device type*/ - Map node_device_map_; /*! \brief current id of the temporary allocated*/ int next_available_sid_{0}; /*! \brief the set of intermediate tensors that are return variables */ @@ -558,54 +560,39 @@ class AOTExecutorCodegen : public MixedModeVisitor { use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))) {} LoweredOutput Codegen(relay::Function func, String mod_name) { - auto aot_allocator = AOTOnDemandAllocator(); - aot_allocator.Run(func); + // TODO(mbs): Codegen at the IRModule level instead of one-function-at-a-time. + IRModule mod = IRModule::FromExpr(func); - // Pre-lowering storage map and memory plan - StorageMap initial_storage_map = aot_allocator.GetStorageMap(); - StaticMemoryPlan memory_plan(initial_storage_map); + // TODO(mbs): Break LowerTE's dependence on the memory plan, which currently requires us + // to generate the storage map twice. + AOTOnDemandAllocator initial_aot_allocator; + initial_aot_allocator.Run(func); - // Build a map from each operation to device. - tec::DeviceMap device_context_map; - for (const auto& it : memory_plan->expr_to_storage_info) { - auto expr = it.first; - auto storage_info = it.second; - auto device_types = storage_info->device_types; - // CHECK_EQ(device_types.size(), 1); - tvm::Device dev; - dev.device_id = 0; - dev.device_type = device_types[0]; - device_context_map.insert({expr, dev}); - } + StorageMap initial_storage_map = initial_aot_allocator.GetStorageMap(); + StaticMemoryPlan memory_plan(initial_storage_map); - // This first phase moves from implicit use of compile engine, - // to instead explicitly lowering the incoming IRModule, and then - // performing the preexisting AOT executor code generation phase. - IRModule mod = IRModule::FromExpr(func); + IRModule lowered_mod = LowerTEPass(targets_, memory_plan, mod_name, [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } - IRModule lowered_mod = - LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) { - // We need to maintain the constant map for external - // functions so we pass this processing function which - // allows us to process each function as we lower it. - if (func->GetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } - - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_); - })(mod); + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); // Post-lowering storage map for writing main func - this should be the same map as previously // created, just referencing the new expressions created from lowering - auto new_allocator = AOTOnDemandAllocator(); - new_allocator.Run(lowered_main_func); - storage_device_map_ = new_allocator.GetStorageMap(); + AOTOnDemandAllocator final_aot_allocator; + final_aot_allocator.Run(lowered_main_func); + storage_device_map_ = final_aot_allocator.GetStorageMap(); for (auto input : lowered_main_func->params) { input_vars_.push_back(input); @@ -622,7 +609,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } // Retrieve the return sids - return_sid_ = aot_allocator.GetReturnIds(); + return_sid_ = final_aot_allocator.GetReturnIds(); for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) { main_signature_.push_back(tir::Var("output", DataType::Handle())); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 69dced36295e7..80db7f576fa33 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -321,6 +321,23 @@ class RelayBuildModule : public runtime::ModuleNode { transform::SplitArgs(target->GetAttr("max_function_args", -1).value())); } + // Handle heterogeneous compilation. + transform::PassContext pass_ctx = PassContext::Current(); + Optional opt_fallback_dev = + pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); + DLDeviceType fallback_dev = static_cast(opt_fallback_dev.value()->value); + ICHECK_GT(fallback_dev, 0U); +#if 0 + // TODO(mbs): Remove + if (targets_.size() > 1) { + relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev); + } +#endif + pass_seqs.push_back(transform::PlanDevices(fallback_dev)); + + // Fuse the operations if it is needed. + pass_seqs.push_back(transform::FuseOps()); + // Create a sequential pass and perform optimizations. transform::Pass seq = transform::Sequential(pass_seqs); if (targets.size() == 1) { @@ -331,19 +348,6 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module = seq(relay_module); } - // Handle heterogeneous compilation. - transform::PassContext pass_ctx = PassContext::Current(); - if (targets_.size() > 1) { - Optional opt_fallback_dev = - pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); - auto fallback_dev = opt_fallback_dev.value(); - ICHECK_GT(fallback_dev->value, 0U); - relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value); - } - - // Fuse the operations if it is needed. - relay_module = transform::FuseOps()(relay_module); - // Do layout rewrite for auto-scheduler. if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) { const auto& target = (*targets.begin()).second; @@ -418,42 +422,29 @@ class RelayBuildModule : public runtime::ModuleNode { */ IRModule RunDeviceAnnotationPass(const IRModule& relay_module, int fallback_device) { UpdateHeterogeneousInputs(fallback_device); - auto rewrite = transform::RewriteAnnotatedOps(fallback_device); - auto updated_module = rewrite(relay_module); - ICHECK(updated_module.defined()); - tvm::Map device_map; - for (const auto& it : updated_module->functions) { - device_map = relay::CollectDeviceInfo(it.second); - if (!device_map.empty()) break; - } - - if (device_map.empty()) { - tvm::Map annotation_map; - for (const auto& it : relay_module->functions) { - annotation_map = relay::CollectDeviceAnnotationOps(it.second); - if (!annotation_map.empty()) break; + // If there's a unique device type used by all "on_device" CallNodes then use that + // as the fallback_device. + // TODO(mbs): This defaulting only roughly matches the original behavior. We should + // cleanup all the logic around default host and device targets. + Map annotations = CollectAllDeviceAnnotationOps(relay_module); + if (!annotations.empty()) { + std::unordered_set device_types; + for (const auto& pair : annotations) { + device_types.insert(static_cast((*annotations.begin()).second->value)); } - // None op is annotated but they are fallen back to the default device. - if (annotation_map.empty()) { - targets_.Set(0, CreateDefaultTarget(fallback_device)); - } else { - // All ops are annotated to the same device type. - int64_t dev_type = -1; - for (auto kv : annotation_map) { - dev_type = kv.second->value; - break; - } - for (auto kv : annotation_map) { - ICHECK_EQ(kv.second->value, dev_type) << "Expressions in the function are " - << "annotated with various device types," - << "but not device copy operators " - << "found. Please check the " - << "RewriteAnnotation pass."; - } - targets_.Set(0, CreateDefaultTarget(dev_type)); + if (device_types.size() == 1UL) { + fallback_device = *device_types.begin(); } } + // Make sure the 'default' target is always available keyed by 0. + // This is the current convention for conveying which of the >= 2 targets is the default. + targets_.Set(kNullDeviceType, CreateDefaultTarget(fallback_device)); + + // Insert "device_copy" CallNodes to account for any user-supplied "on_device" CallNodes. + auto updated_module = transform::RewriteAnnotatedOps(fallback_device)(relay_module); + ICHECK(updated_module.defined()); + return updated_module; } diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index aca95db34c4e0..4a95f62847c30 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -36,13 +36,17 @@ #include #include -#include "te_compiler.h" -#include "utils.h" +#include "../attrs/annotation.h" +#include "../transforms/device_planner.h" +#include "./te_compiler.h" +#include "./utils.h" namespace tvm { namespace relay { + // TODO(@jroesch, @csullivan): declare directly elsewhere backend::StaticMemoryPlan GraphPlanMemory(const Function& func); + namespace backend { class GraphNode; @@ -197,32 +201,16 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorexpr_to_storage_info) { - auto expr = it.first; - auto storage_info = it.second; - auto device_types = storage_info->device_types; - // CHECK_EQ(device_types.size(), 1); - tvm::Device dev; - dev.device_id = 0; - dev.device_type = device_types[0]; - device_context_map.insert({expr, dev}); - } + // TODO(mbs): LowerTEPass does not really need the memory plan. + memory_plan_ = GraphPlanMemory(func); IRModule lowered_mod = - LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) { + LowerTEPass(targets_, memory_plan_, mod_name_, [this](Function func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. @@ -230,11 +218,11 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); - })(mod); + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); @@ -445,18 +433,21 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const CallNode* call_node) override { relay::Call call = GetRef(call_node); - if (auto global_node = call->op.as()) { - auto prim_fn_name = global_node->name_hint; - - return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); - } else { - ICHECK(false) << "Non-primitive-call nodes should have been transformed away.\n" - << "The graph executor code generator expects all calls to have their callee " - "normalized to a GlobalVar but found a " - << call->GetTypeKey() << "." - << "AST: " << PrettyPrint(call) << PrettyPrint(call) << std::endl; - return {}; + auto props = GetOnDeviceProps(call_node); + if (props.body.defined()) { + // See through "on_device" calls. + return VisitExpr(props.body); } + + const auto* global_node = call->op.as(); + ICHECK(global_node) + << "Non-primitive-call nodes should have been transformed away.\n" + << "The graph executor code generator expects all calls to have their callee " + "normalized to a GlobalVar, but found:" + << std::endl + << PrettyPrint(call); + auto prim_fn_name = global_node->name_hint; + return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); } std::vector VisitExpr_(const LetNode* op) override { diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 93c823d8a0076..f65cccaf651a2 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -27,9 +27,13 @@ #include #include #include +#include #include #include "../../support/arena.h" +#include "../attrs/annotation.h" +#include "../op/memory/memory.h" +#include "../transforms/device_planner.h" #include "./utils.h" namespace tvm { @@ -39,44 +43,45 @@ using backend::StaticMemoryPlan; using backend::StorageInfo; using IntegerArray = Array; +/*! A representation of a block of memory required at runtime on some device. */ struct StorageToken { /*! \brief Reference counter */ int ref_counter{0}; /*! \brief number of bytes */ size_t max_bytes{0}; - /*! \brief The corresponding tensor type node. */ - const TensorTypeNode* ttype{nullptr}; - /*! \brief virtual device index that corresponds to the device_type in - * DLDevice. */ - int device_type{0}; + /*! \brief The corresponding tensor type. */ + TensorType ttype{nullptr}; + /*! \brief Device on which memory will reside. */ + Device device{kInvalidDeviceType, -1}; /*! \brief The storage id */ int64_t storage_id{-1}; + + bool is_valid() const { return device.device_type != kInvalidDeviceType; } + + bool is_compatible(const StorageToken& that) const { + return device.device_type == that.device.device_type; + } }; std::ostream& operator<<(std::ostream& os, StorageToken tok) { return os << "StorageToken: " << std::endl << "ref_counter: " << tok.ref_counter << std::endl << "max_bytes: " << tok.max_bytes << std::endl - << "tttype: " << tok.ttype + << "tttype: " << tok.ttype << std::endl + << "tttype: " << tok.ttype << std::endl + << "device: {" << tok.device.device_type << ", " << tok.device.device_id << "}" << std::endl - // ok idk how to print this properly - << "tttype shape: " << tok.ttype->shape << std::endl - << "device_type: " << tok.device_type << std::endl << "storage_id: " << tok.storage_id << std::endl; } -class StorageAllocaBaseVisitor : public ExprVisitor { +class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { public: - // run the visitor on a function. - void Run(const Function& func) { - for (Var param : func->params) { - CreateToken(param.operator->(), false); - } - // must always keep output alive. - for (StorageToken* tok : GetToken(func->body)) { - tok->ref_counter += 1; - } - } + StorageAllocaBaseVisitor() = default; + + // run the visitor on a global function. + void Run(const Function& func) { VisitExpr(func); } + + using transform::DeviceAwareExprVisitor::VisitExpr_; void VisitExpr_(const ConstantNode* op) final { this->CreateToken(op, false); } @@ -84,8 +89,18 @@ class StorageAllocaBaseVisitor : public ExprVisitor { // Do nothing. } - void VisitExpr_(const FunctionNode* op) final { - // do not recurse into sub function. + void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { + if (function_nesting() > 1) { + // do not recurse into sub function. + return; + } + for (const auto& param : func_node->params) { + CreateToken(param.get(), /*can_realloc=*/false); + } + // Process the function body, and make sure all result tokens are considered 'alive'. + for (StorageToken* tok : GetToken(func_node->body)) { + tok->ref_counter += 1; + } } void VisitExpr_(const GlobalVarNode* op) final { @@ -113,15 +128,17 @@ class StorageAllocaBaseVisitor : public ExprVisitor { void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } - void VisitExpr_(const LetNode* op) final { - auto token = GetToken(op->value); - token_map_[op->var.operator->()] = token; - token_map_[op] = GetToken(op->body); + void PreVisitLetBinding_(const Var& var, const Expr& value) final { + token_map_[var.get()] = GetToken(value); + } + + void PostVisitLet_(const LetNode* let_node) final { + token_map_[let_node] = GetToken(let_node->body); } protected: /*! \brief internal token map */ - std::unordered_map > token_map_; + std::unordered_map> token_map_; /*! * \brief Get the necessary token. @@ -130,27 +147,39 @@ class StorageAllocaBaseVisitor : public ExprVisitor { */ const std::vector& GetToken(const Expr& expr) { this->VisitExpr(expr); - auto it = token_map_.find(expr.operator->()); - ICHECK(it != token_map_.end()) - << "Expression: `" << PrettyPrint(expr) << "` not found in storage map."; + // See through on_device calls. + auto props = GetOnDeviceProps(expr); + Expr real_expr = props.body.defined() ? props.body : expr; + auto it = token_map_.find(real_expr.get()); + ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl + << PrettyPrint(real_expr); return it->second; } + /*! - * \brief Populate the token map to set op's tokens - * \param op The node to be processed. - * \param can_realloc Whether we can re-allocate the memory. + * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding + * the result of evaluating \p op. */ - virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0; + void CreateToken(const ExprNode* op, bool can_realloc) { + return CreateTokenOnDevice(op, GetInScopeDeviceType(GetRef(op)), can_realloc); + } + + /*! + * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding + * the result of evaluating \p op on \p device_type. + */ + virtual void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, + bool can_realloc) = 0; }; +/*! \brief Associate storage with every expression without any concern for sharing. */ class StorageAllocaInit : protected StorageAllocaBaseVisitor { public: explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {} /*! \return The internal token map */ - std::unordered_map > GetInitTokenMap( + std::unordered_map> GetInitTokenMap( const Function& func) { - node_device_map_ = CollectDeviceInfo(func); this->Run(func); return std::move(token_map_); } @@ -158,32 +187,24 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { protected: using StorageAllocaBaseVisitor::VisitExpr_; - void CreateToken(const ExprNode* op, bool can_realloc) final { + void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, + bool can_realloc) override { ICHECK(!token_map_.count(op)); std::vector tokens; - int device_type = - node_device_map_.count(GetRef(op)) ? node_device_map_[GetRef(op)]->value : 0; - if (const auto* tuple_type = op->checked_type().as()) { - for (Type t : tuple_type->fields) { - const auto* ttype = t.as(); - ICHECK(ttype); - StorageToken* token = arena_->make(); - token->ttype = ttype; - token->device_type = device_type; - tokens.push_back(token); - } - } else { - const auto* ttype = op->checked_type().as(); - ICHECK(ttype); + for (const auto& ttype : FlattenTupleType(op->checked_type())) { StorageToken* token = arena_->make(); token->ttype = ttype; - token->device_type = device_type; + // TODO(mbs): Should be TargetDevice. + token->device.device_type = device_type; + token->device.device_id = 0; tokens.push_back(token); } token_map_[op] = tokens; } - void VisitExpr_(const CallNode* op) final { + using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; + + void DeviceAwareVisitExpr_(const CallNode* op) final { // create token for the call node. CreateToken(op, true); @@ -198,13 +219,15 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { private: // allocator support::Arena* arena_; - Map node_device_map_; }; +/*! \brief Associate storage with every expression, reusing storage where possible. */ class StorageAllocator : public StorageAllocaBaseVisitor { public: + StorageAllocator() = default; + /*! - * \return totoal number of bytes allocated + * \return total number of bytes allocated */ size_t TotalAllocBytes() const { size_t total = 0; @@ -231,12 +254,12 @@ class StorageAllocator : public StorageAllocaBaseVisitor { std::vector sid_sizes_byte; for (StorageToken* tok : kv.second) { - if (tok->device_type) { + if (tok->is_valid()) { num_annotated_nodes++; } num_nodes++; storage_ids.push_back(tok->storage_id); - device_types.push_back(static_cast(tok->device_type)); + device_types.push_back(static_cast(tok->device.device_type)); sid_sizes_byte.push_back(GetMemorySize(tok)); } auto storage_info = backend::StorageInfo(storage_ids, device_types, sid_sizes_byte); @@ -253,21 +276,21 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } protected: - using StorageAllocaBaseVisitor::VisitExpr_; // override create token by getting token as prototype requirements. - void CreateToken(const ExprNode* op, bool can_realloc) final { + void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, bool can_realloc) final { ICHECK(!token_map_.count(op)); auto it = prototype_.find(op); ICHECK(it != prototype_.end()); std::vector tokens; for (StorageToken* tok : it->second) { + ICHECK_EQ(tok->device.device_type, device_type); if (can_realloc) { tokens.push_back(Request(tok)); } else { // Allocate a new token, StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok)); - allocated_tok->device_type = tok->device_type; + allocated_tok->device = tok->device; // ensure it never get de-allocated. allocated_tok->ref_counter += 1; tokens.push_back(allocated_tok); @@ -275,6 +298,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } token_map_[op] = tokens; } + // Mark op to reuse the input_token // tie the two memories together void ReuseInputToken(const ExprNode* op, StorageToken* input_token) { @@ -291,8 +315,10 @@ class StorageAllocator : public StorageAllocaBaseVisitor { token_map_[op] = {input_token}; } + using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; + // The call map - void VisitExpr_(const CallNode* op) final { + void DeviceAwareVisitExpr_(const CallNode* op) final { std::vector args; // for each input, visit argument token. for (Expr arg : op->args) { @@ -364,8 +390,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { * \return The required memory size. */ size_t GetMemorySize(StorageToken* prototype) { - const TensorTypeNode* ttype = prototype->ttype; - ICHECK(ttype != nullptr); + TensorType ttype = prototype->ttype; + ICHECK(ttype.defined()); size_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); @@ -394,7 +420,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { StorageToken* tok = it->second; - if (tok->device_type != prototype->device_type) continue; + if (!tok->is_compatible(*prototype)) continue; ICHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy tok->max_bytes = std::max(size, tok->max_bytes); @@ -407,7 +433,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { for (auto it = mid; it != begin;) { --it; StorageToken* tok = it->second; - if (tok->device_type != prototype->device_type) continue; + if (!tok->is_compatible(*prototype)) continue; ICHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy tok->max_bytes = std::max(size, tok->max_bytes); @@ -452,7 +478,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // all the storage resources available std::vector data_; /*! \brief internal prototype token map */ - std::unordered_map > prototype_; + std::unordered_map> prototype_; }; StaticMemoryPlan GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index df14b9e078b6b..5c667bd0f6cf5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -35,9 +35,9 @@ #include #include +#include "../attrs/annotation.h" #include "../transforms/pass_utils.h" -#include "compile_engine.h" -#include "te_compiler.h" +#include "./te_compiler.h" namespace tvm { namespace relay { @@ -439,6 +439,7 @@ class Interpreter : public ExprFunctor, const Array& all_prim_shape_fn_vars, const Array& prim_shape_fn_states, size_t num_shape_inputs, size_t num_shape_outputs, + Target prim_shape_target, const std::vector& args) { ICHECK(prim_shape_fn_var.defined()); ICHECK(prim_shape_fn_states.defined()); @@ -460,11 +461,10 @@ class Interpreter : public ExprFunctor, Device shape_device; shape_device.device_type = kDLCPU; shape_device.device_id = 0; - Target shape_target("llvm"); // 'Compile' the TIR shape function to appropriate callable form. PackedFunc packed_shape_func = - TIRToPackedFunc(prim_shape_fn_var, all_prim_shape_fn_vars, shape_target); + TIRToPackedFunc(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_target); size_t arity = num_shape_inputs + num_shape_outputs; std::vector values(arity); @@ -481,13 +481,13 @@ class Interpreter : public ExprFunctor, // flattened form of this arg. Does that match what lowering actually does? int64_t state = prim_shape_fn_states[i]->value; for (const auto& nd_array : FlattenADT(args[i])) { - if (state & kNeedInputData) { + if (state & tec::kNeedInputData) { auto arr = nd_array.CopyTo(shape_device); inputs[arg_counter] = arr; setter(arg_counter, arr); ++arg_counter; } - if (state & kNeedInputShape) { + if (state & tec::kNeedInputShape) { int64_t ndim = nd_array.Shape().size(); NDArray shape_arr; if (ndim == 0) { @@ -553,16 +553,17 @@ class Interpreter : public ExprFunctor, * @return Result of primitive. */ ObjectRef InvokePrimitiveOp(const GlobalVar& prim_fn_var, const Array all_prim_fn_vars, - const GlobalVar& prim_shape_fn_var, + Target prim_target, const GlobalVar& prim_shape_fn_var, const Array& all_prim_shape_fn_vars, const Array& prim_shape_fn_states, size_t num_shape_inputs, - size_t num_shape_outputs, const std::vector& args) { + size_t num_shape_outputs, Target prim_shape_target, + const std::vector& args) { ICHECK(prim_fn_var->checked_type().defined()); const FuncTypeNode* ftn = prim_fn_var->checked_type().as(); ICHECK(ftn); // 'Compile' the TIR primitive to appropriate callable form (on the desired target). - PackedFunc packed_func = TIRToPackedFunc(prim_fn_var, all_prim_fn_vars, target_); + PackedFunc packed_func = TIRToPackedFunc(prim_fn_var, all_prim_fn_vars, prim_target); // Argument tuples are flattened. std::vector arg_nd_arrays = FlattenADTs(args); @@ -596,7 +597,7 @@ class Interpreter : public ExprFunctor, ICHECK(prim_shape_fn_states.defined()); runtime_shapes = ComputeDynamicShape(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, - num_shape_inputs, num_shape_outputs, args); + num_shape_inputs, num_shape_outputs, prim_shape_target, args); ICHECK_EQ(runtime_shapes.size(), result_tensor_types.size()); } @@ -676,26 +677,33 @@ class Interpreter : public ExprFunctor, return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); } - ObjectRef VisitExpr_(const CallNode* call) final { + ObjectRef VisitExpr_(const CallNode* call_node) final { std::vector args; - for (auto arg : call->args) { + for (auto arg : call_node->args) { args.push_back(Eval(arg)); } + if (call_node->op == OnDeviceOp()) { + // Special case: The call 'on_device(expr)' denotes that expr should be executed on + // a particular device. We can ignore this during interpretation. + ICHECK_EQ(call_node->args.size(), 1UL); + return args[0]; + } + // We should not find calls to operators after running fusion and lowering. - if (const OpNode* op_node = call->op.as()) { + if (const OpNode* op_node = call_node->op.as()) { LOG(FATAL) << "found " << op_node->name << "; operators should have been removed by previous passes; try " "fusing and lowering"; } - if (const ConstructorNode* con = call->op.as()) { + if (const ConstructorNode* con = call_node->op.as()) { // Special case: ADT constructor return ConstructorValue(con->tag, args, GetRef(con)); } - if (const GlobalVarNode* gvn = call->op.as()) { - if (const TIRCallAttrs* attrs = call->attrs.as()) { + if (const GlobalVarNode* gvn = call_node->op.as()) { + if (const TIRCallAttrs* attrs = call_node->attrs.as()) { // Special case: Call a lowered TIR function. // TODO(mbs): Make calling convention first-class in Relay. Array all_prim_fn_vars; @@ -726,23 +734,30 @@ class Interpreter : public ExprFunctor, num_shape_outputs = static_cast( Downcast(attrs->metadata.at("prim_shape_fn_num_outputs"))->value); } + Target prim_target; + if (attrs->metadata.count(tvm::attr::kTarget)) { + prim_target = Downcast(attrs->metadata.at(tvm::attr::kTarget)); + } + Target prim_shape_target; + if (attrs->metadata.count("shape_target")) { + prim_shape_target = Downcast(attrs->metadata.at("shape_target")); + } - // Special case: Call TIR primitive. - return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, prim_shape_fn_var, - all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, - num_shape_outputs, args); + return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, prim_target, + prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, + num_shape_inputs, num_shape_outputs, prim_shape_target, args); } } // Now we just evaluate and expect to find a closure. - ObjectRef fn_val = Eval(call->op); + ObjectRef fn_val = Eval(call_node->op); if (const InterpreterClosureObj* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); return Invoke(closure, args); } else if (const RecClosureObj* closure_node = fn_val.as()) { return Invoke(closure_node->clos, args, closure_node->bind); } else { - LOG(FATAL) << "internal error: type error, expected function value in the call " + LOG(FATAL) << "internal error: type error, expected function value in the call_node " << "position"; return ObjectRef(); } @@ -898,25 +913,30 @@ IRModule Prepare(IRModule mod, Device device, Target target) { // Things to initialize to pass into tec::LowerTEPass // We only have one device-specific target. tec::TargetMap targets = {{device.device_type, target}}; - - // All calls to primitives will use the unique target. - tec::DeviceMap device_map; + if (device.device_type != kDLCPU) { + // However some primitives (eg dynamic shape functions) must always execute on the CPU, + // so make sure we have a target for that. + targets.emplace(kDLCPU, Target("llvm")); + } // No need for a memory plan. backend::StaticMemoryPlan memory_plan; /*=nullptr*/ // Run minimal transforms on module to establish invariants needed by interpreter. - transform::Sequential seq( - {transform::SimplifyInference(), - // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' - // attribute. - transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), - // eta expand to support constructors in argument position - transform::EtaExpand( - /*expand_constructor=*/true, /*expand_global_var=*/false), - transform::InferType(), - tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp", - [](Function func) { /* no-op */ })}); + transform::Sequential seq({transform::SimplifyInference(), + // Figure out which devices should be used to execute. + transform::PlanDevices(device.device_type), + // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' + // attribute. + transform::FuseOps(/*fuse_opt_level=*/0), + // Use ANF to reduce number of cases to handle. + transform::ToANormalForm(), + // eta expand to support constructors in argument position. + transform::EtaExpand( + /*expand_constructor=*/true, /*expand_global_var=*/false), + transform::InferType(), + tec::LowerTEPass(targets, memory_plan, /*module_name=*/"intrp", + [](Function func) { /* no-op */ })}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); @@ -968,6 +988,9 @@ class NeedsPreparationVisitor : public ExprVisitor { TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, Target target) { + VLOG_CONTEXT << "EvalFunction"; + VLOG(1) << "evaling module:\n" << PrettyPrint(mod) << "and expression:\n" << PrettyPrint(expr); + // // Step 1: Prepare mod. // @@ -1025,6 +1048,8 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De ICHECK(closure->func.defined()); return TypedPackedFunc)>([intrp, closure](Array args) { + VLOG_CONTEXT << "EvalFunction::Apply"; + VLOG(1) << "evaling closure with " << args.size() << " arguments"; // // Step 3: Apply closure to arguments. // diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 0393fdfec70d1..bbb99e114dff7 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -42,8 +42,10 @@ #include #include -#include "te_compiler_cache.h" -#include "utils.h" +#include "../attrs/annotation.h" +#include "../transforms/device_planner.h" +#include "./te_compiler_cache.h" +#include "./utils.h" namespace tvm { namespace relay { @@ -366,14 +368,12 @@ std::tuple IsDeviceCopy(const Function& func) { * ... %p(...) ... * \endcode */ -class LowerTensorExprMutator : public ExprMutator { +class LowerTensorExprMutator : public DeviceAwareExprMutator { public: - LowerTensorExprMutator(const IRModule& module, const TargetMap& targets, - const DeviceMap& device_ctx_map, ProcessFn process_fn, + LowerTensorExprMutator(const IRModule& module, const TargetMap& targets, ProcessFn process_fn, const String& module_name, TECompiler compiler) : module_(module), targets_(targets), - device_context_map_(device_ctx_map), process_fn_(process_fn), module_name_(module_name), compiler_(compiler), @@ -436,11 +436,10 @@ class LowerTensorExprMutator : public ExprMutator { } // Non-External Relay Function - DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n" - << PrettyPrint(func); + VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func); CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); - DLOG(INFO) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; + VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; // Collect all the lowered functions produced for this primitive function. Map prim_fns; @@ -449,8 +448,7 @@ class LowerTensorExprMutator : public ExprMutator { CHECK(prim_fn.second.as()) << "must be a prim fn"; prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); all_prim_fn_vars.push_back(prim_fn.first); - DLOG(INFO) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) - << "'"; + VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'"; } // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT @@ -478,6 +476,7 @@ class LowerTensorExprMutator : public ExprMutator { tir_call_attrs->metadata.Set("relay_attrs", func->attrs); tir_call_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); + tir_call_attrs->metadata.Set(tvm::attr::kTarget, lowered_func->target); if (IsDynamic(func->ret_type)) { // Also lower the dynamic shape function. @@ -486,8 +485,8 @@ class LowerTensorExprMutator : public ExprMutator { // on the host cpu irrespective of where the primitive runs. // TODO(mbs): Cleanup target handling. Target shape_target("llvm"); - DLOG(INFO) << "lowering to target '" << shape_target->str() - << "' for dynamic shape function for primitive"; + VLOG(1) << "lowering to target '" << shape_target->str() + << "' for dynamic shape function for primitive"; CCacheKey shape_key(func, shape_target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -508,42 +507,42 @@ class LowerTensorExprMutator : public ExprMutator { all_prim_shape_fn_vars.push_back(prim_shape_fn.first); } tir_call_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); + tir_call_attrs->metadata.Set("shape_target", lowered_shape_func->target); } return {lowered_func->prim_fn_var, Attrs(tir_call_attrs)}; } - Expr VisitExpr_(const LetNode* let) override { - Var var = Downcast(Mutate(let->var)); - Expr value = Mutate(let->value); - Function prim_func = ResolveToPrimitive(value); + std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { + Var new_var = Downcast(Mutate(var)); + Expr new_value = Mutate(value); + Function prim_func = ResolveToPrimitive(new_value); if (prim_func.defined()) { - // Remember let var is bound to (possibly indirectly) to a primitive. - primitive_functions_.emplace(let->var, prim_func); + // Remember let var is bound (possibly indirectly) to a primitive. + primitive_functions_.emplace(var, prim_func); } - Expr body = Mutate(let->body); + return {new_var, new_value}; + } + + Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) final { + Function prim_func = ResolveToPrimitive(post_let_node->value); if (prim_func.defined()) { // Leaving let var scope. - primitive_functions_.erase(let->var); - } - if (var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body)) { - return GetRef(let); - } else { - return Let(var, value, body, let->span); + primitive_functions_.erase(pre_let_node->var); } + return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); } - Expr VisitExpr_(const CallNode* call) override { - Call expr = GetRef(call); - + Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { + Call call = GetRef(call_node); // Look for (indirect) calls to primitives. - Function prim_func = ResolveToPrimitive(call->op); + Function prim_func = ResolveToPrimitive(call_node->op); if (!prim_func.defined()) { - // Not a call to a primitive function. - if (const FunctionNode* fn = call->op.as()) { + // Not a call_node to a primitive function. + if (const FunctionNode* fn = call_node->op.as()) { this->process_fn_(GetRef(fn)); } - return ExprMutator::VisitExpr_(call); + return ExprMutator::VisitExpr_(call_node); } // Find the desired target device. @@ -551,17 +550,11 @@ class LowerTensorExprMutator : public ExprMutator { if (prim_func->GetAttr(attr::kCompiler).defined()) { // The generic 'external device' target. target = Target("ext_dev"); - } else if (device_context_map_.empty() && targets_.size() == 1) { - // The unique target. - target = GetTargetFromInteger(kDLCPU, targets_); } else { - // The target corresponding to the call expression's annotation. - auto itr = device_context_map_.find(expr); - ICHECK(itr != device_context_map_.end()) - << "Could not find an entry in the device context map for " << expr - << "The memory planning was either not performed for this precise node, or there is " - "bug in the memory planner."; - target = GetTargetFromInteger(itr->second.device_type, targets_); + // The target corresponding to the call_node expression's annotation. + DLDeviceType device_type = GetInScopeDeviceType(call); + // TODO(mbs): Replace device_type with target so this lookup is unnecessary. + target = GetTargetFromInteger(device_type, targets_); } // Lower the primitive function for that target. @@ -569,18 +562,17 @@ class LowerTensorExprMutator : public ExprMutator { // Similarly transform arguments. Array args; - for (const auto& arg : call->args) { + for (const auto& arg : call_node->args) { args.push_back(VisitExpr(arg)); } - // Replace with direct call to lowered primitive, and attach annotations to record calling + // Replace with direct call_node to lowered primitive, and attach annotations to record calling // convention. return Call(pair.first, args, pair.second); } IRModule module_; TargetMap targets_; - DeviceMap device_context_map_; ProcessFn process_fn_; // Map from in-scope let-bound variables to Relay functions known to be // primitive. We'll rewrite these to the fresh global vars bound to the lowered @@ -596,18 +588,17 @@ class LowerTensorExprMutator : public ExprMutator { const Op& debug_op_; }; -Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - TECompiler compiler, std::function process_fn) { +Pass LowerTensorExpr(TargetMap targets, const String& module_name, TECompiler compiler, + std::function process_fn) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn, - module_name, compiler); + LowerTensorExprMutator lower_te(module, targets, process_fn, module_name, compiler); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); } +// TODO(mbs): Remove once flow targets through 'device' planning intead of device types. Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { if (targets.size() == 1) { // The homogeneous execution case, return the only target. @@ -843,11 +834,8 @@ void UpdateFunctionMetadata(Function relay_func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } -IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - std::function process_fn) { - DLOG(INFO) << "lowering module:\n" << PrettyPrint(module); - +IRModule LowerTE(const IRModule& module, TargetMap targets, backend::StaticMemoryPlan memory_plan, + const String& module_name, std::function process_fn) { TECompiler compiler; backend::FunctionInfo func_info; @@ -856,8 +844,7 @@ IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_con func_info = UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); } - auto updated_module = LowerTensorExpr(targets, device_context_map, memory_plan, module_name, - compiler, process_fn)(module); + auto updated_module = LowerTensorExpr(targets, module_name, compiler, process_fn)(module); // A temporary solution until we can rewrite the auto-scheduler task extraction code to work // in a more reasonable way. @@ -919,12 +906,11 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - std::function process_fn) { +Pass LowerTEPass(TargetMap targets, backend::StaticMemoryPlan memory_plan, + const String& module_name, std::function process_fn) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { - return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn); + return LowerTE(module, targets, memory_plan, module_name, process_fn); }; return tvm::transform::Sequential( {tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()}); diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 9d0eb1078ee03..247c85af086c8 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -172,7 +172,6 @@ Map GetPerTargetModules(IRModule mod); * * \param module The IRModule. * \param targets The mapping for devices to targets. - * \param device_map An analysis result mapping each sub-expression to a device. * \param memory_plan The memory plan used during lowering * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process @@ -180,7 +179,7 @@ Map GetPerTargetModules(IRModule mod); * \return The lowered module, see above. */ IRModule LowerTE( - const IRModule& module, TargetMap targets, DeviceMap device_map, + const IRModule& module, TargetMap targets, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](Function f) {}); @@ -191,16 +190,14 @@ IRModule LowerTE( * with their target. * * \param targets The mapping for devices to targets. - * \param device_context_map An analysis result mapping each sub-expression to a device. * \param memory_plan The memory plan used during lowering * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower * \returns The pass which lowers primative functions to TIR */ -transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, - backend::StaticMemoryPlan memory_plan, const String& module_name, - std::function process_fn); +transform::Pass LowerTEPass(TargetMap targets, backend::StaticMemoryPlan memory_plan, + const String& module_name, std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b3eab91d202c2..9b151164ab8e2 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -45,10 +45,12 @@ #include #include "../../../target/source/codegen_source_base.h" +#include "../../attrs/annotation.h" #include "../../op/op_common.h" +#include "../../transforms/device_planner.h" #include "../../transforms/pass_utils.h" #include "../utils.h" -#include "compiler.h" +#include "./compiler.h" namespace tvm { namespace relay { @@ -247,15 +249,10 @@ int GetFallbackDevice() { return fallback_dev->value; } -class VMFunctionCompiler : ExprFunctor { +class VMFunctionCompiler : DeviceAwareExprVisitor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host, - ExprDeviceMap expr_device_map) - : last_register_(0), - registers_num_(0), - context_(context), - target_host_(target_host), - expr_device_map_(std::move(expr_device_map)) { + VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + : last_register_(0), registers_num_(0), context_(context), target_host_(target_host) { CheckAndUpdateHostConsistency(&targets, &target_host); for (const auto& it : targets) { targets_[it.first->value] = it.second; @@ -264,44 +261,48 @@ class VMFunctionCompiler : ExprFunctor { } VMFunction Compile(const GlobalVar& var, const Function& func) { - size_t i = 0; - // We then assign register num to the free variables - for (auto param : func->params) { - auto arg_register = NewRegister(); - ICHECK_EQ(i, arg_register); - var_register_map_.insert({param, arg_register}); - params_.push_back(param->name_hint()); - ++i; - } - + std::vector params_device_type; if (IsClosure(func)) { + // After lifting we'll have functions of the form: + // fn(closure args) { fn(lifted function args) { body } } + // But the closure body will be: + // fn(closure args, lifter function args) { body } + // Do that flattening on-the-fly. Function inner_func = Downcast(func->body); - for (auto param : inner_func->params) { - auto arg_register = NewRegister(); - ICHECK_EQ(i, arg_register); - var_register_map_.insert({param, arg_register}); - params_.push_back(param->name_hint()); - ++i; + std::vector params; + std::vector param_device_types; + params.reserve(func->params.size() + inner_func->params.size()); + param_device_types.reserve(func->params.size() + inner_func->params.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + params.emplace_back(func->params[i]); + params_device_type.push_back(GetFunctionParamDeviceType(func.get(), i)); + } + for (size_t i = 0; i < inner_func->params.size(); ++i) { + params.emplace_back(inner_func->params[i]); + params_device_type.push_back(GetFunctionParamDeviceType(inner_func.get(), i)); + } + std::vector type_params; + type_params.reserve(func->type_params.size() + inner_func->type_params.size()); + for (const auto& tyvar : func->type_params) { + type_params.push_back(tyvar); + } + for (const auto& tyvar : inner_func->type_params) { + type_params.push_back(tyvar); } - this->VisitExpr(inner_func->body); + Function flattened_func = Function(params, inner_func->body, inner_func->ret_type, + type_params, func->attrs, func->span); + VisitExpr(FunctionOnDevice(flattened_func, params_device_type, + GetFunctionResultDeviceType(inner_func.get()))); } else { - this->VisitExpr(func->body); - } - instructions_.push_back(Instruction::Ret(last_register_)); - - std::vector params_device_type; - for (const auto& it : func->params) { - if (!expr_device_map_.empty()) { - ICHECK_GT(expr_device_map_.count(it), 0U); - params_device_type.push_back(expr_device_map_[it].device_type); - } else { - ICHECK_EQ(targets_.size(), 1U); - params_device_type.push_back((targets_.begin())->first); + params_device_type.reserve(func->params.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + params_device_type.push_back(GetFunctionParamDeviceType(func.get(), i)); } + VisitExpr(func); } - return VMFunction(var->name_hint, params_, instructions_, registers_num_, params_device_type); } + /*! \brief Attrs objects for each op. */ std::map> op_attrs; @@ -342,29 +343,26 @@ class VMFunctionCompiler : ExprFunctor { instructions_.push_back(instr); } - void VisitExpr_(const ConstantNode* const_node) { + using DeviceAwareExprVisitor::VisitExpr_; + + void VisitExpr_(const ConstantNode* const_node) final { // Check the shape is valid NDArray data = const_node->data; size_t konst_idx = context_->constants.size(); - if (expr_device_map_.empty()) { - context_->const_device_type.push_back(targets_.begin()->first); - } else { - auto con = GetRef(const_node); - ICHECK_GT(expr_device_map_.count(con), 0U); - context_->const_device_type.push_back(expr_device_map_[con].device_type); - } + auto con = GetRef(const_node); + context_->const_device_type.push_back(GetInScopeDeviceType(con)); context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); } - void VisitExpr_(const VarNode* var_node) { + void VisitExpr_(const VarNode* var_node) final { auto var = GetRef(var_node); auto reg_it = this->var_register_map_.find(var); ICHECK(reg_it != this->var_register_map_.end()); last_register_ = reg_it->second; } - void VisitExpr_(const TupleNode* tuple_node) { + void VisitExpr_(const TupleNode* tuple_node) final { auto tuple = GetRef(tuple_node); std::vector fields_registers; @@ -377,35 +375,28 @@ class VMFunctionCompiler : ExprFunctor { Emit(Instruction::AllocADT(0, tuple->fields.size(), fields_registers, NewRegister())); } - void VisitExpr_(const MatchNode* match_node) { + void VisitExpr_(const MatchNode* match_node) final { auto match = GetRef(match_node); this->VisitExpr(match->data); CompileMatch(match); } - void VisitExpr_(const LetNode* l) final { - Expr let_binding = GetRef(l); - const LetNode* let; - while ((let = let_binding.as())) { - ICHECK(!let->value.as()) - << "invariant violated, inner functions should not exist (did you set opt_level = 2?)"; - VisitExpr(let->value); - var_register_map_.insert({let->var, this->last_register_}); - let_binding = let->body; - } - - VisitExpr(let_binding); + void PreVisitLetBinding_(const Var& var, const Expr& value) final { + ICHECK(!value.as()) + << "invariant violated, inner functions should not exist (did you set opt_level = 2?)"; + VisitExpr(value); + var_register_map_.emplace(var, this->last_register_); } - void VisitExpr_(const TupleGetItemNode* get_node) { + void VisitExpr_(const TupleGetItemNode* get_node) final { auto get = GetRef(get_node); this->VisitExpr(get->tuple); auto tuple_register = last_register_; Emit(Instruction::GetField(tuple_register, get->index, NewRegister())); } - void VisitExpr_(const GlobalVarNode* gvar) { + void VisitExpr_(const GlobalVarNode* gvar) final { auto var = GetRef(gvar); auto func = context_->module->Lookup(var); auto it = context_->global_map.find(var); @@ -414,7 +405,7 @@ class VMFunctionCompiler : ExprFunctor { Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister())); } - void VisitExpr_(const IfNode* if_node) { + void VisitExpr_(const IfNode* if_node) final { this->VisitExpr(if_node->cond); size_t test_register = last_register_; @@ -501,8 +492,9 @@ class VMFunctionCompiler : ExprFunctor { void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { std::vector argument_registers; - ICHECK(func->GetAttr(attr::kPrimitive, 0) != 0) - << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; + ICHECK(func->HasNonzeroAttr(attr::kPrimitive)) + << "internal error: invoke_tvm_op requires the first argument to be a primitive " + "relay::Function"; auto input_tuple = inputs.as(); ICHECK(input_tuple) << "internal error: invoke_tvm_op inputs must be a tuple," @@ -526,31 +518,21 @@ class VMFunctionCompiler : ExprFunctor { Target target; + // Which target should execute the function? if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); } else { - // Next generate the invoke instruction. - if (expr_device_map_.empty()) { - // homogeneous execution. - ICHECK_EQ(targets_.size(), 1U); - const auto& it = targets_.begin(); - target = (*it).second; + int dev_type = GetInScopeDeviceType(func); + if (targets_.count(dev_type) == 0) { + target = CreateDefaultTarget(dev_type); } else { - ICHECK_GT(expr_device_map_.count(func), 0U) - << "Found not annotated expression, please make sure " - "context analysis has been executed"; - int dev_type = expr_device_map_[func].device_type; - if (targets_.count(dev_type) == 0) { - target = CreateDefaultTarget(dev_type); - } else { - target = targets_[expr_device_map_[func].device_type]; - } + target = targets_[dev_type]; } } CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; - auto cfunc = context_->compiler->Lower(key, mangle_fn); + auto cfunc = context_->compiler->Lower(key, mangle_fn); // <<<< one-func-at-a-time lowering auto op_index = -1; if (func->GetAttr(attr::kCompiler).defined()) { @@ -576,7 +558,7 @@ class VMFunctionCompiler : ExprFunctor { argument_registers)); } - void VisitExpr_(const CallNode* call_node) { + void DeviceAwareVisitExpr_(const CallNode* call_node) final { Expr op = call_node->op; // First we handle the case in which we are using an opaque @@ -646,20 +628,8 @@ class VMFunctionCompiler : ExprFunctor { ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs"; auto dtype = alloc_attrs->dtype; - Index device_type; - if (expr_device_map_.empty()) { - // TODO(zhiics) There is bug if all expressions are annotated with the device - // that is different the first one in the target list. - auto& kv = *(targets_.begin()); - device_type = kv.first; - } else { - ICHECK_GT(expr_device_map_.count(GetRef(call_node)), 0U) - << " The alloc_storage node is not annotated"; - device_type = expr_device_map_[GetRef(call_node)].device_type; - } - - Emit(Instruction::AllocStorage(size_register, alignment, dtype, device_type, - NewRegister())); + Emit(Instruction::AllocStorage(size_register, alignment, dtype, + alloc_attrs->device_type, NewRegister())); }) .Match("vm.shape_func", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -764,12 +734,31 @@ class VMFunctionCompiler : ExprFunctor { } } - void VisitExpr_(const FunctionNode* func_node) { - if (!func_node->HasNonzeroAttr(attr::kPrimitive)) { - LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl - << "Program: " << AsText(GetRef(func_node), false) << std::endl - << "AST: " << GetRef(func_node); + void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { + if (function_nesting() > 1) { + ICHECK(func_node->HasNonzeroAttr(attr::kPrimitive)) + << "local functions should have been removed by lambda lifting:" << std::endl + << "Program: " << AsText(GetRef(func_node), false) << std::endl + << "AST: " << GetRef(func_node); + return; + } + ICHECK(!IsClosure(GetRef(func_node))) + << "closures should have been flattened away by Compile"; + + size_t i = 0; + + // Assign a register num to each parameter. + for (auto param : func_node->params) { + auto arg_register = NewRegister(); + ICHECK_EQ(i, arg_register); + var_register_map_.insert({param, arg_register}); + params_.push_back(param->name_hint()); + ++i; } + + VisitExpr(func_node->body); + + instructions_.push_back(Instruction::Ret(last_register_)); } /*! @@ -862,8 +851,6 @@ class VMFunctionCompiler : ExprFunctor { std::unordered_map targets_; /*! \brief Host target. */ Target target_host_; - /*! \brief Map from Relay expr to device type. */ - ExprDeviceMap expr_device_map_; }; PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { @@ -930,15 +917,11 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe // the global state. exec_->functions.resize(context_.module->functions.size()); - // Collect the annotated device information. - // This indicates which device each Relay expr should be executed on. - ExprDeviceMap expr_device_map = AnalyzeContext(); - for (auto named_func : context_.module->functions) { auto gvar = named_func.first; if (auto* n = named_func.second.as()) { auto func = GetRef(n); - VMFunctionCompiler func_compiler(&context_, targets_, target_host_, expr_device_map); + VMFunctionCompiler func_compiler(&context_, targets_, target_host_); auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); @@ -1041,13 +1024,17 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, Array pass_seqs = relay::backend::GetPassPrefix(targets, true); - if (targets_.size() > 1) { - // Handle heterogeneous compilation. - int fallback_dev = GetFallbackDevice(); - pass_seqs.push_back(transform::RewriteAnnotatedOps(fallback_dev)); + DLDeviceType default_device_type; + if (targets_arg.size() == 1UL) { + default_device_type = + static_cast(static_cast((*targets_arg.begin()).first->value)); + } else { + default_device_type = static_cast(GetFallbackDevice()); } + pass_seqs.push_back(PlanDevices(default_device_type)); pass_seqs.push_back(transform::FuseOps()); + // Do layout rewrite for auto-scheduler. transform::PassContext pass_ctx = PassContext::Current(); if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) { @@ -1145,25 +1132,6 @@ void VMCompiler::Codegen() { exec_->SetLib(lib); } -ExprDeviceMap VMCompiler::AnalyzeContext() const { - Device default_device; - ExprDeviceMap expr_device_map; - if (targets_.size() > 1) { - int fallback_dev = GetFallbackDevice(); - default_device.device_type = static_cast(fallback_dev); - default_device.device_id = 0; - expr_device_map = ContextAnalysis(context_.module, default_device); - } else { - const auto& tgt = targets_.begin(); - default_device.device_type = static_cast((*tgt).first->value); - if (default_device.device_type != kDLCPU) { - default_device.device_id = 0; - expr_device_map = ContextAnalysis(context_.module, default_device); - } - } - return expr_device_map; -} - runtime::Module CreateVMCompiler() { auto exec = make_object(); return runtime::Module(exec); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index a05c52ced07f9..af3c5bccbeff2 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -63,7 +63,6 @@ using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; using TargetsMap = Map; -using ExprDeviceMap = std::unordered_map; struct VMCompilerContext { // The module context for the compilation @@ -108,8 +107,8 @@ class VMCompiler : public runtime::ModuleNode { * \brief Lower the functions in a Module * * \param mod Relay Module - * \param targets For heterogeneous compilation, it is a dictionary indicating context - * to target mapping. For homogeneous compilation, it is a build target. + * \param targets For heterogeneous compilation, it is a dictionary indicating device type + * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); @@ -122,8 +121,8 @@ class VMCompiler : public runtime::ModuleNode { * \brief Perform a series of optimizations on the input IR module. * * \param mod The input IRModule. - * \param targets For heterogeneous compilation, it is a dictionary indicating context - * to target mapping. For homogeneous compilation, it is a build target. + * \param targets For heterogeneous compilation, it is a dictionary indicating device type + * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target. * * \return The optimized IRModule. @@ -136,9 +135,6 @@ class VMCompiler : public runtime::ModuleNode { */ void PopulateGlobalMap(); - /*! \brief Analyze the device context of each expression. */ - ExprDeviceMap AnalyzeContext() const; - protected: /*! \brief Target devices. */ TargetsMap targets_; diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index c768a2c300ec1..c329c655f0285 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -26,13 +26,15 @@ #include #include #include -#include #include #include #include #include +#include "../../attrs/annotation.h" +#include "../../transforms/device_planner.h" + using namespace tvm::runtime; namespace tvm { @@ -44,7 +46,7 @@ inline std::string GenerateName(const Function& func) { return std::string("lifted_name") + std::to_string(hash); } -bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } +bool IsClosure(const Function& func) { return func->HasNonzeroAttr(attr::kClosure); } Function MarkClosure(Function func) { return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); @@ -56,39 +58,28 @@ Function MarkClosure(Function func) { * We will lift a function out into a global which takes the set of the free * vars and then return the new created function. */ -class LambdaLifter : public ExprMutator { +class LambdaLifter : public transform::DeviceAwareExprMutator { public: explicit LambdaLifter(const IRModule& module) : module_(module) {} - Expr VisitExpr_(const LetNode* let_node) final { - auto pre_visit = [this](const LetNode* op) { - bool is_lambda = false; - if (auto func = op->value.as()) { - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - is_lambda = true; - this->letrec_.push_back(op->var); - } + std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { + bool is_lambda = false; + if (const auto* func_node = value.as()) { + if (!func_node->HasNonzeroAttr(attr::kPrimitive)) { + is_lambda = true; + this->letrec_.push_back(var); } - Expr value = this->VisitExpr(op->value); + } + Expr new_value = this->VisitExpr(value); - if (is_lambda) { - this->letrec_.pop_back(); - } - }; - auto post_visit = [this](const LetNode* op) { - // Rely on the Memoizer to cache pre-visit values - Expr value = this->VisitExpr(op->value); - // Visit body and cache the op - Expr body = this->VisitExpr(op->body); - auto expr = GetRef(op); - this->memo_[expr] = Let(op->var, value, body); - }; - ExpandANormalForm(let_node, pre_visit, post_visit); - return memo_[GetRef(let_node)]; + if (is_lambda) { + this->letrec_.pop_back(); + } + return {var, new_value}; } - Expr VisitExpr_(const CallNode* call_node) final { - auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + Expr DeviceAwareVisitExpr_(const CallNode* call_node) final { + auto call = Downcast(DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node)); if (auto var_node = call_node->op.as()) { auto var = GetRef(var_node); if (!letrec_.empty() && var == letrec_.back()) { @@ -100,20 +91,27 @@ class LambdaLifter : public ExprMutator { return std::move(call); } - Expr VisitExpr_(const FunctionNode* func_node) final { + Expr DeviceAwareVisitExpr_(const FunctionNode* func_node) final { auto func = GetRef(func_node); - // We should not transform primitive functions. if (func->HasNonzeroAttr(attr::kPrimitive)) { + // We should not transform primitive functions. return std::move(func); } + if (function_nesting() == 1) { + // We don't need to lift global functions. + return Function(func_node->params, VisitExpr(func_node->body), func_node->ret_type, + func_node->type_params, func_node->attrs, func_node->span); + } + auto name = GenerateName(func); auto global = GlobalVar(name); auto free_vars = FreeVars(func); auto free_type_vars = FreeTypeVars(func, module_); Array captured_vars; + std::vector captured_var_device_types; bool recursive = false; for (const auto& var : free_vars) { if (!letrec_.empty() && var == letrec_.back()) { @@ -121,8 +119,10 @@ class LambdaLifter : public ExprMutator { continue; } captured_vars.push_back(var); + captured_var_device_types.push_back(GetInScopeDeviceType(var)); } + // Freshen all the captured vars. Array typed_captured_vars; Map rebinding_map; for (auto free_var : captured_vars) { @@ -131,6 +131,8 @@ class LambdaLifter : public ExprMutator { rebinding_map.Set(free_var, var); } + DLDeviceType result_device_type = GetInScopeDeviceType(func_node->body); + if (recursive) { if (!captured_vars.empty()) { Array fvs; @@ -143,7 +145,7 @@ class LambdaLifter : public ExprMutator { } } - auto body = Downcast(ExprMutator::VisitExpr_(func_node)); + auto body = Downcast(DeviceAwareExprMutator::DeviceAwareVisitExpr_(func_node)); // When performing this optimization there are two cases. // @@ -168,8 +170,9 @@ class LambdaLifter : public ExprMutator { // The "inner" function should be used to generate the // code for the closure. Function lifted_func; - if (captured_vars.size() == 0 && free_type_vars.size() == 0) { - lifted_func = Function(body->params, body->body, body->ret_type, body->type_params); + if (captured_vars.empty() && free_type_vars.empty()) { + lifted_func = Function(body->params, body->body, body->ret_type, body->type_params, + body->attrs, body->span); } else { // When a closure is locally bound in a program, we have its full type information // avalible to us. @@ -183,13 +186,15 @@ class LambdaLifter : public ExprMutator { // bind to go from unannotated free variables -> annotated free variables and then // construct the "closure" function with fully annotated arguments, no longer relying // on type inference. - auto before = Downcast(body)->params.size(); + size_t before_arity = body->params.size(); auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, func->type_params, func->attrs, func->span); - auto after = Downcast(rebound_body)->params.size(); - CHECK_EQ(before, after); + size_t after_arity = rebound_body->params.size(); + CHECK_EQ(before_arity, after_arity); lifted_func = - Function(typed_captured_vars, rebound_body, func->func_type_annotation(), free_type_vars); + Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), + free_type_vars, /*attrs=*/{}, func->span); + lifted_func = FunctionOnDevice(lifted_func, captured_var_device_types, result_device_type); lifted_func = MarkClosure(lifted_func); } @@ -206,7 +211,7 @@ class LambdaLifter : public ExprMutator { module_->Add(global, lifted_func); } - if (captured_vars.size() == 0) { + if (captured_vars.empty()) { return std::move(global); } else { // If we need to allocate a closure, @@ -226,9 +231,7 @@ class LambdaLifter : public ExprMutator { if (auto* n = pair.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); - module_->Add(pair.first, func, true); + module_->Add(pair.first, Downcast(Mutate(func)), /*update=*/true); } } return module_; diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 89f22cfb25b21..4a554a58f1f56 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -298,7 +298,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { auto call_node = expr.as(); - for (auto node : expr_graph_.node_map_.at(expr)->inputs_) { + for (auto node : expr_graph_[expr]->inputs_) { if (!(call_node && node->ref_ == call_node->op)) { memoize_ = true; if (VisitDFPattern(op->parent, node->ref_)) { @@ -322,7 +322,7 @@ bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Exp while (!stack.empty()) { Expr current = stack.top(); stack.pop(); - for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) { + for (auto node : expr_graph_[current]->dominator_children_) { if (visited.count(node->ref_) == 0) { if (VisitDFPattern(op->parent, node->ref_)) { return true; @@ -707,11 +707,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { return; } else if (kv.second != body) { // if the node isn't the output of the group - auto node = matcher_->expr_graph_.node_map_.at(kv.first); + auto node = matcher_->expr_graph_[kv.first]; for (auto* output : node->outputs_) { // and the node is used by nodes outside of the group if (memo.count(output->ref_) == 0 && - !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { + !matcher_->expr_graph_[expr]->Dominates(output)) { // Exit because nodes in this pattern's body are used outside the pattern // fusing it would be invalid return; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3b3c8797d7f20..47dca1a29f54f 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -120,6 +120,31 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s data_ = std::move(n); } +Call Call::CopyWith(Optional opt_op, Optional> opt_args, + Optional opt_attrs, Optional> opt_type_args, + Optional opt_span) { + Expr op = opt_op.value_or(get()->op); + Array args = opt_args.value_or(get()->args); + Attrs attrs = opt_attrs.value_or(get()->attrs); + Array type_args = opt_type_args.value_or(get()->type_args); + Span span = opt_span.value_or(get()->span); + if (op == get()->op && + std::equal(args.begin(), args.end(), get()->args.begin(), get()->args.end()) && + attrs == get()->type_args && + std::equal(type_args.begin(), type_args.end(), get()->type_args.begin(), + get()->type_args.end()) && + span == get()->span) { + return *this; + } + CallNode* new_call_node = CopyOnWrite(); + new_call_node->op = op; + new_call_node->args = args; + new_call_node->attrs = attrs; + new_call_node->type_args = type_args; + new_call_node->span = span; + return GetRef(new_call_node); +} + TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay.ir.Call") diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 5984a208efe0f..f3958fb080433 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -31,6 +31,8 @@ #include +#include "../attrs/annotation.h" + namespace tvm { namespace relay { MixedModeVisitor::MixedModeVisitor(int visit_limit) { @@ -527,15 +529,19 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; - for (Var param : func->params) { - if (!args_map.count(param)) { - new_params.push_back(param); + std::vector new_param_device_types; + for (size_t i = 0; i < func->params.size(); ++i) { + if (!args_map.count(func->params[i])) { + new_params.push_back(func->params[i]); + new_param_device_types.push_back(GetFunctionParamDeviceType(func, i)); } } if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { return expr; } - auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); + auto ret = + Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); + ret = FunctionOnDevice(ret, new_param_device_types, GetFunctionResultDeviceType(func)); std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); @@ -543,9 +549,15 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { for (const auto& v : FreeVars(ret)) { if (set.count(v) == 0) { new_params.push_back(v); + // TODO(mbs): We've lost the device context for any introduced vars. + LOG(WARNING) << "introduced free var '" << PrettyPrint(v) + << "' into function body but no device is known for it"; + new_param_device_types.push_back(kInvalidDeviceType); } } - ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); + ret = + Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); + ret = FunctionOnDevice(ret, new_param_device_types, GetFunctionResultDeviceType(func)); ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index e4d9585470a64..94c256c905e21 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -34,13 +34,15 @@ namespace relay { // IndexedGraph IndexedGraph CreateIndexedGraph(const Expr& expr) { + VLOG_CONTEXT << "CreateIndexedGraph"; + VLOG(1) << "creating for:" << std::endl << PrettyPrint(expr); using NodePtr = std::shared_ptr::Node>; /*! \brief Creator Creates an IndexedGraph and determintes Topological order */ class Creator : public MixedModeVisitor { public: IndexedGraph CreateGraph(const Expr& expr) { VisitExpr(expr); - graph_.node_map_[expr]->is_external_ = true; + graph_[expr]->is_external_ = true; return std::move(graph_); } @@ -48,7 +50,7 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { void VisitLeaf(const Expr& expr) override { MixedModeVisitor::VisitLeaf(expr); auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; + graph_.Set(expr, node); graph_.topological_order_.push_back(node); } IndexedGraph graph_; @@ -76,7 +78,7 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { /*! Default visitation pushes the parent to the child's outputs and the child to the parent's * inputs*/ void VisitExpr(const Expr& expr, NodePtr parent) override { - auto current = graph_.node_map_[expr]; + auto current = graph_[expr]; if (parent) { current->outputs_.push_back(parent.get()); parent->inputs_.push_back(current.get()); @@ -97,59 +99,59 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { void VisitExpr_(const TupleNode* op, NodePtr parent) override { for (auto field : op->fields) { - this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + this->VisitExpr(field, graph_[GetRef(op)]); } } void VisitExpr_(const FunctionNode* op, NodePtr parent) override { for (auto param : op->params) { - this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + this->VisitExpr(param, graph_[GetRef(op)]); } - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->body, graph_[GetRef(op)]); } void VisitExpr_(const CallNode* op, NodePtr parent) override { - this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->op, graph_[GetRef(op)]); for (auto ty_arg : op->type_args) { this->VisitType(ty_arg); } for (auto arg : op->args) { - this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + this->VisitExpr(arg, graph_[GetRef(op)]); } } void VisitExpr_(const LetNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->value, graph_[GetRef(op)]); + this->VisitExpr(op->var, graph_[GetRef(op)]); + this->VisitExpr(op->body, graph_[GetRef(op)]); } void VisitExpr_(const IfNode* op, NodePtr parent) override { - this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->cond, graph_[GetRef(op)]); + this->VisitExpr(op->true_branch, graph_[GetRef(op)]); + this->VisitExpr(op->false_branch, graph_[GetRef(op)]); } void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { - this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->tuple, graph_[GetRef(op)]); } void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->value, graph_[GetRef(op)]); } void VisitExpr_(const RefReadNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->ref, graph_[GetRef(op)]); } void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->ref, graph_[GetRef(op)]); + this->VisitExpr(op->value, graph_[GetRef(op)]); } void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { @@ -160,9 +162,9 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { } void VisitExpr_(const MatchNode* op, NodePtr parent) override { - this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->data, graph_[GetRef(op)]); for (const Clause& c : op->clauses) { - this->VisitClause(c, graph_.node_map_[GetRef(op)]); + this->VisitClause(c, graph_[GetRef(op)]); } } @@ -185,7 +187,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { public: IndexedGraph CreateGraph(const DFPattern& pattern) { VisitDFPattern(pattern); - graph_.node_map_[pattern]->is_external_ = true; + graph_[pattern]->is_external_ = true; return std::move(graph_); } @@ -194,7 +196,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { if (this->visited_.count(pattern.get()) == 0) { DFPatternVisitor::VisitDFPattern(pattern); auto node = std::make_shared::Node>(pattern, index_++); - graph_.node_map_[pattern] = node; + graph_.Set(pattern, node); graph_.topological_order_.push_back(node); } } @@ -222,7 +224,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { /*! Default visitation pushes the parent to the child's outputs */ void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { - auto current = graph_.node_map_[pattern]; + auto current = graph_[pattern]; if (parent) { current->outputs_.push_back(parent.get()); parent->inputs_.push_back(current.get()); @@ -232,19 +234,19 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { protected: IndexedGraph graph_; void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->left, graph_[GetRef(op)]); + VisitDFPattern(op->right, graph_[GetRef(op)]); } void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->pattern, graph_[GetRef(op)]); } void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->op, graph_[GetRef(op)]); if (op->args.defined()) { for (auto arg : op->args) { - VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + VisitDFPattern(arg, graph_[GetRef(op)]); } } } @@ -252,13 +254,13 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->pattern, graph_[GetRef(op)]); } void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->parent, graph_[GetRef(op)]); + VisitDFPattern(op->path, graph_[GetRef(op)]); + VisitDFPattern(op->child, graph_[GetRef(op)]); } void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} @@ -266,42 +268,42 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override { if (op->params.defined()) { for (auto param : op->params) { - VisitDFPattern(param, graph_.node_map_[GetRef(op)]); + VisitDFPattern(param, graph_[GetRef(op)]); } } - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->body, graph_[GetRef(op)]); } void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->pattern, graph_[GetRef(op)]); } void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->tuple, graph_[GetRef(op)]); } void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { if (op->fields.defined()) { for (auto field : op->fields) { - VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + VisitDFPattern(field, graph_[GetRef(op)]); } } } void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->cond, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->true_branch, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->cond, graph_[GetRef(op)]); + VisitDFPattern(op->true_branch, graph_[GetRef(op)]); + VisitDFPattern(op->false_branch, graph_[GetRef(op)]); } void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->var, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->value, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->var, graph_[GetRef(op)]); + VisitDFPattern(op->value, graph_[GetRef(op)]); + VisitDFPattern(op->body, graph_[GetRef(op)]); } void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->pattern, graph_[GetRef(op)]); } void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index d073bcaeea5c9..5d4edd7952afd 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -97,6 +97,7 @@ class IndexedGraph { return false; } }; + /*! \brief Construct the domination tree inside IndexedGraph */ void PostDom() { for (size_t i = topological_order_.size(); i != 0; --i) { @@ -113,8 +114,18 @@ class IndexedGraph { } } } - /*! \brief Map of input nodes to IndexedGraph Nodes */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map_; + + std::shared_ptr operator[](const T& item) const { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << "no entry for:" << std::endl << PrettyPrint(item); + return itr->second; + } + + void Set(const T& item, std::shared_ptr ptr) { + ICHECK(ptr != nullptr) << "null ptr for:" << std::endl << PrettyPrint(item); + node_map_[item] = std::move(ptr); + } + /*! \brief Topological IndexedGraph Nodes */ std::vector> topological_order_; @@ -150,6 +161,10 @@ class IndexedGraph { } return lhs; } + + private: + /*! \brief Map of input nodes to IndexedGraph Nodes */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map_; }; /*! \brief Create an Indexed Graph based on an Expr */ diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c2997fb6cf958..1a26be53e9100 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -34,10 +34,10 @@ #include +#include "../../attrs/annotation.h" #include "../../transforms/infer_layout_utils.h" #include "../op_common.h" #include "../type_relations.h" -#include "tvm/relay/attrs/device_copy.h" namespace tvm { namespace relay { @@ -97,14 +97,21 @@ RELAY_REGISTER_OP("memory.alloc_storage") return {topi::identity(inputs[0])}; }); -Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, +Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, Array assert_shape) { auto attrs = make_object(); attrs->dtype = dtype; if (assert_shape.defined()) { attrs->assert_shape = assert_shape; } else { - attrs->const_shape = Downcast(shape); + // Look through any on_device for the shape argument expression. + Expr literal_shape = shape; + auto props = GetOnDeviceProps(literal_shape); + if (props.body.defined()) { + // See through on_device calls. + literal_shape = props.body; + } + attrs->const_shape = Downcast(literal_shape); } static const Op& op = Op::Get("memory.alloc_tensor"); return Call(op, {storage, offset, shape}, Attrs(attrs), {}); @@ -307,36 +314,5 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.ToTupleType") return ToTupleType(t, std::vector(array.begin(), array.end())); }); -// relay.device_copy -TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); - -Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type) { - auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; - static const Op& op = Op::Get("device_copy"); - return Call(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relay.op._make.device_copy").set_body_typed(DeviceCopy); - -RELAY_REGISTER_OP("device_copy") - .describe(R"code( -Copy data from one tensor to another. The source and destination might be -on different devices. -)code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input data.") - .set_support_level(10) - .add_type_rel("Identity", IdentityRel) - .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h index bbbd11867549d..558c409782f57 100644 --- a/src/relay/op/memory/memory.h +++ b/src/relay/op/memory/memory.h @@ -33,7 +33,6 @@ namespace tvm { namespace relay { Expr AllocStorage(Expr size, Expr alignment, Device dev, DataType dtype_hint); -Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type); Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, Array assert_shape); Expr ToTupleType(const Type& ty, const std::vector& exprs); diff --git a/src/relay/quantize/partition.cc b/src/relay/quantize/partition.cc index c65cc18799327..46d5b2f7f4de5 100644 --- a/src/relay/quantize/partition.cc +++ b/src/relay/quantize/partition.cc @@ -26,7 +26,7 @@ #include -#include "../transforms/pattern_utils.h" +#include "../attrs/annotation.h" #include "./quantize.h" namespace tvm { diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 968628fbfe39c..6912e9c5f39d6 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -29,8 +29,8 @@ #include #include +#include "../attrs/annotation.h" #include "../qnn/utils.h" -#include "../transforms/pattern_utils.h" #include "./quantize.h" namespace tvm { diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 02f9d474411ab..bf90036dc2172 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -18,7 +18,7 @@ */ /*! - * \file deivce_annotation.cc + * \file device_annotation.cc * \brief Passes to rewrite annotated program and retrieve the device allocation * of expression. * @@ -39,17 +39,14 @@ #include #include +#include "../attrs/annotation.h" +#include "../attrs/device_copy.h" + namespace tvm { namespace relay { namespace { -bool IsOnDeviceNode(const ExprNode* node) { - if (!node->IsInstance()) return false; - const auto* call_node = static_cast(node); - return call_node->attrs.as(); -} - bool IsDeviceCopyNode(const ExprNode* node) { if (!node->IsInstance()) return false; const auto* call_node = static_cast(node); @@ -69,6 +66,13 @@ bool IsDeviceCopyNode(const ExprNode* node) { } // namespace +/*! + * \brief Builds a map from expression to device type based on existing + * "on_device" CallNodes. + * + * Only "on_device" CallNodes, their single args, and tuple-projection from such will be indexed. + */ +// TODO(mbs): Retire. class ValidateAnnotation : private ExprVisitor { public: static std::unordered_map Validate(const Expr& expr) { @@ -80,45 +84,33 @@ class ValidateAnnotation : private ExprVisitor { private: void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); - if (IsOnDeviceNode(call_node)) { - int device_type = GetDeviceId(call_node); + auto props = GetOnDeviceProps(call_node); + if (props.body.defined()) { if (annotation_map_.count(call_node)) { - ICHECK_EQ(annotation_map_.at(call_node), device_type) + ICHECK_EQ(annotation_map_.at(call_node), props.device_type) << "An expression node can only be annotated to one device."; } else { - annotation_map_.insert({call_node, GetDeviceId(call_node)}); + annotation_map_.insert({call_node, props.device_type}); } - ICHECK_EQ(call_node->args.size(), 1U); - const auto* node = call_node->args[0].operator->(); + const auto* node = props.body.get(); if (annotation_map_.count(node)) { - ICHECK_EQ(annotation_map_.at(node), device_type) + ICHECK_EQ(annotation_map_.at(node), props.device_type) << "An expression node can only be annotated to one device."; } else { - annotation_map_.insert({node, GetDeviceId(call_node)}); + annotation_map_.insert({node, props.device_type}); } } } void VisitExpr_(const TupleGetItemNode* get_elem) final { ExprVisitor::VisitExpr_(get_elem); - const auto* tn = get_elem->tuple.operator->(); + const auto* tn = get_elem->tuple.get(); if (annotation_map_.count(tn)) { annotation_map_.insert({get_elem, annotation_map_.at(tn)}); } } - /* - * \brief Get the device type of the annotation node. - * \param call_node The on_device annotation call node. - * \return The device type. - */ - int GetDeviceId(const CallNode* call_node) { - ICHECK(IsOnDeviceNode(call_node)) << "The input call node must be on_device node."; - const OnDeviceAttrs* on_device_attr = call_node->attrs.as(); - return on_device_attr->device_type; - } - std::unordered_map annotation_map_; }; @@ -145,7 +137,7 @@ class RewriteAnnotation : public ExprMutator { return ExprMutator::VisitExpr_(op); } else { Expr new_let = Let(op->var, value, body); - UpdateAnnotationMap(op, new_let.operator->()); + UpdateAnnotationMap(op, new_let.get()); return this->VisitExpr(new_let); } } @@ -154,13 +146,13 @@ class RewriteAnnotation : public ExprMutator { Array fields; bool annotated = false; for (const auto& field : op->fields) { - annotated |= NeedDeviceCopy(field.operator->(), op); + annotated |= NeedDeviceCopy(field.get(), op); fields.push_back(GetDeviceCopyExpr(field, op)); } if (annotated) { Expr new_tuple = Tuple(fields); - UpdateAnnotationMap(op, new_tuple.operator->()); + UpdateAnnotationMap(op, new_tuple.get()); return this->VisitExpr(new_tuple); } else { return ExprMutator::VisitExpr_(op); @@ -169,9 +161,9 @@ class RewriteAnnotation : public ExprMutator { Expr VisitExpr_(const TupleGetItemNode* op) final { Expr tuple = op->tuple; - if (NeedDeviceCopy(tuple.operator->(), op)) { + if (NeedDeviceCopy(tuple.get(), op)) { Expr new_expr = TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); - UpdateAnnotationMap(op, new_expr.operator->()); + UpdateAnnotationMap(op, new_expr.get()); return this->VisitExpr(new_expr); } else { return ExprMutator::VisitExpr_(op); @@ -188,14 +180,15 @@ class RewriteAnnotation : public ExprMutator { return ExprMutator::VisitExpr_(if_node); } else { Expr new_if = If(cond, true_br, false_br); - UpdateAnnotationMap(if_node, new_if.operator->()); + UpdateAnnotationMap(if_node, new_if.get()); return this->VisitExpr(new_if); } } Expr VisitExpr_(const CallNode* call_node) final { - if (IsOnDeviceNode(call_node)) { - return this->VisitExpr(call_node->args[0]); + auto props = GetOnDeviceProps(call_node); + if (props.body.defined()) { + return this->VisitExpr(props.body); } if (IsDeviceCopyNode(call_node)) { @@ -205,14 +198,14 @@ class RewriteAnnotation : public ExprMutator { Array new_args; bool annotated = false; for (const auto& arg : call_node->args) { - annotated |= NeedDeviceCopy(arg.operator->(), call_node); + annotated |= NeedDeviceCopy(arg.get(), call_node); new_args.push_back(GetDeviceCopyExpr(arg, call_node)); } if (annotated) { Call new_call = Call(call_node->op, new_args, call_node->attrs, call_node->type_args); - UpdateAnnotationMap(call_node, new_call.operator->()); + UpdateAnnotationMap(call_node, new_call.get()); return this->VisitExpr(new_call); } else { return ExprMutator::VisitExpr_(call_node); @@ -231,7 +224,7 @@ class RewriteAnnotation : public ExprMutator { } Expr GetDeviceCopyExpr(const Expr& src, const ExprNode* dst) { - const auto* src_node = src.operator->(); + const auto* src_node = src.get(); if (!NeedDeviceCopy(src_node, dst)) return src; const auto sit = annotation_map_.find(src_node); @@ -286,225 +279,47 @@ class RewriteAnnotation : public ExprMutator { * \param dst_dev_type The device type where the data is copied to. * \return The created call node. */ - Call CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) { - auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; - static const Op& op = Op::Get("device_copy"); - Call device_copy = Call(op, {src}, Attrs(attrs), {}); - annotation_map_.insert({device_copy.operator->(), dst_dev_type}); + Expr CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) { + Expr device_copy = DeviceCopy(src, static_cast(src_dev_type), + static_cast(dst_dev_type)); + annotation_map_.insert({device_copy.get(), dst_dev_type}); return device_copy; } + const Op& device_copy_op_ = Op::Get("device_copy"); std::unordered_map annotation_map_; int fallback_device_; }; -// Get all annotation expressions. -class AnnotatationVisitor : private ExprVisitor { - public: - static Map GetAnnotations(const Expr& expr) { - AnnotatationVisitor visitor; - visitor(expr); - return visitor.annotations_; - } - - private: - void VisitExpr_(const CallNode* call_node) { - if (IsOnDeviceNode(call_node)) { - const auto* attr = call_node->attrs.as(); - annotations_.Set(GetRef(call_node), attr->device_type); - } - ExprVisitor::VisitExpr_(call_node); - } - Map annotations_; -}; - -/* - * \brief Return device allocation map based on the post order traversed graph. - * For the following program: - * .. code-block:: python - * x = relay.var("x") - * y = relay.var("y") - * add = relay.add(x, y) - * sqrt = relay.sqrt(add) - * log = relay.log(add) - * subtract = relay.subtract(sqrt, log) - * exp = relay.exp(subtract) - * - * Suppose we have annotated add, sqrt, and log with device 1, 2, and 3, - * respectively. The fallback/default device is 4. After Rewriting the - * program, we can have the following graph, where each copy op has both - * source and destination device type denoting which device the data should be - * copied from and to. - * - * x y - * \ / - * add/1 - * / \ - * copy1 copy2 - * | | - * sqrt/2 log/3 - * | | - * copy3 copy4 - * \ / - * subtract - * | - * exp +/*! \brief Builds a map from "on_device" CallNodes to their device types. * - * To Get the device mapping of each expression, we need to propagate the - * device information from the copy ops. This can be done in two passes. - * -Pass 1: Propagating the source device type to ops in a bottom-up way to the - * ancestors until encountering another copy op. For example, this way - * provides add, x, and y device types from the copy operator, `copy1`. - * -Pass 2: Propagating the destination device type of "the last" copy op to the - * remain nodes. For instance, this offers `subtract` and `exp` the - * same device type as `copy3`. + * No other expression appear in the result map. */ - -class DeviceInfo { +class AnnotationVisitor : private ExprVisitor { public: - static Map GetDeviceMap(const Expr& expr) { - DeviceInfo device_info; - device_info.post_visitor_ = PostDfsOrderVisitor(); - device_info.post_visitor_.Visit(expr); - if (device_info.post_visitor_.num_device_copy_ops_ > 0) { - device_info.PropagateDeviceId(); - return device_info.device_map_; - } else { - return Map(); - } + static void AccumAnnotations(const Expr& expr, Map* annotations) { + AnnotationVisitor visitor(annotations); + visitor(expr); } private: - class PostDfsOrderVisitor : private ExprVisitor { - public: - void Visit(const Expr& expr) { - if (const auto* fn = expr.as()) { - for (const auto& param : fn->params) { - this->VisitExpr(param); - } - this->VisitExpr(fn->body); - } else { - this->VisitExpr(expr); - } - } + explicit AnnotationVisitor(Map* annotations) : annotations_(annotations) {} - private: - // Post order traversal. - void VisitExpr_(const FunctionNode* fn) final { - // TODO(zhiics) Skip annotation of function node for now. - } - - void VisitExpr_(const ConstantNode* cn) final { device_tag_[cn] = dev_type_; } - - void VisitExpr_(const CallNode* call) final { - // Skip annotation nodes. - if (!IsOnDeviceNode(call)) { - if (const auto* node = GetDeviceCopyNode(call)) { - ICHECK(node->IsInstance()); - const auto* call_node = static_cast(node); - auto attrs = call_node->attrs.as(); - - if (attrs) { - num_device_copy_ops_++; - dev_type_ = attrs->src_dev_type; - for (auto& arg : call->args) { - Visit(arg); - // restore the type for remaining arguments - dev_type_ = attrs->src_dev_type; - } - device_tag_[call] = attrs->dst_dev_type; - // update the out_dev_type_, which should be the dst_dev_type of last copy - out_dev_type_ = attrs->dst_dev_type; - } else { - auto attrs = call_node->attrs.as(); - CHECK(attrs) << "must be non-null"; - num_device_copy_ops_++; - dev_type_ = Downcast(attrs->metadata["source_device"]); - for (auto& arg : call->args) { - Visit(arg); - // restore the type for remaining arguments - dev_type_ = Downcast(attrs->metadata["source_device"]); - } - device_tag_[call] = Downcast(attrs->metadata["dst_device"]); - // update the out_dev_type_, which should be the dst_dev_type of last copy - out_dev_type_ = Downcast(attrs->metadata["dst_device"]); - } - } else { - for (auto& arg : call->args) { - int cur_dev_type = dev_type_; - Visit(arg); - // restore the type for remaining arguments - dev_type_ = cur_dev_type; - } - device_tag_[call] = dev_type_; - } - } - } - - void VisitExpr_(const TupleNode* tn) final { - ExprVisitor::VisitExpr_(tn); - // TODO(zhiics) Skip annotation of tuple node for now. - } - - void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); } - - void VisitExpr_(const VarNode* vn) final { device_tag_[vn] = dev_type_; } - - void VisitExpr_(const LetNode* ln) final { - ExprVisitor::VisitExpr_(ln); - device_tag_[ln] = dev_type_; - } - - void VisitExpr_(const IfNode* in) final { - ExprVisitor::VisitExpr_(in); - device_tag_[in] = dev_type_; - } - - int num_device_copy_ops_{0}; - int dev_type_ = -1; - int out_dev_type_ = -1; - std::unordered_map device_tag_; - friend DeviceInfo; - }; - - /* - * \brief Returns a device copy node based on the current expr node. It - * returns a device copy node either the current expr node is a device copy - * node or the current expr node is a function node whose body is a device - * copy node (i.e. the fused function of a device copy call node). - */ - static const ExprNode* GetDeviceCopyNode(const ExprNode* node) { - if (IsDeviceCopyNode(node)) { - return node; - } else if (node->IsInstance()) { - const auto* call_node = static_cast(node); - if (const auto* fn = call_node->op.as()) { - const ExprNode* body = fn->body.operator->(); - if (IsDeviceCopyNode(body)) { - return body; - } - } - } - return nullptr; - } - - void PropagateDeviceId() { - int out_dev_type = post_visitor_.out_dev_type_; - for (auto& it : post_visitor_.device_tag_) { - if (it.second != -1) { - device_map_.Set(GetRef(it.first), it.second); - } else { - device_map_.Set(GetRef(it.first), out_dev_type); - } + void VisitExpr_(const CallNode* call_node) final { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined()) { + annotations_->Set(GetRef(call_node), props.device_type); } + ExprVisitor::VisitExpr_(call_node); } - PostDfsOrderVisitor post_visitor_; - Map device_map_; + Map* annotations_; }; +/*! + * \brief Inserts "device_copy" CallNodes where an existing "on_device" CallNode suggests + * a transition between device domains. All existing "on_device" CallNodes are removed. + */ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { RewriteAnnotation rewrote = RewriteAnnotation(); Expr new_expr = rewrote.Rewrite(expr, fallback_device); @@ -518,7 +333,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { std::vector new_body; if (const TupleNode* tuple = body.as()) { for (const auto& field : tuple->fields) { - if (!IsOnDeviceNode(field.operator->())) { + if (!IsOnDeviceCall(field)) { new_body.push_back(field); } } @@ -537,7 +352,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { } else if (const TupleNode* tuple = new_expr.as()) { std::vector new_fields; for (const auto& field : tuple->fields) { - if (!IsOnDeviceNode(field.operator->())) { + if (!IsOnDeviceCall(field)) { new_fields.push_back(field); } } @@ -552,13 +367,19 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { } } -Map CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); } - Map CollectDeviceAnnotationOps(const Expr& expr) { - return AnnotatationVisitor::GetAnnotations(expr); + Map annotations; + AnnotationVisitor::AccumAnnotations(expr, &annotations); + return annotations; } -TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo").set_body_typed(CollectDeviceInfo); +Map CollectAllDeviceAnnotationOps(const IRModule& mod) { + Map annotations; + for (const auto& pair : mod->functions) { + AnnotationVisitor::AccumAnnotations(pair.second, &annotations); + } + return annotations; +} TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps") .set_body_typed(CollectDeviceAnnotationOps); @@ -573,7 +394,7 @@ Pass RewriteAnnotatedOps(int fallback_device) { return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation").set_body_typed(RewriteAnnotatedOps); +TVM_REGISTER_GLOBAL("relay._transform.RewriteAnnotatedOps").set_body_typed(RewriteAnnotatedOps); } // namespace transform diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc new file mode 100644 index 0000000000000..274a31c06d0bc --- /dev/null +++ b/src/relay/transforms/device_planner.cc @@ -0,0 +1,1941 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_planner.cc + * \brief Determines a unique device to hold the result of every Relay sub-expression. + * + * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. + * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the + * specific target associated with D (this is recovered independently via a TargetMap), and we + * do not track the storage scope within D (this is yet to be implemented). + * + * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', + * see below. + * + * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: + * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and + * 'dst_dev_type' device type, which constrain the argument and context of the call + * respectively. It is ok if source and destination devices are the same, such no-op copies + * will be removed after accounting for the device preference. + * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which + * constrains the argument of the call, but (usually, see below) leaves the context + * unconstrained. These are called 'annotations' in the rest of the code, have no operational + * significance by themselves, but may trigger the insertion of a new "device_copy". + * - In two situations the result of an "on_device" CallNode may also be constrained to the + * given device: + * - The "on_device" call occurs at the top-level of a function body, or occurs as an + * immediately let-bound expression. In this situation the extra degree of freedom in + * the function result and let-binding leads to surprising device copies, so we simply + * force the function result or let-bound variable to the given device. + * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted + * it ourselves during an earlier invocation of this pass. This helps make this pass + * idempotent. + * + * We proceed in four phases: + * + * Phase 0 + * ------- + * We rewrite the programs to handle some special cases: + * - "on_device" calls at the top-level of function or immediately let-bound are rewritten + * to have \code is_fixed=true \endcode. + * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written + * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from + * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. + * + * Phase 1 + * ------- + * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see + * below) to all other Relay sub-expressions. (For idempotence we also respect any existing + * "on_device" function attributes we introduce below.) + * + * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the + * same device. However each call site can use a different device. In other words primitives are + * 'device polymorphic' since we compile and execute them for each required device. + * + * For most Relay expressions the device for the overall expression is the same as the device + * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple + * itself, the condition and arms of an \p if must all be on the same device as the overall if, + * and so on. + * + * Some special ops (or 'dialects') are handled: + * - Relay supports computing the shape of tensors and operators at runtime using "shape_of", + * "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors + * they describe may reside on any device. + * - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again + * shapes reside on the CPU, but the allocated tensors may reside on any device. + * + * Two Relay expression have special handling: + * - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the + * overall let. However the result of \p e1 may be on a different device. + * - For a function \code fn(x, y) { body } \endcode the result of the function must be on the + * same device as \p body. However parameters \p x and \p may be on different devices, even + * different from each other. Every call to the function must use the same choice of parameter + * and result devices -- there is no 'device polymorphism' for Relay functions. + * + * Phase 2 + * ------- + * After flowing constraints we apply some defaulting heuristics (using a global default device) + * to fix the device for any as-yet unconstrained sub-expressions. + * - Unconstrained function result devices default to the global default device. + * - Unconstrained function parameters devices default to the device for the function result. + * - Unconstrained let-bound expression devices default to the device for the overall let. + * + * Phase 3 + * ------- + * Finally, the result of this analysis is reified into the result as: + * - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for + * every function (both top-level and local). These describe the devices for the function's + * parameters and the result. + * - Additional "device_copy" CallNodes where a copy is required in order to respect the + * intent of the original "on_device" CallNodes. + * - Additional "on_device" CallNodes where the device type of an expression does not match + * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * this means "on_device" CallNodes may appear in two places: + * - On a let-bound expression if its device differs from the overall let expression. + * - On a call argument if its device differs from the call result. In particular, the + * argument to a "device_copy" call will always be wrapped in an "on_device". (That may + * seem pedantic but simplifies downstream handling.) + * However since we make it easy to track devices for variables we never wrap an "on_device" + * around a var or global var. These uses of "on_device" imply both the argument and result are + * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, + * which helps make this pass idempotent. + * + * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device + * for any expression for their own use, e.g. during memory planning. All downstream passes must + * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF + * must respect the lexical scoping convention: + * \code + * f(on_device(g(h(a, b), c), device_type=CPU)) + * ==> + * let %x0 = on_device(h(a, b), device_type=CPU) + * let %x1 = on_device(g(%x0), device-type=CPU) + * f(on_device(%x1, device_type=CPU)) + * \endcode + * + * This pass should be run before FuseOps it can use device-specific fusion rules. + * + * 'Stored on' vs 'Executes on' + * ---------------------------- + * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the + * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for + * primitives. + * + * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are + * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific + * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to + * know exactly which device (possibly one of a number of available 'CPU'-like devices) is + * responsible for execution. Currently that's handled independently by the \p AnnotateTargets + * pass, but we'd like to fold that into device planning here to ensure everything is consistent. + * + * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay + * expression (eg an if expression) on one device even though the tensor data resides on + * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' + * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just + * compile the function body for the function's result device. + * + * This works after conversion to ANF provided the compilation for a let expression is prepared + * to make a cross-device call. However we leave it to a downstream transformation to heuristically + * minimize cross-device calls by moving device copies out of functions. E.g.: + * \code + * def @f() { // execute on CPU + * let x = on_device(...GPU computation..., device_type=GPU); + * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) + * } + * def @main() { + * ... call @f() on CPU ... + * } + * \endcode + * could be rewritten to: + * \code + * def @f() { // execute on GPU + * let x = ...GPU computation...; + * ...GPU computation... + * } + * def @main() { + * let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU) + * ... use x on CPU ... + * } + * \endcode + * + * Higher-order shenanigans + * ------------------------ + * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions + * as arguments (even anonymous functions), return functions, evaluate conditional expressions + * over functions, and so on. We handle this during constraint solving using the domain: + * \code + * D ::= -- first-order + * | fn(D,...,D):D -- higher-order + * \endcode + * In this way we can determine the device for all function parameters and results. E.g. for + * \code + * let f = fn(x, y) { ... } + * let g = fn(f, z) { f(z, z) } + * g(f, on_device(..., device_type=CPU)) + * \endcode + * the parameters \p x and \p y will be on the CPU. + * + * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a + * function. Our analysis must guarantee that the function's parameters and result devices are + * consistent for \p e2, \p e3, and the context of the call. But: + * - Which device holds the closure result of evaluating \p e1 ? + * - If \p e2 is of function type, what does that mean when we say every function parameter + * is on a device? + * - If \p e1 returns a function, what does that mean when we say every function result is + * on a device? + * + * Since higher-order aspects are later compiled away (by 'defunctionalization' + * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular, + * we really don't want our domain \p D to allow for yet another device for the function closure. + * So we'll just force the 'device for a function' to be the same as the device for the function's + * result using the notion of the 'result domain' for a domain: + * \code + * result_domain() = + * result_domain(fn(D1,...,Dn):Dr) = result_domain(Dr) + * \endcode + * + * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the + * analysis encounters a function inside one of those it simply forces all argument and result + * devices for the function to match the device for the first-order expression. For example, + * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function + * parameters and result must similarly be on the GPU. + * + * ------- + * | AOR | This pass supports all of Relay. + * ------- + * ^ + * | + * `-- Mark's stamp of completeness :-) + * + * TODO(mbs): + * * Though on_device is the identity for all types we can't wrap it around functions/constructors + * taking type args (or at least not without changing type_infer.cc to see through them). + * This is not currently handled generally. + * * Proper diagnostics for unification failure using spans. + * * Make sure the pass is idempotent even after FuseOps etc. + * * Support application of constructors properly. Are they device polymorphic? + * * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'. + * * Support running the pass post FuseOps (so need to understand primitive functions, both + * outlines and lined) and post the VM transforms (probably need to support more intrinsic + * forms?). + * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default + * device for primitives vs the default device for the rest of Relay. + * * We'll probably need some support for partial 'device polymorphism' for functions once we + * incorporate targets and memory scopes into the domain. For example it's ok for the function + * body to be executed on different device ids provided they have the same target and memory + * scope. + * * Might be simpler to just let every type have a device annotation rather than work in + * a separate domain? + * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + */ + +#include "./device_planner.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../attrs/annotation.h" +#include "../attrs/device_copy.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +class DeviceDomain; +using DeviceDomainPtr = std::shared_ptr; + +/****** + ****** Domains + ******/ + +/*! + * \brief Represents the domain over which we collect equality constraints. + * + * \code + * D ::= ?x? -- first order, free + * | -- first order, bound + * | fn(D1, ..., Dn):Dr -- higher order + * \endcode + * + * We require a function value to be on the same device as its result. To support that we need + * a notion of the 'result domain' of a domain: + * \code + * result_domain(?x?) = ?x? + * result_domain() = + * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) + * \endcode + */ +class DeviceDomain { + public: + /*! + * \brief Constructs a first-order domain of \p device_type, which may be + * \p kInvalidDeviceType to indicate the domain is free. + */ + explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {} + + /*! + * \brief Constructs a higher-order domain, where \p args_and_result contain the + * function argument and result domains in order. + */ + explicit DeviceDomain(std::vector args_and_result) + : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {} + + /*! \brief Returns true if domain is first-order and free. */ + bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); } + + /*! \brief Returns true if domain is higher-order. */ + bool is_higher_order() const { return !args_and_result_.empty(); } + + DLDeviceType first_order_device_type() const { + ICHECK(args_and_result_.empty()); + return device_type_; + } + + size_t function_arity() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.size() - 1UL; + } + + DeviceDomainPtr function_param(size_t i) const { + ICHECK(!args_and_result_.empty()); + ICHECK_LT(i + 1, args_and_result_.size()); + return args_and_result_[i]; + } + + DeviceDomainPtr function_result() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.back(); + } + + private: + /*! + * \brief If this is a function domain then always kInvalidDevice. Otherwise will be + * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is + * bound. + */ + const DLDeviceType device_type_; + + /*! + * \brief If this is a function domain then the sub-domains for each of the function's + * arguments, and the domain for its result. Otherwise empty. + */ + const std::vector args_and_result_; + + friend struct DeviceDomainHash; + friend struct DeviceDomainEqual; + friend class DeviceDomains; +}; + +// Ye olde boost hash mixer. +constexpr size_t mix(size_t h1, size_t h2) { + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); +} + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. +struct DeviceDomainHash { + size_t operator()(const DeviceDomainPtr& domain) const { + if (domain->is_free()) { + // Give each free first-order domain its own identity. + return static_cast(reinterpret_cast(domain.get())); + } else { + size_t h = domain->args_and_result_.size(); + h = mix(h, std::hash()(static_cast(domain->device_type_))); + for (const auto& sub_domain_ptr : domain->args_and_result_) { + h = mix(h, DeviceDomainHash()(sub_domain_ptr)); + } + return h; + } + } +}; + +struct DeviceDomainEqual { + public: + bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { + if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) { + // Mismatched arities are never equal. + // (Though we'll never ask to do such a comparison explicitly, the hash map + // may do so implicitly due to hash collisions.) + return false; + } + if (lhs->is_free() && rhs->is_free()) { + // Compare first-order free domains by their address. + return lhs.get() == rhs.get(); + } + if (lhs->args_and_result_.empty()) { + // Compare first-order domains by their device type -- free vs bound will compare as false. + return lhs->device_type_ == rhs->device_type_; + } else { + // Compare higher-order domains pointwise. + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) { + return false; + } + } + return true; + } + } +}; + +/*! + * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation + * built up by calls to \p Unify. + */ +class DeviceDomains { + public: + DeviceDomains() = default; + + /*! + * \brief Returns a domain appropriate for \p type who's result domain is bound + * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain + * will be free. + */ + static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type) { + if (const auto* func_type_node = type.as()) { + std::vector args_and_result; + args_and_result.reserve(func_type_node->arg_types.size() + 1); + for (const auto& arg_type : func_type_node->arg_types) { + args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType)); + } + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type)); + return std::make_shared(std::move(args_and_result)); + } else { + return std::make_shared(device_type); + } + } + + /*! + * \brief Returns a higher-order domain with \p args_and_results. + */ + static DeviceDomainPtr MakeDomain(std::vector arg_and_results) { + return std::make_shared(std::move(arg_and_results)); + } + + /*! \brief Returns a domain with the given result device type appropriate \p device_type. */ + static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) { + ICHECK_NE(device_type, kInvalidDeviceType); + return MakeDomain(type, device_type); + } + + /*! \brief Returns a free domain appropriate for \p type. */ + static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); } + + /*! \brief Returns the domain representing the equivalence class containing \p domain. */ + DeviceDomainPtr Lookup(DeviceDomainPtr domain) { + DeviceDomainPtr root = domain; + while (true) { + auto itr = domain_to_equiv_.find(root); + if (itr == domain_to_equiv_.end()) { + break; + } + ICHECK_NE(itr->second, root); + root = itr->second; + ICHECK_NOTNULL(root); + } + // Path compression. + while (domain != root) { + auto itr = domain_to_equiv_.find(domain); + ICHECK(itr != domain_to_equiv_.end()); + domain = itr->second; + ICHECK_NOTNULL(domain); + itr->second = root; + } + return root; + } + + /*! + * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs. + * + * Throws \p Error on failure. + */ + DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + // TODO(mbs): Proper diagnostics. + ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) + << "Device domains:" << std::endl + << ToString(lhs) << std::endl + << "and" << std::endl + << ToString(rhs) << std::endl + << "do not have the same kind and can't be unified."; + if (rhs->is_free()) { + return lhs; + } else if (lhs->is_free()) { + return rhs; + } else if (lhs->args_and_result_.empty()) { + // Must have consistent device types for first order domains. + if (lhs->device_type_ != rhs->device_type_) { + // TODO(mbs): Proper diagnostics. + std::ostringstream os; + os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_; + throw Error(os.str()); + } + return lhs; + } else { + // Recurse for higher-order. + std::vector args_and_result; + args_and_result.reserve(lhs->args_and_result_.size()); + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i])); + } + return MakeDomain(std::move(args_and_result)); + } + } + + /*! + * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p + * rhs disagree on bound device type. + * + * Throws \p Error on failure. + */ + // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but + // given we have refs to functions I'm prepared to be surprised. + DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + lhs = Lookup(lhs); + rhs = Lookup(rhs); + auto joined_domain = Join(lhs, rhs); + if (!DeviceDomainEqual()(lhs, joined_domain)) { + domain_to_equiv_.emplace(lhs, joined_domain); + } + if (!DeviceDomainEqual()(rhs, joined_domain)) { + domain_to_equiv_.emplace(rhs, joined_domain); + } + return joined_domain; + } + + /*! + * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order, + * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as + * \p Unify. + * + * Throws \p Error on failure. + */ + void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (!lhs->is_higher_order() && rhs->is_higher_order()) { + Collapse(lhs, rhs); + } else { + Unify(lhs, rhs); + } + } + + /*! \brief Returns true if a domain is known for \p expr. */ + bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } + + /*! \brief Returns the domain representing \p expr. */ + DeviceDomainPtr DomainFor(const Expr& expr) { + ICHECK(expr.defined()); + auto itr = expr_to_domain_.find(expr.get()); + if (itr != expr_to_domain_.end()) { + return Lookup(itr->second); + } + auto domain = Free(expr->checked_type()); + expr_to_domain_.emplace(expr.get(), domain); + return domain; + } + + /*! + * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the + * callee is a primitive or special operation we handle it specially. Otherwise defers to \p + * DomainFor(call->op). + * + * This special handling is needed: + * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices. + * - To handle some special ops which constrain devices to the CPU. + * - To allow the same primitive to be called on different devices at different call sites. + * Since each call to the op can have a different domain we index the ops by the call expression + * rather than the op itself. + */ + DeviceDomainPtr DomainForCallee(const Call& call) { + auto itr = call_to_callee_domain_.find(call.get()); + if (itr != call_to_callee_domain_.end()) { + return Lookup(itr->second); + } + std::vector args_and_result; + if (call->op == OnDeviceOp()) { + // on_device(expr, device_type=, is_fixed=false) + // on_device : fn():?x? + // + // on_device(expr, device_type=, is_fixed=true) + // on_device: fn(): + auto props = GetOnDeviceProps(call.get()); + args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.device_type)); + if (props.is_fixed) { + args_and_result.emplace_back(args_and_result.front()); + } else { + args_and_result.emplace_back(Free(props.body->checked_type())); + } + } else if (call->op == DeviceCopyOp()) { + // device_copy(expr, src_dev_type=, dst_dev_type=) + // device_copy: fn(): + auto props = GetDeviceCopyProps(call.get()); + args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.src_dev_type)); + args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.dst_dev_type)); + } else if (call->op == alloc_storage_op) { + ICHECK_EQ(call->args.size(), 2U); + // alloc_storage(size, alignment, device_type=) + // alloc_storage: fn(, ): + const auto* attrs = call->attrs.as(); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back( + ForDeviceType(call->checked_type(), static_cast(attrs->device_type))); + } else if (call->op == alloc_tensor_op) { + ICHECK_EQ(call->args.size(), 3U); + // alloc_tensor(storage, offset, shape) + // alloc_tensor: fn(?x?, , ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op == shape_func_op) { + ICHECK_EQ(call->args.size(), 3U); + // shape_func(func, inputs, outputs, is_inputs=[...]) + // shape_func: fn(..., , ): + // where ... is a free domain appropriate for func's type + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + // TODO(mbs): I think this should be on the cpu only when is_input = [false], but + // what do we do when we have multiple arguments with different is_input values? + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == shape_of_op) { + ICHECK_EQ(call->args.size(), 1U); + // shape_of(tensor) + // shape_of: fn(?x?): + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == invoke_tvm_op) { + ICHECK_EQ(call->args.size(), 3U); + // invoke_tvm_op(op, inputs, outputs) + // invoke_tvm_op: fn(..., ?x?, ?x?):?x? + // where ... is a free domain appropriate for op's type + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + } else if (call->op == reshape_tensor_op) { + ICHECK_EQ(call->args.size(), 2U); + // reshape_tensor(data, shape) + // reshape_tensor: fn(?x?, ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x?, ..., ?x?):?x? + // (all args and result must be first-order). + auto free_domain = Free(arb_); + for (size_t i = 0; i < call->args.size(); ++i) { + args_and_result.emplace_back(free_domain); + } + args_and_result.emplace_back(free_domain); + } else { + // Defer to normal case where op can be an arbitrary expression. + return DomainFor(call->op); + } + auto domain = MakeDomain(std::move(args_and_result)); + call_to_callee_domain_.emplace(call.get(), domain); + return domain; + } + + /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */ + void UnifyExprExact(const Expr& lhs, const Expr& rhs) { + auto lhs_domain = DomainFor(lhs); + auto rhs_domain = DomainFor(rhs); + try { + Unify(lhs_domain, rhs_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expressions:" << std::endl + << PrettyPrint(lhs) << std::endl + << "with device:" << std::endl + << ToString(lhs_domain) << "and:" << std::endl + << PrettyPrint(rhs) << std::endl + << "with device:" << std::endl + << ToString(rhs_domain) << std::endl + << e.what(); + } + } + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + */ + void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + Unify(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } + } + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + * If \p expected_domain is higher-order but \p expr is first-order, require all arguments + * and the result of \p expected_domain to have the same domain as for \p expr. + */ + void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + UnifyCollapsed(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } + } + + /*! \brief Returns true if \p domain contains any free sub-domains. */ + bool AnyFree(DeviceDomainPtr domain) { + domain = Lookup(domain); + if (domain->is_free()) { + return true; + } + for (const auto& sub_domain : domain->args_and_result_) { + if (AnyFree(sub_domain)) { + return true; + } + } + return false; + } + + /* + * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. + * This can be used to handle functions within tuples, references and ADTs since we don't + * attempt to track anything beyond 'the device' for expressions of those first-order types. + * + * Throws \p Error on failure. + */ + void Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain) { + for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { + Unify(higher_order_domain->function_param(i), first_order_domain); + } + Unify(higher_order_domain->function_result(), first_order_domain); + } + + /*! \brief Force all free domains in \p domain to default to \p default_device_type. */ + void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) { + ICHECK_NE(default_device_type, kInvalidDeviceType); + domain = Lookup(domain); + if (domain->is_free()) { + // Will never throw since lhs is free. + Unify(domain, std::make_shared(default_device_type)); + } else if (!domain->args_and_result_.empty()) { + for (const auto& sub_domain : domain->args_and_result_) { + SetDefault(sub_domain, default_device_type); + } + } + } + + /*! + * \brief If \p domain is higher-order and its result domain is free, force it to + * \p default_device_type. Then force any remaining free domains to the result domain + * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault. + */ + void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type) { + if (!domain->is_higher_order()) { + SetDefault(domain, default_device_type); + return; + } + DLDeviceType result_device_type = ResultDeviceType(domain); + if (result_device_type == kInvalidDeviceType) { + // If the function result device is still free use the given default. + result_device_type = default_device_type; + } + // Default any remaining free parameters to the function result device. + SetDefault(domain, result_device_type); + } + + /*! \brief Returns one-line description of \p domain for debugging. */ + std::string ToString(DeviceDomainPtr domain) { + domain = Lookup(domain); + std::ostringstream os; + if (domain->is_free()) { + // first-order free + os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; + } else if (domain->args_and_result_.empty()) { + // first-order bound + os << "<" << domain->device_type_ << ">"; + } else { + // higher-order + os << "fn("; + for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) { + if (i > 0) { + os << ","; + } + os << ToString(domain->args_and_result_[i]); + } + os << "):" << ToString(domain->args_and_result_.back()); + } + return os.str(); + } + + /*! \brief Returns description of entire system of constraints for debugging */ + std::string ToString() { + std::ostringstream os; + for (const auto& pair : expr_to_domain_) { + os << "expression:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + for (const auto& pair : call_to_callee_domain_) { + os << "call:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "callee domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + return os.str(); + } + + /*! + * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). + */ + DeviceDomainPtr ResultDomain(DeviceDomainPtr domain) { + domain = Lookup(domain); + while (!domain->args_and_result_.empty()) { + domain = Lookup(domain->args_and_result_.back()); + } + return domain; + } + + /*! + * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain + * comment). + */ + DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_device_type(); + } + + private: + /*! \brief Intrinsics we need to handle specially. */ + const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); + const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); + const Op& shape_of_op = Op::Get("vm.shape_of"); + const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); + const Op& shape_func_op = Op::Get("vm.shape_func"); + const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); + /*! \brief The CPU device type for special operators such as dynamic shape functions. */ + const DLDeviceType cpu_device_type_ = kDLCPU; + /*! \brief Placeholder for any first-order type. */ + Type arb_ = TupleType(); + /*! \brief The domain for first-order expressions on the CPU. */ + DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_); + + /*! \brief Maps expressions to their domains as determined during analysis. */ + std::unordered_map expr_to_domain_; + + /*! + * \brief Maps call expressions to the domains for their callee where the callee is a primitive. + */ + std::unordered_map call_to_callee_domain_; + + /*! \brief Maps device domains to their equivalent domains as determined during unification. */ + std::unordered_map + domain_to_equiv_; +}; + +/****** + ****** Phase 0 + ******/ + +/*! + * \brief Rewrites "on_device" calls to handle some special cases. + */ +class RewriteOnDevices : public ExprMutator { + public: + RewriteOnDevices() = default; + + private: + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + Expr tuple = VisitExpr(tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy. + Expr tuple_get_item = + TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + auto props = GetOnDeviceProps(tuple); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "wrapping tuple get item:" << std::endl + << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl + << "with \"on_device\" for device " << props.device_type; + return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false); + } else { + return tuple_get_item; + } + } + + Expr VisitExpr_(const LetNode* let_node) final { + auto expr = GetRef(let_node); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + Expr value = VisitExpr(inner_let_node->value); + auto props = GetOnDeviceProps(value); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising let-bound expression of let:" << std::endl + << PrettyPrint(expr) << std::endl + << "to be fixed to device " << props.device_type; + value = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + expr = VisitExpr(expr); + // TODO(mbs): Avoid copy. + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr, + /*span=*/std::get<2>(*itr)); + } + return expr; + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + Expr body = VisitExpr(function_node->body); + auto props = GetOnDeviceProps(body); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising body of function:" << std::endl + << PrettyPrint(GetRef(function_node)) << std::endl + << "to be fixed to device " << props.device_type; + body = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + // TODO(mbs): Avoid copy + return Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + } +}; + +/****** + ****** Phase 1 + ******/ + +/* + * \brief Collects the system of device constraints for all sub-expressions in a module. + * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. + */ +class DeviceAnalyzer : public ExprVisitor { + public: + explicit DeviceAnalyzer(IRModule mod) + : mod_(std::move(mod)), domains_(std::make_unique()) {} + + /*! + * \brief Returns the expression-to-device-domain map for all expressions in all the global + * function definitions in the module. Expressions may have free domains, these will be resolved + * by \p DeviceDefaulter below. + */ + std::unique_ptr Analyze() { + VLOG_CONTEXT << "DeviceAnalyzer"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + + // Find the higher-order domain for the callee. See DomainForCallee for the special rules + // for primitives. + VisitExpr(call_node->op); + auto func_domain = domains_->DomainForCallee(call); // higher-order + + // Build the domain for the function implied by its arguments and call context. + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + std::vector args_and_result_domains; + args_and_result_domains.reserve(call_node->args.size() + 1); + for (const auto& arg : call_node->args) { + args_and_result_domains.emplace_back(domains_->DomainFor(arg)); + VisitExpr(arg); + } + args_and_result_domains.emplace_back(domains_->DomainFor(call)); + auto implied_domain = + DeviceDomains::MakeDomain(std::move(args_and_result_domains)); // higher-order + + VLOG(1) << "initial call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied domain:" << std::endl + << domains_->ToString(implied_domain) << "for call:" << std::endl + << PrettyPrint(call); + + // The above must match. + try { + domains_->Unify(func_domain, implied_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:" + << std::endl + << PrettyPrint(call) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied call devices:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << e.what(); + } + + VLOG(1) << "final call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // Let var must be same device as value it is bound to. + domains_->UnifyExprExact(let->var, let->value); // may be higher-order + // Let body must be same device as overall let. + domains_->UnifyExprExact(let, let->body); // may be higher-order + + VisitExpr(let->var); + VisitExpr(let->value); + + expr = let->body; + } + + // Visit the last body + VisitExpr(expr); + } + + void VisitExpr_(const FunctionNode* function_node) final { + // No need to step into fused primitive functions as they are lowered individually according + // to the devices of all their call sites. + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + + // The function body domain must match the function result domain. + domains_->UnifyExprExact(function_node->body, + func_domain->function_result()); // may be higher-order + + VLOG(1) << "initial function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + // The parameter domains must match the function argument domains. + domains_->UnifyExprExact(function_node->params[i], + func_domain->function_param(i)); // may be higher-order + VisitExpr(function_node->params[i]); + } + + // If the function already has an "on_device" attribute then we can further + // constrain the function's domain to match it. + Optional opt_attrs = + function_node->GetAttr(FunctionOnDeviceAttrs::kFunctionAttrsKey); + if (opt_attrs) { + std::vector args_and_result; + for (size_t i = 0; i < function_node->params.size(); ++i) { + args_and_result.emplace_back( + domains_->ForDeviceType(function_node->params[i]->checked_type(), + GetFunctionParamDeviceType(function_node, i))); + } + args_and_result.emplace_back(domains_->ForDeviceType( + function_node->body->checked_type(), GetFunctionResultDeviceType(function_node))); + auto annotation_domain = domains_->MakeDomain(std::move(args_and_result)); + try { + domains_->Unify(func_domain, annotation_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) + << "Function devices are incompatible with its \"on_device\" annotation. Function:" + << std::endl + << PrettyPrint(function) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and annotation devices:" << std::endl + << domains_->ToString(annotation_domain) << std::endl + << e.what(); + } + } + + VisitExpr(function_node->body); + + VLOG(1) << "final function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + } + + void VisitExpr_(const TupleNode* tuple_node) final { + Tuple tuple = GetRef(tuple_node); + for (size_t i = 0; i < tuple->fields.size(); i++) { + auto domain = domains_->DomainFor(tuple->fields[i]); // may be higher-order + domains_->UnifyExprCollapsed(tuple, domain); // collapse to first-order if needed + VisitExpr(tuple->fields[i]); + } + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + TupleGetItem tuple_get_item = GetRef(tuple_get_item_node); + auto domain = domains_->DomainFor(tuple_get_item); // may be higher-order + domains_->UnifyExprCollapsed(tuple_get_item_node->tuple, + domain); // collapse to first-order if needed + VisitExpr(tuple_get_item_node->tuple); + } + + class DevicePatternAnalyzer : public PatternVisitor { + public: + DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node) + : domains_(domains), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto var_domain = domains_->DomainFor(pattern_var_node->var); // may be higher order + domains_->UnifyExprCollapsed(GetRef(adt_node_), + var_domain); // collapse to first-order if needed + } + + /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */ + DeviceDomains* domains_; + /*! \brief The expression for the ADT we are matching over. */ + const ExprNode* adt_node_; + }; + + void VisitPattern(const Pattern& pattern) final {} + + void VisitExpr_(const MatchNode* match_node) final { + // For match node, we unify the value and the rhs of each clause + Match match = GetRef(match_node); + auto match_domain = domains_->DomainFor(match); // may be higher-order + DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get()); + domains_->UnifyExprCollapsed(match->data, match_domain); // collapse to first-order if needed + for (const auto& clause : match->clauses) { + pattern_analyzer.VisitPattern(clause->lhs); + domains_->UnifyExprExact(clause->rhs, match_domain); + VisitExpr(clause->rhs); + } + VisitExpr(match_node->data); + } + + void VisitExpr_(const GlobalVarNode* global_var_node) final { + domains_->DomainFor(GetRef(global_var_node)); + } + + void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef(var_node)); } + + void VisitExpr_(const ConstantNode* constant_node) final { + domains_->DomainFor(GetRef(constant_node)); + } + + void VisitExpr_(const ConstructorNode* constructor_node) final { + // Probably needs to be device polymorphic. + domains_->DomainFor(GetRef(constructor_node)); + } + + void VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + auto domain = domains_->DomainFor(ife); // may be higher-order + domains_->UnifyExprCollapsed(if_node->cond, domain); // collapse to first-order if needed + domains_->UnifyExprExact(if_node->true_branch, domain); + domains_->UnifyExprExact(if_node->false_branch, domain); + VisitExpr(if_node->cond); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + } + + void VisitExpr_(const OpNode* op) final { + // no-op, primitive operators are handled at their call-sites. + } + + void VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + auto domain = domains_->DomainFor(ref_create_node->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_create, domain); // collapse to first-order if needed + VisitExpr(ref_create_node->value); + } + + void VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + auto domain = domains_->DomainFor(ref_read); // may be higher-order + domains_->UnifyExprCollapsed(ref_read_node->ref, domain); // collapse to first-order if needed + VisitExpr(ref_read_node->ref); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + auto domain = domains_->DomainFor(ref_write->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_write->ref, domain); // collapse to first-order if needed + domains_->UnifyExprCollapsed(ref_write, domain); // collapse to first-order if needed + VisitExpr(ref_write_node->ref); + VisitExpr(ref_write_node->value); + } + + /*! \brief The module we are analyzing. */ + IRModule mod_; + /*! \brief The domains for all expressions processed so far. */ + std::unique_ptr domains_; +}; + +/****** + ****** Phase 2 + ******/ + +/*! + * \brief Ensures every sub-expression in a module has a device type, using both the global + * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes. + * + * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap + * order. + */ +class DeviceDefaulter : public ExprVisitor { + public: + DeviceDefaulter(IRModule mod, std::unique_ptr domains, + DLDeviceType default_device_type) + : mod_(std::move(mod)), + domains_(std::move(domains)), + default_device_type_(default_device_type) {} + + std::unique_ptr Default() { + VLOG_CONTEXT << "DeviceDefaulter"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + if (domains_->AnyFree(func_domain)) { + VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); + } + VisitExpr(function_node->body); + } + + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + auto func_domain = domains_->DomainForCallee(call); // higher-order + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + if (domains_->AnyFree(func_domain)) { + // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) + // above. But for calls to primitives we may still need to force free domains to be + // defaulted. + VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + } + return ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // If the let-var device is still free force it to match the overall let. + auto let_domain = domains_->DomainFor(let); // may be higher-order + DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain); + ICHECK_NE(let_device_type, kInvalidDeviceType); + auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order + if (domains_->AnyFree(let_var_domain)) { + VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + domains_->SetDefault(let_var_domain, let_device_type); + VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + } + VisitExpr(let->var); + VisitExpr(let->value); + expr = let->body; + } + VisitExpr(expr); + } + + /*! \brief The module we are processing. */ + IRModule mod_; + /*! \brief The domains for all expressions. */ + std::unique_ptr domains_; + /*! \brief The default device type. */ + DLDeviceType default_device_type_; +}; + +/****** + ****** Phase 3 + ******/ + +/*! + * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every + * sub-expression in a module can be easily recovered by a later transformation using simple + * lexical scoping rules (e.g. for memory planning). + * + * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard + * any existing "device_copy" CallNodes which are no-ops. + * + * - Functions are given an "on_device" attribute bound to a FunctionOnDeviceAttrs to capture + * the device type for its parameters and result. + * + * - Additional "device_copy" CallNodes are inserted wherever there's a transition between + * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen + * where the original program explicitly allowed a transition using an "on_device" CallNode. + * That is, we do not not try to 'fix' a program with inconsistent devices. + * + * - Additional "on_device" CallNodes are inserted so that a later transform can discover + * the device for an arbitrary sub-expression by looking only for the lexically enclosing + * "on_device" CallNode or "on_device" function attribute. In particular, since function + * arguments and let-bound expressions can be on a device different from the function + * or let body itself we will insert "on_device" CallNodes to spell out any differences. This + * applies even to the argument to a "device_copy" CallNode, which may look pedantic but + * keeps downstream processing simple. The "on_device" calls should be removed before code gen, + * which is easily done on-the-fly. + */ +class DeviceCapturer : public ExprMutator { + public: + DeviceCapturer(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} + + IRModule Capture() { + VLOG_CONTEXT << "CaptureDevices"; + IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map); + for (const auto& pair : mod_->functions) { + VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; + result->Add(pair.first, Downcast(Mutate(pair.second))); + } + return result; + } + + private: + // Nothing interesting for VarNode, ConstantNode, GlobalVarNode and OpNode. + + Expr VisitExpr_(const TupleNode* tuple_node) final { + auto tuple = GetRef(tuple_node); + Array fields; + fields.reserve(tuple_node->fields.size()); + for (const auto& field : tuple_node->fields) { + fields.push_back(VisitChild(tuple, field)); + } + // TODO(mbs): Avoid copy + return Tuple(std::move(fields), tuple_node->span); + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + VLOG(1) << "capturing function:" << std::endl + << PrettyPrint(function) << std::endl + << "with domain:" << std::endl + << domains_->ToString(func_domain); + + // Gather the parameter and result device types for the "on_device" function attribute. + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + Array param_device_types; + param_device_types.reserve(function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType); + param_device_types.push_back(param_device_type); + } + + // Rewrite the body. Note that the body may have begun with an "on_device" so + // be prepared to insert a "device_copy". + Expr body = VisitChild( + /*lexical_device_type=*/result_device_type, + /*expected_device_type=*/result_device_type, + /*child_device_type=*/GetDeviceType(function_node->body), function_node->body); + + // TODO(mbs): Avoid copy + Function func = Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + return FunctionOnDevice(func, param_device_types, result_device_type); + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + DLDeviceType call_device_type = GetDeviceType(call); + + auto on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + // We're done with the original "on_device" calls and can pinch them out. + // Note that this step has already been simulated by GetDeviceType. + return VisitExpr(on_device_props.body); + } + + auto device_copy_props = GetDeviceCopyProps(call_node); + if (device_copy_props.body.defined()) { + DLDeviceType src_device_type = device_copy_props.src_dev_type; + ICHECK_EQ(call_device_type, device_copy_props.dst_dev_type); + if (call_device_type == src_device_type) { + // We can pinch out existing "device_copy" CallNodes if their source and destinations + // match. + return VisitExpr(device_copy_props.body); + } + // else: handle as for any other call. + } + + auto func_domain = domains_->DomainForCallee(call); // higher-order + VLOG(1) << "considering call:" << std::endl + << PrettyPrint(call) << std::endl + << "on device " << call_device_type << " with function domain:" << std::endl + << domains_->ToString(func_domain); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + + // The callee is on the current device. + Expr op = VisitChild( + /*lexical_device_type=*/call_device_type, + /*expected_device_type=*/call_device_type, + /*child_device_type=*/result_device_type, call_node->op); + + // Each argument can be on the device for the corresponding function parameter. However if + // any of those differ from the overall call device then wrap them in an "on_device" to + // help downstream transforms track devices lexically. + Array args; + args.reserve(call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), call->args.size()); + for (size_t i = 0; i < call_node->args.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType) + << "for parameter " << i << " for call:" << std::endl + << PrettyPrint(call); + args.push_back(VisitChild(/*lexical_device_type=*/call_device_type, + /*expected_device_type=*/param_device_type, + /*child_device_type=*/GetDeviceType(call_node->args[i]), + call_node->args[i])); + } + return call.CopyWith(op, args); + } + + Expr VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iterate through chained lets, provided they all agree on their device type. + DLDeviceType let_device_type = GetDeviceType(expr); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + if (GetDeviceType(inner_let) != let_device_type) { + // We have a device transition which needs to be handled. + break; + } + // The let-bound value can be on a different device than the overall let. However if those + // devices don't agree wrap the let-bound value in an "on_device" to help downstream + // transforms track devices lexically. + Expr value = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/GetDeviceType(inner_let_node->var), + /*child_device_type=*/GetDeviceType(inner_let_node->value), + inner_let_node->value); + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + Expr body = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/let_device_type, + /*child_device_type=*/GetDeviceType(expr), expr); + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, + /*span=*/std::get<2>(*itr)); + } + return body; + } + + Expr VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + Expr cond = VisitChild(ife, if_node->cond); + Expr true_branch = VisitChild(ife, if_node->true_branch); + Expr false_branch = VisitChild(ife, if_node->false_branch); + // TODO(mbs): Avoid copy + return If(cond, true_branch, false_branch, if_node->span); + } + + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + auto tuple_get_item = GetRef(tuple_get_item_node); + Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy + return TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + } + + Expr VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + Expr value = VisitChild(ref_create, ref_create_node->value); + // TODO(mbs): Avoid copy + return RefCreate(value, ref_create_node->span); + } + + Expr VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + Expr ref = VisitChild(ref_read, ref_read_node->ref); + // TODO(mbs): Avoid copy + return RefRead(ref, ref_read_node->span); + } + + Expr VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + Expr ref = VisitChild(ref_write, ref_write_node->ref); + Expr value = VisitChild(ref_write, ref_write_node->value); + // TODO(mbs): Avoid copy + return RefWrite(ref, value, ref_write_node->span); + } + + Expr VisitExpr_(const ConstructorNode* constructor_node) final { + auto constructor = GetRef(constructor_node); + // check we have a device type. + (void)GetDeviceType(constructor); + return constructor; + } + + Expr VisitExpr_(const MatchNode* match_node) final { + auto match = GetRef(match_node); + Expr data = VisitChild(match, match_node->data); + Array clauses; + clauses.reserve(match_node->clauses.size()); + for (const auto& clause : match_node->clauses) { + Pattern lhs = VisitPattern(clause->lhs); // actually a no-op, so we're not checking vars + Expr rhs = VisitChild(match, clause->rhs); + clauses.push_back(Clause(lhs, rhs)); + } + // TODO(mbs): Avoid copy + return Match(data, std::move(clauses), match_node->complete, match_node->span); + } + + DLDeviceType GetDeviceType(const Expr& expr) { + // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. + auto props = GetOnDeviceProps(expr); + Expr true_expr = props.body.defined() ? props.body : expr; + ICHECK(domains_->contains(true_expr)); + // If expr is higher order we'll return only the result domain's device type. + DLDeviceType device_type = domains_->ResultDeviceType(domains_->DomainFor(true_expr)); + ICHECK_NE(device_type, kInvalidDeviceType) + << "no device type was determined for expression:" << std::endl + << PrettyPrint(true_expr); + return device_type; + } + + /*! + * \brief Reconcile the \p child_device_type for \p child with both the \p expected_device_type + * (as required by the expression context the \p child is in) and the \p lexical_device_type + * (as a downstream transform would infer based only on lexically enclosing "on_device" + * CallNodes and function attributes.) Generally \p lexical_device_type and \p + * expected_device_type are the same by definition, but may differ in arguments to functions + * and let-bound expressions. + * + * If \p child_device_type differs from \p expected_device_type, wrap it as: + * \code + * device_copy(on_device(child', device_type=child_device_type), + * src_dev_type=child_device_type, dst_dev_type=expected_device_type) + * \endcode + * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the + * child. + * + * If \p expected_device_type differs from \p lexical_device_type, then (also) wrap + * the expression as: + * \code + * on_device(..., device_type=expected_device_type) + * \endcode + * + * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped + * by a "device_copy", even though those copies will generally all be to the same destination + * device. + */ + Expr VisitChild(DLDeviceType lexical_device_type, DLDeviceType expected_device_type, + DLDeviceType child_device_type, const Expr& child) { + ICHECK_NE(lexical_device_type, kInvalidDeviceType); + ICHECK_NE(expected_device_type, kInvalidDeviceType); + if (child->IsInstance()) { + // Primitive operators don't need to be rewritten and can have a different domain for + // each call site. + return child; + } + Expr result = VisitExpr(child); + if (child_device_type != expected_device_type) { + VLOG(1) << "creating " << DeviceCopyOp()->name << " from device type " << child_device_type + << " to device type " << expected_device_type << " for:" << std::endl + << PrettyPrint(result); + // Also wrap the child in an "on_device" so downstream transforms can track devices + // lexically. + result = OptOnDevice(result, child_device_type, /*is_fixed=*/true); + result = DeviceCopy(result, child_device_type, expected_device_type); + } + if (expected_device_type != lexical_device_type) { + VLOG(1) << "creating " << OnDeviceOp()->name << " for device type " << expected_device_type + << " for:" << std::endl + << PrettyPrint(result); + result = OptOnDevice(result, expected_device_type, /*is_fixed=*/true); + } + return result; + } + + /*! + * Common case of visiting a direct \p child of \p parent where by default the \p child + * is expected to be on the same device as the \p parent. + */ + Expr VisitChild(const Expr& parent, const Expr& child) { + DLDeviceType expected_device_type = GetDeviceType(parent); + DLDeviceType child_device_type = GetDeviceType(child); + return VisitChild(expected_device_type, expected_device_type, child_device_type, child); + } + + /*! \brief Module we are rewriting, so we can lookup global variables. */ + IRModule mod_; + /*! \brief Device domain for every expression from DeviceAnalyzer. */ + std::unique_ptr domains_; +}; + +/*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ +tvm::transform::Pass Rewrite() { + auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { + return Downcast(RewriteOnDevices().Mutate(f)); + }; + return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {}); +} + +/*! \brief Run the remaining phases. */ +tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { + return tvm::transform::CreateModulePass( + [=](IRModule mod, tvm::transform::PassContext pass_cnxt) -> IRModule { + // Collect the system of constraints for every sub-expression using existing "on_device" + // and "device_copy" calls. + std::unique_ptr domains = DeviceAnalyzer(mod).Analyze(); + VLOG(1) << "Domains after analysis:" << std::endl << domains->ToString(); + + // Choose sensible default devices for every sub-expression if otherwise unconstrained + // by existing "on_device" or "device_copy" calls. + domains = DeviceDefaulter(mod, std::move(domains), default_device_type).Default(); + VLOG(1) << "Domains after defaulting: " << std::endl << domains->ToString(); + + // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture + // the above map, and attach additional "on_device" attributes to all function + // definitions. + return DeviceCapturer(mod, std::move(domains)).Capture(); + }, + /*opt_level=*/0, "PlanDevicesCore", {}); +} + +} // namespace + +/****** + ****** Pass + ******/ + +TVM_DLL tvm::transform::Pass PlanDevices(DLDeviceType default_device_type) { + std::vector passes; + passes.emplace_back(Rewrite()); + passes.emplace_back(PlanDevicesCore(default_device_type)); + return tvm::transform::Sequential(std::move(passes), "PlanDevices"); +} + +TVM_REGISTER_GLOBAL("relay._transform.PlanDevices") + .set_body_typed([](const Device& default_device) { + return PlanDevices(default_device.device_type); + }); + +/****** + ****** Visitor/Mutator Helpers + ******/ + +DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { + auto props = GetOnDeviceProps(expr); + if (props.body.defined() && props.is_fixed) { + // Look through any fixed "on_device" annotations. + return props.device_type; + } + if (expr->IsInstance()) { + // Lookup variable binding. + auto itr = var_device_types_.find(Downcast(expr)); + if (itr == var_device_types_.end()) { + return kInvalidDeviceType; + } else { + return itr->second; + } + } + // Otherwise use the currently in-scope device type. + if (expr_device_types_.empty()) { + return kInvalidDeviceType; + } else { + return expr_device_types_.back(); + } +} + +void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } + +void LexicalOnDeviceMixin::ExitFunctionBody() { + ICHECK_GT(function_nesting_, 0); + --function_nesting_; +} + +void LexicalOnDeviceMixin::PushDeviceType(DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + expr_device_types_.emplace_back(device_type); +} + +void LexicalOnDeviceMixin::PopDeviceType() { + if (expr_device_types_.empty()) { + return; + } + expr_device_types_.pop_back(); +} + +void LexicalOnDeviceMixin::PushBoundVar(Var var, DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + ICHECK(var_device_types_.find(var) == var_device_types_.end()); + var_device_types_.emplace(std::move(var), device_type); +} + +void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { + auto itr = var_device_types_.find(var); + if (itr == var_device_types_.end()) { + return; + } + var_device_types_.erase(itr); +} + +void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } +} + +void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec). + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + PopBoundVar((*itr)->var); + PostVisitLet_(*itr); + } +} + +void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + ExprVisitor::VisitExpr_(function_node); +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const CallNode* call_node) { + ExprVisitor::VisitExpr_(call_node); +} + +void DeviceAwareExprVisitor::PreVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); +} + +void DeviceAwareExprVisitor::PostVisitLet_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PostVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + Expr result = DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + + return result; + } +} + +Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector> bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); + expr = inner_let_node->body; + } + + expr = VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* pre_let_node = std::get<3>(*itr); + PopBoundVar(pre_let_node->var); + Let post_let = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), + /*body=*/expr, /*span=*/std::get<2>(*itr)); + expr = PostVisitLet_(pre_let_node, post_let.get()); + } + return PostVisitLetBlock_(let_node, expr.as()); +} + +Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + Expr expr = VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + return OnDevice(expr, props.device_type, props.is_fixed); + } else { + return DeviceAwareVisitExpr_(call_node); + } +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return ExprMutator::VisitExpr_(function_node); +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const CallNode* call_node) { + return ExprMutator::VisitExpr_(call_node); +} + +void DeviceAwareExprMutator::PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ +} + +std::pair DeviceAwareExprMutator::PreVisitLetBinding_(const Var& var, + const Expr& value) { + return std::make_pair(Downcast(VisitExpr(var)), VisitExpr(value)); +} + +Expr DeviceAwareExprMutator::PostVisitLet_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +Expr DeviceAwareExprMutator::PostVisitLetBlock_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_planner.h b/src/relay/transforms/device_planner.h new file mode 100644 index 0000000000000..f8ecda48d220c --- /dev/null +++ b/src/relay/transforms/device_planner.h @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_SRC_RELAY_TRANSFORMS_DEVICE_PLANNER_H_ +#define TVM_SRC_RELAY_TRANSFORMS_DEVICE_PLANNER_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relay { +namespace transform { + +// PlanDevices() is declared in the public . + +/*! + * \brief Helper class for expression transformers which need to keep track of the device + * holding the results of expressions and bound variables. This is recovered from the + * "on_device" function attributes and fixed "on_device" CallNodes added by the PlanDevices + * pass. + * + * \sa \p DeviceAwareExpr{Visitor,Mutator}. + */ +class LexicalOnDeviceMixin { + protected: + /*! + * \brief Returns the device type on which the result of \p expr should/will be stored, assuming + * Push/Pop DeviceType/BoundVar have been correctly called. Returns \p kInvalidDeviceType if + * stack is empty and no bound vars have device types. + */ + DLDeviceType GetInScopeDeviceType(const Expr& expr) const; + + /*! \brief Indicate a function body is being entered. */ + void EnterFunctionBody(); + + /*! \brief Indicate a function body has been processed. */ + void ExitFunctionBody(); + + /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ + void PushDeviceType(const DLDeviceType device_type); + + /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ + void PopDeviceType(); + + /*! \brief Remember that \p var will be stored on \p device_type. Ignore if \p kInvalidDeviceType. + * + * CAUTION: Despite the name we don't support re-entering the same function body. + */ + void PushBoundVar(Var var, DLDeviceType device_type); + + /*! \brief Remove the binding for \p var to it's device type. Ignore if var is not bound. */ + void PopBoundVar(const Var& var); + + /*! + * \brief Returns the number of function definitions wrapping the currently visited expression. + */ + int function_nesting() const { return function_nesting_; } + + private: + /*! + * \brief The number of function bodies entered. Since many transforms need to distinguish global + * functions from local functions this supports the mixin's \p is_global() helper method. + */ + int function_nesting_ = 0; + + /*! + * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. + * When visiting an expression other than a variable we can assume the expression result is + * to be stored on device_type_.back(). + */ + std::vector expr_device_types_; + /*! + * \brief A map from in-scope variable to their device types. We may assume the variable is only + * ever bound to a value stored on this device at runtime. + */ + std::unordered_map + var_device_types_; +}; + +/*! \brief ExprVisitor which tracks devices. */ +class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { + public: + using ExprVisitor::VisitExpr_; + + void VisitExpr_(const FunctionNode* function_node) final; + void VisitExpr_(const LetNode* let_node) final; + void VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual void DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node); +}; + +/*! \brief ExprMutator which tracks devices. */ +class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { + public: + Expr VisitExpr_(const FunctionNode* function_node) final; + Expr VisitExpr_(const LetNode* let_node) final; + Expr VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual Expr DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual std::pair PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation just returns a reference to the post-visited node. + */ + virtual Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation returns reference to let node. + */ + virtual Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node); +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_SRC_RELAY_TRANSFORMS_DEVICE_PLANNER_H_ diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 7b3f2da716aac..ca9a286ff6a78 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -242,6 +242,10 @@ class ForwardPrep : private MixedModeVisitor { message_[key] = message; } } + + // We intended the following overrides on implementations from ExprVisitor. + using MixedModeVisitor::VisitExpr_; + // Visitor pattern override. void VisitExpr_(const TupleGetItemNode* op) final { MixedModeVisitor::VisitExpr_(op); } diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f1f7a95e33e80..88ef24ed1dcfd 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -31,8 +31,9 @@ #include #include "../../support/arena.h" -#include "pass_utils.h" -#include "pattern_utils.h" +#include "../attrs/annotation.h" +#include "./pass_utils.h" +#include "./pattern_utils.h" namespace tvm { namespace relay { @@ -159,12 +160,6 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } private: - /*! \brief allocator of all the internal node object */ - support::Arena* arena_; - // The output. - IndexedForwardGraph graph_; - // attribute equal comparator - StructuralEqual attr_equal_; // Update the message stored at the node. void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) { const tvm::Object* key = node.get(); @@ -367,6 +362,13 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); this->AddNode(op); } + + /*! \brief allocator of all the internal node object */ + support::Arena* arena_; + // The output. + IndexedForwardGraph graph_; + // attribute equal comparator + StructuralEqual attr_equal_; }; IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { @@ -847,12 +849,6 @@ class FuseMutator : private MixedModeMutator { return var; } }; - /*! \brief Internal arena. */ - support::Arena arena_; - /*! \brief The group assignment map. */ - std::unordered_map gmap_; - /* \brief Internal group information map. */ - std::unordered_map ginfo_; // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { @@ -1013,6 +1009,13 @@ class FuseMutator : private MixedModeMutator { }); LOG(INFO) << "Dump of group info:\n" << text; } + + /*! \brief Internal arena. */ + support::Arena arena_; + /*! \brief The group assignment map. */ + std::unordered_map gmap_; + /* \brief Internal group information map. */ + std::unordered_map ginfo_; }; Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) { @@ -1026,9 +1029,10 @@ Pass FuseOps(int fuse_opt_level) { [=](Function f, IRModule m, PassContext pc) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps)); - return Downcast(FuseOps(f, opt_level, max_fuse_depth.value(), m)); + Function result = Downcast(FuseOps(f, opt_level, max_fuse_depth.value(), m)); + return result; }; - return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); + return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps); diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index c75f18f6831c5..56875f6c16a16 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -65,7 +65,7 @@ class LetList { */ Var Push(Var pv, Expr expr) { ICHECK(!used_); - ICHECK(WellFormed(expr)); + ICHECK(WellFormed(expr)) << "expression:" << std::endl << PrettyPrint(expr); lets_.emplace_back(std::make_pair(pv, expr)); return pv; } diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 657e2c3924555..c298bf40e8347 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -41,13 +41,16 @@ #include #include +#include "../attrs/annotation.h" +#include "../attrs/device_copy.h" #include "../backend/te_compiler.h" #include "../backend/te_compiler_cache.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" +#include "./device_planner.h" +#include "./let_list.h" #include "./pass_utils.h" -#include "let_list.h" -#include "pattern_utils.h" +#include "./pattern_utils.h" using namespace tvm::runtime; using namespace tvm::relay::tec; @@ -55,9 +58,6 @@ using namespace tvm::relay::tec; namespace tvm { namespace relay { -using AnalysisResultMap = - std::unordered_map; - inline Constant MakeConstant(const std::vector& value) { return MakeConstantTensor(DataType::Int(64), {static_cast(value.size())}, value); } @@ -85,29 +85,17 @@ bool IsReshapeOnly(const Expr& expr) { return false; } -class DialectRewriter : public ExprMutator { +class DialectRewriter : public transform::DeviceAwareExprMutator { public: - DialectRewriter(const Target& target_host, const AnalysisResultMap& context_analysis_map) - : target_host_(target_host), context_analysis_map_(context_analysis_map) {} - - // Get the device of an expression. - Device GetDevice(const Expr& expr) const { - auto it = context_analysis_map_.find(expr); - CHECK(it != context_analysis_map_.end()) << "Cannot find expr in the context analysis map:\n" - << AsText(expr, false); - return it->second; - } + DialectRewriter(const Target& target_host) : target_host_(target_host) {} - Function Rewrite(const Function& expr) { - auto ret = ExprMutator::Mutate(expr); - return Downcast(ret); - } + Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } Expr VisitExpr_(const TupleNode* tn) final { LetList& scope = scopes_.back(); Array new_fields; for (auto field : tn->fields) { - auto new_field = ExprMutator::Mutate(field); + auto new_field = Mutate(field); if (new_field->IsInstance()) { Var const_var("const", Type(nullptr)); new_field = scope.Push(const_var, new_field); @@ -117,32 +105,38 @@ class DialectRewriter : public ExprMutator { return Tuple(new_fields); } - Expr VisitExpr_(const LetNode* ln) final { - scopes_.emplace_back(); + void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); } - const LetNode* let = ln; - Expr body; - while (let) { - auto new_value = ExprMutator::Mutate(let->value); - scopes_.back().Push(let->var, new_value); - body = let->body; - let = body.as(); - } + std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { + Expr new_value = Mutate(value); + scopes_.back().Push(var, new_value); + // Since we always need a let block on which to bind sub-expressions the rewritten bindings + // are tracked in the current scopes. But return the rewritten binding anyway. + return {var, new_value}; + } - CHECK(body.defined()); - auto new_body = ExprMutator::Mutate(body); + Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node) final { + // The current scope has captured all the rewritten let-binding, as well as any additional + // bindings we needed to add. All we need is the rewritted body. + Expr new_body = post_let_node->body; + while (const auto* inner_let_node = new_body.as()) { + new_body = inner_let_node->body; + } auto ret = scopes_.back().Get(new_body); scopes_.pop_back(); return ret; } - Expr VisitExpr_(const CallNode* cn) final { + Expr DeviceAwareVisitExpr_(const CallNode* cn) final { + Call call = GetRef(cn); + DLDeviceType device_type = GetInScopeDeviceType(call); if (IsPrimitive(cn)) { // Because we are in ANF we do not need to visit the arguments. + // TODO(mbs): But does so anyway... LetList& scope = scopes_.back(); std::vector new_args; for (const auto& it : cn->args) { - new_args.push_back(ExprMutator::Mutate(it)); + new_args.push_back(Mutate(it)); } Tuple ins(new_args); @@ -170,30 +164,36 @@ class DialectRewriter : public ExprMutator { return DeviceCopy(new_args[0], copy_attr->src_dev_type, copy_attr->dst_dev_type); } else if (IsDynamic(ret_type)) { Function func = Downcast(cn->op); - return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type); + // TODO(mbs): Device id is always zero. + Device device{device_type, /*device_id=*/0}; + return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type, device); } else { // Handle the static case Array outs; for (size_t i = 0; i < out_types.size(); ++i) { - Device dev = GetDevice(GetRef(cn)); - auto out = MakeStaticAllocation(&scope, out_types[i], dev, std::to_string(i)); + DLDeviceType device_type = GetInScopeDeviceType(GetRef(cn)); + // TODO(mbs): Device id is always zero. + Device device{device_type, /*device_id=*/0}; + auto out = MakeStaticAllocation(&scope, out_types[i], device, std::to_string(i)); outs.push_back(out); } Tuple output(outs); + // TODO(mbs): Capture device in attributes. Expr invoke = InvokeTVMOp(cn->op, ins, output); - scope.Push(invoke); + scope.Push(OnDevice(invoke, device_type, /*is_fixed=*/true)); return ToTupleType(ret_type, std::vector(output->fields.begin(), output->fields.end())); } } else { - return ExprMutator::VisitExpr_(cn); + return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(cn); } } private: // Insert a device copy node. Expr DeviceCopy(const Expr& inp, int src_dev, int dst_dev) { - return ExprMutator::Mutate(relay::DeviceCopy(inp, src_dev, dst_dev)); + return Mutate(relay::DeviceCopy(inp, static_cast(src_dev), + static_cast(dst_dev))); } // Check if a call invokes a primitive function. @@ -210,9 +210,9 @@ class DialectRewriter : public ExprMutator { if (const auto* fn = expr.as()) { auto body = fn->body; const CallNode* call = body.as(); - return call && call->op == Op::Get("device_copy"); + return call && call->op == device_copy_op_; } else if (const CallNode* cn = expr.as()) { - return cn->op == Op::Get("device_copy"); + return cn->op == device_copy_op_; } else { return false; } @@ -297,17 +297,17 @@ class DialectRewriter : public ExprMutator { if (state == 2) { std::vector exprs = FromTupleType(ty, arg); for (size_t j = 0; j < exprs.size(); ++j) { - Expr sh_of = ExprMutator::Mutate(ShapeOf(exprs[j])); + Expr sh_of = Mutate(ShapeOf(exprs[j])); Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr)); shape_func_ins.push_back(scope->Push(in_shape_var, sh_of)); input_pos++; } is_inputs.push_back(0); } else if (state == 1) { - auto new_arg = ExprMutator::Mutate(arg); - auto dev = GetDevice(arg); - if (dev.device_type != cpu_dev.device_type) { - new_arg = DeviceCopy(new_arg, dev.device_type, cpu_dev.device_type); + auto new_arg = Mutate(arg); + DLDeviceType device_type = GetInScopeDeviceType(arg); + if (device_type != cpu_dev.device_type) { + new_arg = DeviceCopy(new_arg, device_type, cpu_dev.device_type); } Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); shape_func_ins.push_back(scope->Push(in_shape_var, new_arg)); @@ -339,10 +339,9 @@ class DialectRewriter : public ExprMutator { // Generate the code for invoking a TVM op with a dynamic shape. Expr DynamicInvoke(LetList* scope, const Function& func, const Tuple& ins, const std::vector& new_args, const std::vector& out_types, - const Type& ret_type) { + const Type& ret_type, Device dev) { auto out_shapes = EmitShapeFunc(scope, func, new_args); std::vector storages; - auto func_dev = GetDevice(func); CHECK_EQ(out_shapes.size(), out_types.size()); for (size_t i = 0; i < out_shapes.size(); ++i) { auto out_shape = out_shapes[i]; @@ -350,7 +349,7 @@ class DialectRewriter : public ExprMutator { auto size = ComputeStorageInRelay(out_shape, out_type); auto alignment = ComputeAlignment(out_type->dtype); Var sto_var("storage_" + std::to_string(i), Type(nullptr)); - auto val = AllocStorage(size, alignment, func_dev, out_type->dtype); + auto val = AllocStorage(size, alignment, dev, out_type->dtype); storages.push_back(scope->Push(sto_var, val)); } @@ -365,8 +364,9 @@ class DialectRewriter : public ExprMutator { } Tuple tuple_outs(outs); + // TODO(mbs): Capure device in invoke attributes. auto invoke = InvokeTVMOp(func, ins, tuple_outs); - scope->Push(invoke); + scope->Push(OnDevice(invoke, dev.device_type, /*is_fixed=*/true)); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); } @@ -391,8 +391,9 @@ class DialectRewriter : public ExprMutator { } private: + const Op& device_copy_op_ = Op::Get("device_copy"); + Target target_host_; - AnalysisResultMap context_analysis_map_; std::vector scopes_; runtime::DataType compute_dtype_ = runtime::DataType::Int(64); @@ -411,27 +412,11 @@ Pass ManifestAlloc(Target target_host, Map targets) { mod->ImportFromStd("core.rly"); mod = relay::transform::InferType()(mod); - Device fallback_dev; - if (targets.size() > 1) { - auto pass_ctx = PassContext::Current(); - Optional opt_fallback_dev_type = - pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); - auto fallback_dev_type = opt_fallback_dev_type.value(); - CHECK_GT(fallback_dev_type->value, 0U); - fallback_dev.device_type = static_cast(fallback_dev_type->value); - fallback_dev.device_id = 0; - } else { - const auto& it = targets.begin(); - fallback_dev.device_type = static_cast((*it).first->value); - fallback_dev.device_id = 0; - } - auto ca = ContextAnalysis(mod, fallback_dev); - auto glob_funcs = mod->functions; for (const auto& it : glob_funcs) { if (auto* func_node = it.second.as()) { auto func = GetRef(func_node); - auto rewriter = DialectRewriter(target_host, ca); + auto rewriter = DialectRewriter(target_host); auto updated_func = rewriter.Rewrite(func); mod->Update(it.first, updated_func); diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index bb2f268a23d7d..91e86e130d6b4 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -36,7 +36,8 @@ #include #include "../analysis/dependency_graph.h" -#include "let_list.h" +#include "../attrs/annotation.h" +#include "./let_list.h" namespace tvm { namespace relay { @@ -118,8 +119,11 @@ inline Expr TransformF(const std::function& func, const Expr& * if so, the compute cost of the expression is bounded so it can be copy without graph mode. */ inline bool IsAtomic(const Expr& e) { - return e.as() || e.as() || e.as() || e.as() || - e.as(); // Constant is always by reference. + auto props = GetOnDeviceProps(e); + Expr true_expr = props.body.defined() ? props.body : e; + return true_expr.as() || true_expr.as() || true_expr.as() || + true_expr.as() || + true_expr.as(); // Constant is always by reference. } /*! @@ -222,57 +226,10 @@ std::pair CalcScope(const DependencyGraph& dg); */ Scope LCA(Scope lhs, Scope rhs); -/* Special care is needed to handle local recursion. - * Fill additionally take a (possibly null) Var argument, - * If it is not null, Fill is required to bind the transformed result to that var. - */ -class Fill : ExprFunctor { - public: - static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope); - - // For basic block normal form, bind expressions only if the original expression's - // scope should be lifted - static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, - NodeScopeMap* node_scope, ExprSet* lifted); - - private: - const DependencyGraph& dg_; - NodeScopeMap* node_scope_ = nullptr; - std::unordered_map memo; - // a set of Expressions to include for let bindings. If set to nullptr - // all Exprs will be pushed to the let list. - ExprSet* include_set_ = nullptr; - - Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set) - : dg_(dg), node_scope_(node_scope), include_set_(include_set) {} - - Scope GetScope(const Expr& e); - Scope GetSubScope(const Expr& e, size_t i); - - Expr VisitExpr(const Expr& e, const Var& v) final; - Expr VisitExpr(const Expr& e); - - Expr Atomic(const Expr& e, const Var& v); - // Bind expression `now` to var `v` if the original expression is in the include set, or if - // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly. - Expr Compound(const Expr& orig, const Expr& now, const Var& v); - - Expr VisitExpr_(const CallNode* c, const Var& v) final; - Expr VisitExpr_(const TupleNode* t, const Var& v) final; - Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final; - Expr VisitExpr_(const RefCreateNode* r, const Var& v) final; - Expr VisitExpr_(const RefReadNode* r, const Var& v) final; - Expr VisitExpr_(const RefWriteNode* r, const Var& v) final; - Expr VisitExpr_(const IfNode* i, const Var& v) final; - Expr VisitExpr_(const FunctionNode* f, const Var& v) final; - Expr VisitExpr_(const LetNode* l, const Var& v) final; - Expr VisitExpr_(const ConstantNode* c, const Var& v) final; - Expr VisitExpr_(const VarNode* vn, const Var& v) final; - Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final; - Expr VisitExpr_(const OpNode* op, const Var& v) final; - Expr VisitExpr_(const ConstructorNode* c, const Var& v) final; - Expr VisitExpr_(const MatchNode* m, const Var& v) final; -}; +// For basic block normal form. +Expr ToBasicBlockNormalFormAux(const Expr& e); + +// ToANormalForm for expressions and as a Pass are declared in transform.h } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 920ac153b63da..692ef3c9f557a 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -696,10 +696,6 @@ static inline Expr BroadCastTo(Expr data, Array shape) { return MakeBroadCastTo(data, CheckConstantShapeArrayInteger(shape)); } -Expr StopFusion(Expr data); - -Expr CastHint(Expr data, DataType dtype); - } // namespace relay } // namespace tvm #endif // TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index 70d37d822d71e..d2d5da4947df8 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -23,7 +23,8 @@ #include #include -#include "pattern_utils.h" +#include "../attrs/annotation.h" +#include "./pattern_utils.h" namespace tvm { namespace relay { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 91e8d90c1232f..6e4a7057524af 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -30,8 +30,10 @@ #include "../../support/arena.h" #include "../analysis/dependency_graph.h" -#include "let_list.h" -#include "pass_utils.h" +#include "../attrs/annotation.h" +#include "./device_planner.h" +#include "./let_list.h" +#include "./pass_utils.h" namespace tvm { namespace relay { @@ -94,189 +96,309 @@ std::pair CalcScope(const DependencyGraph& dg) { return std::make_pair(expr_scope, lifted_exprs); } -Expr Fill::ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) { - Fill fi(dg, node_scope, nullptr); - return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); -} +namespace { -// For basic block normal form, bind expressions only if the original expression's scope -// should be lifted -Expr Fill::ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, - NodeScopeMap* node_scope, ExprSet* lifted) { - Fill fi(dg, node_scope, lifted); - auto var = fi.VisitExpr(e); - return fi.GetScope(e)->let_list->Get(var); -} +/* Special care is needed to handle local recursion. + * Fill additionally take a (possibly null) Var argument, + * If it is not null, Fill is required to bind the transformed result to that var. + * + * ToANormalForm and PlanDevices + * ----------------------------- + * If PlanDevices has run this transform must respect the lexical scoping rules for the residual + * "on_device" calls. Eg: + * \code + * on_device(add(subtract(x, y), add(y, z)), device_type=2, is_fixed=true) + * ==> + * let %x0 = on_device(subtract(x, y), device_type=2, is_fixed=true) + * let %x1 = on_device(add(y, z), device_type=2, is_fixed=true) + * let %x2 = on_device(add(%x0, %x1), device_type=2, is_fixed=true) + * %x2 + * \endcode + * + * In addition to conversion to ANF this pass is also handling hoisting implicitly shared + * sub-expressions to the inner-most scope common to all their uses: + * \code + * on_device( + * if y { + * on_device(%0, device_type=2, is_fixed=true) + * } else { + * on_device(subtract(%0, b), device_type=2, is_fixed=true) + * }, + * device_type=1, is_fixed=true) + * (where %0 = add(a, b)) + * ==> + * let %x0 = on_device(add(a, b), device_type=2, is_fixed=true); + * on_device( + * if y { + * on_device(%x0, device_type=2, is_fixed=true) + * } else { + * let %x1 = on_device(subtract(%x0, b), device_type=2, is_fixed=true); + * %x1 + * }, + * device_type=1, is_fixed=true) + * \endcode + * Though the PlanDevices has already avoided inserting "on_device" calls where they are redundant + * due to lexical scope, it's fiddly to do the same in this pass since the notion of 'scope' is + * now determined by the scope map. So we'll just insert them mechanically on every let-binding. + * + * TODO(mbs): Rewrite to derive from DeviceAwareExprMutator and not track device types + * explicitly. It's easy to get rid of the need for the extra var argument on VisitExpr by shifting + * the recursion a '1/2 step' to return a possibly compound expression who's inner expressions are + * all atomic. However the use of the scope map is currently subtle enough I want to leave it + * alone for now. + */ +class Fill : ExprFunctor, private transform::LexicalOnDeviceMixin { + public: + static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) { + Fill fi(dg, node_scope, nullptr); + return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); + } + + // For basic block normal form, bind expressions only if the original expression's scope + // should be lifted + static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, + NodeScopeMap* node_scope, ExprSet* lifted) { + Fill fi(dg, node_scope, lifted); + return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); + } -Scope Fill::GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } + private: + Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set) + : dg_(dg), node_scope_(node_scope), include_set_(include_set) {} -Scope Fill::GetSubScope(const Expr& e, size_t i) { - DependencyGraph::Node* n = dg_.expr_node.at(e); - auto h = n->children.head; - while (i != 0) { + Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } + + Scope GetSubScope(const Expr& e, size_t i) { + DependencyGraph::Node* n = dg_.expr_node.at(e); + auto h = n->children.head; + while (i != 0) { + ICHECK(h); + --i; + h = h->next; + } ICHECK(h); - --i; - h = h->next; + return node_scope_->at(h->value); } - ICHECK(h); - return node_scope_->at(h->value); -} -Expr Fill::VisitExpr(const Expr& e, const Var& v) { - if (memo.count(e) == 0) { - memo.insert({e, ExprFunctor::VisitExpr(e, v)}); - } else if (v.defined()) { - GetScope(e)->let_list->Push(v, memo.at(e)); - } - auto ret = memo.at(e); - // if no include_set is specified, every expression should be atomic. - if (include_set_ == nullptr) ICHECK(IsAtomic(ret)); - return ret; -} + Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } -Expr Fill::VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } + Expr VisitExpr(const Expr& e, const Var& v) { + if (memo.count(e) == 0) { + memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + } else if (v.defined()) { + GetScope(e)->let_list->Push(v, memo.at(e)); + } + auto ret = memo.at(e); + // if no include_set is specified, every expression should be atomic. + // TODO(mbs): Note that Constants must be let-bound even though they are considered 'atomic' + // by this test. + if (include_set_ == nullptr && function_nesting() > 0) { + ICHECK(IsAtomic(ret)) << "expression:" << std::endl << PrettyPrint(ret); + } + return ret; + } -Expr Fill::Atomic(const Expr& e, const Var& v) { - return v.defined() ? GetScope(e)->let_list->Push(v, e) : e; -} + Expr Atomic(const Expr& e, const Var& v) { + Expr annotated_expr = OptOnDevice(e, GetInScopeDeviceType(e), /*is_fixed=*/true); + return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr; + } -// Bind expression `now` to var `v` if the original expression is in the include set, or if -// v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly -Expr Fill::Compound(const Expr& orig, const Expr& now, const Var& v) { - Var var = v.defined() ? v : Var(String("x"), Type()); - bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); - if (!v.defined() && not_included) { - return now; - } else { - return GetScope(orig)->let_list->Push(var, now); + // Bind expression `now` to var `v` if the original expression is in the include set, or if + // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly + Expr Compound(const Expr& orig, const Expr& now, const Var& v) { + Expr annotated_expr = OptOnDevice(now, GetInScopeDeviceType(orig), /*is_fixed=*/true); + Var var = v.defined() ? v : Var(String("x"), Type()); + bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); + if (!v.defined() && not_included) { + return annotated_expr; + } else { + return GetScope(orig)->let_list->Push(var, annotated_expr); + } } -} -Expr Fill::VisitExpr_(const CallNode* c, const Var& v) { - Expr e = GetRef(c); - std::vector args; - for (const auto& a : c->args) { - args.push_back(VisitExpr(a)); + Expr VisitExpr_(const CallNode* c, const Var& v) final { + auto props = GetOnDeviceProps(c); + if (props.body.defined() && props.is_fixed) { + // Keep track of expression device type for lexically enclosing sub-expressions. + PushDeviceType(props.device_type); + Expr body = VisitExpr(props.body, v); + // We are done with this sub-expression. + PopDeviceType(); + return body; + } + + Expr e = GetRef(c); + std::vector args; + for (const auto& a : c->args) { + args.push_back(VisitExpr(a)); + } + return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); } - return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); -} -Expr Fill::VisitExpr_(const TupleNode* t, const Var& v) { - Expr e = GetRef(t); - std::vector fields; - for (const auto& a : t->fields) { - fields.push_back(VisitExpr(a)); + Expr VisitExpr_(const TupleNode* t, const Var& v) final { + Expr e = GetRef(t); + std::vector fields; + for (const auto& a : t->fields) { + fields.push_back(VisitExpr(a)); + } + return Compound(e, Tuple(fields), v); } - return Compound(e, Tuple(fields), v); -} -Expr Fill::VisitExpr_(const TupleGetItemNode* t, const Var& v) { - Expr e = GetRef(t); - return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v); -} + Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { + Expr e = GetRef(t); + return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v); + } -Expr Fill::VisitExpr_(const RefCreateNode* r, const Var& v) { - Expr e = GetRef(r); - return Compound(e, RefCreate(VisitExpr(r->value)), v); -} + Expr VisitExpr_(const RefCreateNode* r, const Var& v) final { + Expr e = GetRef(r); + return Compound(e, RefCreate(VisitExpr(r->value)), v); + } -Expr Fill::VisitExpr_(const RefReadNode* r, const Var& v) { - Expr e = GetRef(r); - return Compound(e, RefRead(VisitExpr(r->ref)), v); -} + Expr VisitExpr_(const RefReadNode* r, const Var& v) final { + Expr e = GetRef(r); + return Compound(e, RefRead(VisitExpr(r->ref)), v); + } -Expr Fill::VisitExpr_(const RefWriteNode* r, const Var& v) { - Expr e = GetRef(r); - return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v); -} + Expr VisitExpr_(const RefWriteNode* r, const Var& v) final { + Expr e = GetRef(r); + return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v); + } -Expr Fill::VisitExpr_(const IfNode* i, const Var& v) { - Expr e = GetRef(i); - Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)), - GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch))); - return Compound(e, ret, v); -} + Expr VisitExpr_(const IfNode* i, const Var& v) final { + Expr e = GetRef(i); + Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)), + GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch))); + return Compound(e, ret, v); + } -Expr Fill::VisitExpr_(const FunctionNode* f, const Var& v) { - Expr e = GetRef(f); - Expr ret; - if (f->HasNonzeroAttr(attr::kPrimitive)) { - ret = e; - } else { - ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, - f->type_params, f->attrs); + Expr VisitExpr_(const FunctionNode* f, const Var& v) final { + Expr e = GetRef(f); + Expr ret; + if (f->HasNonzeroAttr(attr::kPrimitive)) { + ret = e; + } else { + // Keep track of expression and bound variable device types for lexically enclosing + // sub-expressions. + PushDeviceType(GetFunctionResultDeviceType(f)); + for (size_t i = 0; i < f->params.size(); ++i) { + PushBoundVar(f->params[i], GetFunctionParamDeviceType(f, i)); + } + EnterFunctionBody(); + ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, + f->type_params, f->attrs); + // We are done with this function. + ExitFunctionBody(); + for (size_t i = 0; i < f->params.size(); ++i) { + PopBoundVar(f->params[i]); + } + PopDeviceType(); + } + if (function_nesting() == 0) { + ICHECK(!v.defined()); + // This is a global function which can be bound directly in the module. + return ret; + } else { + // This is a local function which must be let-bound. + return Compound(e, ret, v); + } } - return Compound(e, ret, v); -} -Expr Fill::VisitExpr_(const LetNode* l, const Var& v) { - Expr e = GetRef(l); - VisitExpr(l->value, l->var); - Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); - return Compound(e, ret, v); -} + Expr VisitExpr_(const LetNode* l, const Var& v) final { + Expr e = GetRef(l); + // Keep track of bound variable device types for lexically enclosing sub-expressions. + PushBoundVar(l->var, GetInScopeDeviceType(l->value)); + VisitExpr(l->value, l->var); + Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); + // We are done with these sub-expressions. + PopBoundVar(l->var); + return Compound(e, ret, v); + } -Expr Fill::VisitExpr_(const ConstantNode* c, const Var& v) { - Expr e = GetRef(c); - return Compound(e, e, v); -} + Expr VisitExpr_(const ConstantNode* c, const Var& v) final { + Expr e = GetRef(c); + return Compound(e, e, v); + } -Expr Fill::VisitExpr_(const VarNode* vn, const Var& v) { - Expr e = GetRef(vn); - return Atomic(e, v); -} + Expr VisitExpr_(const VarNode* vn, const Var& v) final { + Expr e = GetRef(vn); + return Atomic(e, v); + } -Expr Fill::VisitExpr_(const GlobalVarNode* gvn, const Var& v) { - GlobalVar gv = GetRef(gvn); - return Atomic(gv, v); -} + Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { + GlobalVar gv = GetRef(gvn); + return Atomic(gv, v); + } -Expr Fill::VisitExpr_(const OpNode* op, const Var& v) { - Expr e = GetRef(op); - return Atomic(e, v); -} + Expr VisitExpr_(const OpNode* op, const Var& v) final { + Expr e = GetRef(op); + return Atomic(e, v); + } -Expr Fill::VisitExpr_(const ConstructorNode* c, const Var& v) { - Expr e = GetRef(c); - return Atomic(e, v); -} + Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { + Expr e = GetRef(c); + return Atomic(e, v); + } -Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) { - Expr e = GetRef(m); - Expr data = VisitExpr(m->data); - std::vector clauses; - for (const Clause& c : m->clauses) { - clauses.push_back( - Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs)))); + Expr VisitExpr_(const MatchNode* m, const Var& v) final { + Expr e = GetRef(m); + Expr data = VisitExpr(m->data); + std::vector clauses; + for (const Clause& c : m->clauses) { + clauses.emplace_back(c->lhs, + GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs))); + } + return Compound(e, Match(data, clauses, m->complete), v); } - return Compound(e, Match(data, clauses, m->complete), v); -} -IRModule ToANormalForm(const IRModule& m) { - DLOG(INFO) << "ToANF:" << std::endl << m; + const DependencyGraph& dg_; + NodeScopeMap* node_scope_ = nullptr; + std::unordered_map memo; + // a set of Expressions to include for let bindings. If set to nullptr + // all Exprs will be pushed to the let list. + ExprSet* include_set_ = nullptr; +}; +IRModule ModuleToANormalForm(const IRModule& m) { tvm::Map updates; auto funcs = m->functions; for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0); if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; + Function func = GetRef(n); + Function ret = Downcast(transform::ToANormalForm(func)); + ICHECK_EQ(FreeVars(ret).size(), 0) << "rewritten:" << std::endl + << PrettyPrint(ret) << std::endl + << "should not have free vars: " << FreeVars(ret); + VLOG(1) << "rewritten:" << std::endl + << PrettyPrint(func) << std::endl + << "to ANF:" << std::endl + << PrettyPrint(ret); + updates.Set(it.first, ret); } - Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second); - ICHECK_EQ(FreeVars(ret).size(), 0) - << AsText(ret) << "should not has free vars: " << FreeVars(ret); - updates.Set(it.first, Downcast(ret)); } for (auto pair : updates) { m->Add(pair.first, pair.second, true); } - DLOG(INFO) << "ToANF: transformed" << std::endl << m; - return m; } +} // namespace + +Expr ToBasicBlockNormalFormAux(const Expr& e) { + // calculate all the dependency between nodes. + support::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + /* The scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all incoming edge. + * We also record the set of expressions whose scope is lifted. + */ + std::pair scopes = CalcScope(dg); + return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second); +} + namespace transform { Expr ToANormalForm(const Expr& e) { @@ -307,7 +429,7 @@ Expr ToANormalForm(const Expr& e) { Pass ToANormalForm() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); }; + [=](IRModule m, PassContext pc) { return ModuleToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 8e952d60b8b75..931543d2640c0 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -30,48 +30,36 @@ #include "../../support/arena.h" #include "../analysis/dependency_graph.h" -#include "let_list.h" -#include "pass_utils.h" +#include "./pass_utils.h" namespace tvm { namespace relay { -Expr ToBasicBlockNormalFormAux(const Expr& e) { - // calculate all the dependency between nodes. - support::Arena arena; - DependencyGraph dg = DependencyGraph::Create(&arena, e); - /* The scope of the whole expr is global. - * The scope of any subexpr, is the lowest common ancestor of all incoming edge. - * We also record the set of expressions whose scope is lifted. - */ - std::pair scopes = CalcScope(dg); - return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second); -} - IRModule ToBasicBlockNormalForm(const IRModule& mod) { - DLOG(INFO) << "ToBBlock:" << std::endl << mod; - // Create a new module by shallow copy. - IRModule mod_ = mod->ShallowCopy(); + IRModule new_mod = mod->ShallowCopy(); tvm::Map updates; - auto funcs = mod_->functions; + auto funcs = new_mod->functions; for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables"; if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; + Function func = GetRef(n); + Function ret = Downcast(ToBasicBlockNormalFormAux(func)); + VLOG(1) << "rewritten:" << std::endl + << PrettyPrint(func) << std::endl + << "to BasicBlockANF:" << std::endl + << PrettyPrint(ret); + updates.Set(it.first, Downcast(ret)); } - Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second); - updates.Set(it.first, Downcast(ret)); } for (auto pair : updates) { - mod_->Add(pair.first, pair.second, true); + new_mod->Add(pair.first, pair.second, true); } - DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod_; - - return mod_; + return new_mod; } bool BasicBlockNormalFormCheck(const Expr& e) { diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 6c2371716b167..5ca6d86b1d52f 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -486,8 +486,9 @@ class TypeInferencer : private ExprFunctor, if (type_args.size() > fn_ty_node->type_params.size()) { this->EmitFatal(Diagnostic::Error(call->span) << "Incorrect number of type args in " << call->span << ": " - << "Expected " << fn_ty_node->type_params.size() << "but got " - << type_args.size()); + << "Expected " << fn_ty_node->type_params.size() << " but got " + << type_args.size() << " for call:\n" + << PrettyPrint(GetRef(call))); } for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) { type_args.push_back(IncompleteType(TypeKind::kType)); @@ -824,7 +825,6 @@ Pass InferType() { auto pass_info = PassInfo(0, "InferType", {}); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { - DLOG(INFO) << "tvm::relay::transform::InferType"; // Execute the pass function and return a new module. IRModule updated_mod = mod->ShallowCopy(); diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 968a4488bbcfe..8db89c59a85d7 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -272,7 +272,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ dtype.code = static_cast(dtype_code); dtype.bits = static_cast(dtype_bits); dtype.lanes = static_cast(dtype_lanes); - Device dev; + tvm::Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; auto ndarray = NDArray::Empty(ShapeTuple(shape, shape + ndim), dtype, dev); @@ -286,7 +286,7 @@ TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body([](TVMArgs args, int ndim = args[1]; ShapeTuple shape(shape_ptr, shape_ptr + ndim); DataType dtype = args[2]; - Device dev = args[3]; + tvm::Device dev = args[3]; Optional mem_scope = args[4]; auto ndarray = NDArray::Empty(shape, dtype, dev, mem_scope); *ret = ndarray; diff --git a/src/runtime/vm/serialize_utils.h b/src/runtime/vm/serialize_utils.h index b4a10806caaf5..cbcdb1bdfa161 100644 --- a/src/runtime/vm/serialize_utils.h +++ b/src/runtime/vm/serialize_utils.h @@ -59,13 +59,13 @@ struct VMFunctionSerializer { /*! \brief The parameters of the VMFunction. */ std::vector params; /*! \brief The device type of each parameter of the VMFunction. */ - std::vector params_device_type; + std::vector params_device_type; VMFunctionSerializer() = default; VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, const std::vector& params, - const std::vector& params_device_type) + const std::vector& params_device_type) : name(name), register_file_size(register_file_size), num_instructions(num_instructions), diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 4df013baa2fb8..45c0ec29ea445 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -233,7 +233,7 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = offset; i < args.size(); ++i) { - Index device_type = vm_func.params_device_type[i - offset]; + DLDeviceType device_type = vm_func.params_device_type[i - offset]; Device dev = GetDevice(device_type); if (args[i].type_code() == kTVMDLTensorHandle) { diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index c33bd57922421..81249eafdb10d 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -44,12 +44,12 @@ def check_graph_executor( with tvm.transform.PassContext(opt_level=opt_level, config=config): graph_executor_factory = relay.build(func, target, params=params) - contexts = [tvm.cpu(0), tvm.device(device)] + devices = [tvm.cpu(0), tvm.device(device)] graph_json = json.loads(graph_executor_factory.graph_json) if "device_index" in graph_json["attrs"]: device_index = graph_json["attrs"]["device_index"][1] assert device_index == expected_index - mod = graph_executor.GraphModule(graph_executor_factory["default"](*contexts)) + mod = graph_executor.GraphModule(graph_executor_factory["default"](*devices)) mod.run() res = mod.get_output(0).numpy() tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) @@ -267,8 +267,11 @@ def expected(): return func def check_storage_and_device_types(): + default_device_type = 3; func = annotated() - func = run_opt_pass(func, [transform.RewriteAnnotatedOps(3), transform.FuseOps(2)]) + func = run_opt_pass(func, [transform.RewriteAnnotatedOps(default_device_type), + transform.PlanDevices(default_device_type), + transform.FuseOps(2)]) smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = [] device_types = [] diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py deleted file mode 100644 index fe19c479292f4..0000000000000 --- a/tests/python/relay/test_pass_context_analysis.py +++ /dev/null @@ -1,205 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks - -import numpy as np -import pytest - -import tvm -from tvm import relay -from tvm.relay import expr as _expr -from tvm.relay.analysis import context_analysis - - -def test_device_copy(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - x = relay.var("x", shape=(2, 3)) - copy = relay.op.device_copy(x, tvm.cpu(), tvm.cuda()) - out = copy + relay.const(np.random.rand(2, 3)) - glb_var = relay.GlobalVar("main") - mod[glb_var] = relay.Function([x], out) - ca = context_analysis(mod, tvm.cpu()) - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - for expr, dev in ca.items(): - if isinstance(expr, _expr.Call): - assert dev[0].value == gpu_dev - elif isinstance(expr, _expr.Var): - assert dev[0].value == cpu_dev - elif isinstance(expr, _expr.Constant): - assert dev[0].value == gpu_dev - - -def test_shape_func(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - data_shape = (relay.Any(),) - x = relay.var("x", shape=data_shape) - y = relay.op.vm.shape_of(x) - z = relay.nn.relu(y) - p0 = relay.var("p0", shape=data_shape) - fn = relay.Function([p0], z) - out = relay.var("out", shape=(1,), dtype="int64") - ins = relay.Tuple([y]) - outs = relay.Tuple([out]) - is_inputs = [False] - shape_func = relay.op.vm.shape_func(fn, ins, outs, is_inputs) - mod["main"] = relay.Function([x, out], shape_func) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - # The output of shape func should be on cpu. - assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev - # shape func is the body and it should be on cpu - assert main.body in ca and ca[main.body][0].value == cpu_dev - - -def test_vm_shape_of(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - data_shape = (relay.Any(),) - x = relay.var("x", shape=data_shape) - y = relay.op.vm.shape_of(x) - mod["main"] = relay.Function([x], y) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - assert main.body in ca and ca[main.body][0].value == cpu_dev - - -def test_alloc_storage(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - mod.import_from_std("core.rly") - size = relay.Var("size", relay.scalar_type("int64")) - alignment = relay.Var("alignment", relay.scalar_type("int64")) - # allocate a chunk on of memory on gpu. - sto = relay.op.memory.alloc_storage(size, alignment, tvm.cuda()) - mod["main"] = relay.Function([size, alignment], sto) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - body = main.body - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - # Inputs are unified with alloc storage inputs which are on cpu - assert main.params[0] in ca and ca[main.params[0]][0].value == cpu_dev - assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev - - assert isinstance(body, relay.Call) and len(body.args) == 2 - # size of alloc_storage is on cpu - assert body.args[0] in ca and ca[body.args[0]][0].value == cpu_dev - # alignment of alloc_storage is on cpu - assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev - # alloc_storage is on gpu as specified - assert body in ca and ca[body][0].value == gpu_dev - - -def test_alloc_tensor(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - mod.import_from_std("core.rly") - sto_type = relay.TypeCall(mod.get_global_type_var("Storage"), []) - sto = relay.Var("x", sto_type) - sh = relay.const(np.array([3, 2]), dtype="int64") - at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh) - mod["main"] = relay.Function([sto], at) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - body = main.body - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - # Input of the function falls back to the default device gpu - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - - assert isinstance(body, relay.Call) and len(body.args) == 3 - # storage of alloc_tensor falls back to the default device gpu - assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev - # shape of alloc_tensor is on cpu - assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev - # alloc_tensor keeps the same device context as storage which is is on gpu - assert body in ca and ca[body][0].value == gpu_dev - - -def test_vm_reshape_tensor(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - x = relay.var("x", shape=(2, 8), dtype="float32") - shape = relay.const([-1, 4, 2], dtype="int64") - y = relay.op.vm.reshape_tensor(x, shape, [2, 4, 2]) - mod = tvm.IRModule() - mod["main"] = relay.Function([x], y) - ca = context_analysis(mod, tvm.cuda()) - main = mod["main"] - body = main.body - - cpu_dev = tvm.cpu().device_type - gpu_dev = tvm.cuda().device_type - # Input of the function falls back to the default device gpu - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - - # dats of reshape_tensor falls back to the default device gpu - assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev - # shape of reshape_tensor is on cpu - assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev - # reshape_tensor sits on the same device as the data - assert body in ca and ca[body][0].value == gpu_dev - - -def test_dynamic_input(): - if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist: - return - - mod = tvm.IRModule() - data_shape = (relay.Any(), relay.Any()) - x0 = relay.var("x0", shape=data_shape) - x1 = relay.var("x1", shape=data_shape) - mod["main"] = relay.Function([x0, x1], x0 + x1) - - compiler = relay.vm.VMCompiler() - mod, _ = compiler.optimize(mod, target="cuda") - ca = context_analysis(mod, tvm.cpu()) - main = mod["main"] - - gpu_dev = tvm.cuda().device_type - assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev - assert main.params[1] in ca and ca[main.params[1]][0].value == gpu_dev - assert main.body in ca and ca[main.body][0].value == gpu_dev - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_lambda_lift.py b/tests/python/relay/test_pass_lambda_lift.py index ce737b7bedbbd..3a1f4dbbbf98b 100644 --- a/tests/python/relay/test_pass_lambda_lift.py +++ b/tests/python/relay/test_pass_lambda_lift.py @@ -83,4 +83,7 @@ def test_recursive(): if __name__ == "__main__": - pytest.main() + # pytest.main() + test_basic() + test_closure() + test_recursive() \ No newline at end of file diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py new file mode 100644 index 0000000000000..a2523219b61ac --- /dev/null +++ b/tests/python/relay/test_pass_plan_devices.py @@ -0,0 +1,1062 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License + +# TODO(mbs): All the input/expected programs should be directly quoted using @script +# TODO(mbs): Test all the relay constructs: If, Match, Constructor, Ref*, GlobalVar + +import tvm +from tvm import relay +import tvm.testing +import numpy as np + +N = 5 +M = 7 +CPU = tvm.device("cpu") # device_type=1 +GPU = tvm.device("cuda") # device_type=2 +DEFAULT = GPU + + +def rand(shape): + return np.random.rand(*shape).astype("float32") + + +def rands(shape, n): + return [rand(shape) for i in range(n)] + + +def rewrite_and_assert(in_mod, expected_mod): + actual_mod = relay.transform.InferType()(in_mod) + actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod) + actual_mod = relay.transform.InferType()(actual_mod) + expected_mod = relay.transform.InferType()(expected_mod) + if not tvm.ir.structural_equal(actual_mod, expected_mod): + # Print everything in full so we can see what's going on. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod) + + +def eval_and_assert(in_mod: tvm.IRModule, reference_func, args): + with tvm.transform.PassContext(opt_level=3): + compiled = relay.create_executor("vm", mod=in_mod, device=GPU, target="cuda").evaluate() + actual = compiled(*args).numpy() + expected = reference_func(*args) + tvm.testing.assert_allclose(actual, expected) + + +def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, args): + # Correctness + # rewrite_and_assert(in_mod, expected_mod) + # Idempotence + # rewrite_and_assert(expected_mod, expected_mod) + # The VM can compile + if not (reference_func is None) and not (args is None): + eval_and_assert(in_mod, reference_func, args) + + +def on_cpu(expr: relay.Expr): + return relay.annotation.on_device(expr, CPU) + + +def on_gpu(expr: relay.Expr): + return relay.annotation.on_device(expr, GPU) + + +def cpu_to_gpu(expr: relay.Expr): + return relay.op.device_copy(expr, CPU, GPU) + + +def gpu_to_cpu(expr: relay.Expr): + return relay.op.device_copy(expr, GPU, CPU) + + +def fixed_cpu(expr: relay.Expr): + return relay.annotation.on_device(expr, CPU, True) + + +def fixed_gpu(expr: relay.Expr): + return relay.annotation.on_device(expr, GPU, True) + + +def test_plain(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + # def @main(a, b, c, d) { subtract(add(a, b), add(c, d)) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], + relay.subtract(relay.add(a, b), + relay.add(c, d)))) + + # def @main(a, b, c, d, on_device={param_device_types=[2,2,2,2], result_device_type=2}) { + # subtract(add(a, b), add(c, d)) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], + relay.subtract(relay.add(a, b), + relay.add(c, d))), + [GPU, GPU, GPU, GPU], GPU)) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_left_add_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + # def @main(a, b, c, d) { subtract(on_cpu(add(a, b)), add(c, d)) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], + relay.subtract(on_cpu(relay.add(a, b)), + relay.add(c, d)))) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) { + # subtract(cpu_to_gpu(fixed_cpu(add(a, b)), add(c, d)) + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], + relay.subtract(cpu_to_gpu(fixed_cpu(relay.add(a, b))), + relay.add(c, d))), + [CPU, CPU, GPU, GPU], GPU)) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_left_add_on_cpu_via_copy(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + # def @main(a, b, c, d) { subtract(cpu_to_gpu(add(a, b)), add(c, d)) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], + relay.subtract(cpu_to_gpu(relay.add(a, b)), + relay.add(c, d)))) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) { + # subtract(cpu_to_gpu(fixed_cpu(add(a, b)), add(c, d)) + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], + relay.subtract(cpu_to_gpu(fixed_cpu(relay.add(a, b))), + relay.add(c, d))), + [CPU, CPU, GPU, GPU], GPU)) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_both_adds_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + # def @main(a, b, c, d) { subtract(on_cpu(add(a, b)), on_cpu(add(c, d))) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], + relay.subtract(on_cpu(relay.add(a, b)), + on_cpu(relay.add(c, d))))) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,1,1], result_device_type=2}) { + # subtract(cpu_to_gpu(fixed_cpu(add(a, b)), cpu_to_gpu(fixed_cpu(add(c, d)))) + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], + relay.subtract(cpu_to_gpu(fixed_cpu(relay.add(a, b))), + cpu_to_gpu(fixed_cpu(relay.add(c, d))))), + [CPU, CPU, CPU, CPU], GPU)) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_sharing(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + + # def @main(a, b) { + # %0 = add(a, b) + # subtract(on_cpu(%0), %0) } + def input(): + add = relay.add(a, b) + return tvm.IRModule.from_expr( + relay.Function([a, b], + relay.subtract(on_cpu(add), + on_cpu(add)))) + + # def @main(a, b, on_device={param_device_types=[1,1], result_device_type=2}) { + # %0 = add(a, b) + # subtract(cpu_to_gpu(fixed_cpu(%0), cpu_to_gpu(fixed_cpu(%0))) + def expected(): + add = relay.add(a, b) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b], + relay.subtract(cpu_to_gpu(fixed_cpu(add)), + cpu_to_gpu(fixed_cpu(add)))), + [CPU, CPU], GPU)) + + def ref(a, b): + x = np.add(a, b) + return np.subtract(x, x) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +def test_let_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + l = relay.Var("l") + r = relay.Var("r") + + # def @main(a, b, c, d) { + # let l = add(a, b); + # let r = add(c, d); + # subtract(on_cpu(l), r) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], + relay.Let(l, relay.add(a, b), + relay.Let(r, relay.add(c, d), + relay.subtract(on_cpu(l), r))))) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) { + # let l = fixed_cpu(add(a, b)); + # let r = add(c, d); + # subtract(cpu_to_gpu(fixed_cpu(l)), r) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], + relay.Let(l, fixed_cpu(relay.add(a, b)), + relay.Let(r, relay.add(c, d), + relay.subtract(cpu_to_gpu(fixed_cpu(l)), r)))), + [CPU, CPU, GPU, GPU], GPU)) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_func_param_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + f = relay.Var("f") + x = relay.Var("x") + y = relay.Var("y") + + # def @main(a, b, c, d) { + # let f = fn(x, y) { on_cpu(add(x, y)) } -- forces both body and result on CPU + # subtract(f(a, b), add(c, d)) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], + relay.Let(f, + relay.Function([x, y], + on_cpu(relay.add(x, y))), + relay.subtract(relay.Call(f, [a, b]), + relay.add(c, d))))) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,1,1], result_device_type=1}) { + # let f = fn(x, y, on_device={param_device_types[1,1], result_device_type=1}) { + # add(x, y) + # }; + # subtract(f(a, b), add(c, d)) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], + relay.Let(f, + relay.annotation.function_on_device( + relay.Function([x, y], relay.add(x, y)), + [CPU, CPU], CPU + ), + relay.subtract(relay.Call(f, [a, b]), + relay.add(c, d)))), + [CPU, CPU, CPU, CPU], CPU)) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_func_result_on_cpu(): + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + f = relay.Var("f") + x = relay.Var("x") + y = relay.Var("y") + + # def @main(a, b, c, d) { + # let f = fn(x, y) { add(x, y) } + # subtract(on_cpu(f(a, b)), add(c, d)) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], + relay.Let(f, + relay.Function([x, y], + relay.add(x, y)), + relay.subtract(on_cpu(relay.Call(f, [a, b])), + relay.add(c, d))))) + + # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) { + # let f = fixed_cpu(fn(x, y, on_device={param_device_types=[1,1], result_device_type=1}) { + # add(x, y) + # }); + # subtract(cpu_to_gpu(fixed_cpu(f(a, b))), add(c, d)) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], + relay.Let(f, + fixed_cpu(relay.annotation.function_on_device( + relay.Function([x, y], + relay.add(x, y)), + [CPU, CPU], CPU + )), + relay.subtract(cpu_to_gpu(fixed_cpu(relay.Call(f, [a, b]))), + relay.add(c, d)))), + [CPU, CPU, GPU, GPU], GPU)) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +def test_higher_order(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + f = relay.Var("f") + g = relay.Var("g") + a = relay.Var("a") + h = relay.Var("h") + b = relay.Var("b") + + # The constraint on a flows back to y via f and h + # def @main(x, y) { + # let f = fn(g) { fn(a) { add(g(on_cpu(a)), x) } } + # let h = fn(b) { relu(b) } + # subtract(x, f(h)(y)) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function([x, y], + relay.Let(f, + relay.Function([g], + relay.Function([a], + relay.add( + relay.Call(g, [on_cpu(a)]), x))), + relay.Let(h, + relay.Function([b], relay.nn.relu(b)), + relay.subtract(x, relay.Call(relay.Call(f, [h]), [y])))))) + + # def @main(x, y, on_device={param_device_types=[GPU, CPU], result_device_type=GPU}) { + # let f = fn(g, on_device={param_device_types=[GPU], result_device_type=GPU}) { + # fn(a, on_device={param_device_types=[CPU], result_device_type=GPU}) { + # add(g(cpu_to_gpu(fixed_cpu(a))), x) + # } + # } + # let h = fn(b, on_device={param_device_types=[GPU], result_device_type=GPU}) { relu(b) } + # subtract(x, f(h)(fixed_cpu(y))) + # } + def expected(): + # Yeah, this is illegible. + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x, y], + relay.Let(f, + relay.annotation.function_on_device( + relay.Function([g], + relay.annotation.function_on_device( + relay.Function([a], + relay.add( + relay.Call(g, + [cpu_to_gpu( + fixed_cpu(a))]), + x)), + [CPU], GPU)), + [GPU], GPU), + relay.Let(h, + relay.annotation.function_on_device( + relay.Function([b], relay.nn.relu(b)), + [GPU], GPU), + relay.subtract(x, relay.Call(relay.Call(f, [h]), + [fixed_cpu(y)]))))), + [GPU, CPU], GPU)) + + def ref(x, y): + return np.subtract(x, np.add(np.relu(y), x)) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +def test_function_in_tuple(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + y = relay.var("y", shape=shape) + f = relay.Var("f") + t = relay.Var("t") + + # Since f end up in a tuple its argument and result is forced to be on the CPU + # def @main(x, y) { + # let f = fn(a, b) { add(a, on_cpu(b)) } + # let t = (f, x) + # t.0(t.1, y) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function([x, y], + relay.Let(f, + relay.Function([a, b], relay.add(a, on_cpu(b))), + relay.Let(t, + relay.Tuple([f, x]), + relay.Call(relay.TupleGetItem(t, 0), + [relay.TupleGetItem(t, 1), y]))))) + + # def @main(x, y, on_device={param_device_types=[1,1], result_device_type=1}) { + # let f = fn(a, b, on_device={param_device_types=[1,1], result_device_type=1}) { add(a, b) } + # let t = (f, x) + # t.0(t.1, y) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x, y], + relay.Let(f, + relay.annotation.function_on_device( + relay.Function([a, b], relay.add(a, b)), + [CPU, CPU], CPU), + relay.Let(t, + relay.Tuple([f, x]), + relay.Call(relay.TupleGetItem(t, 0), + [relay.TupleGetItem(t, 1), y])))), + [CPU, CPU], CPU)) + + def ref(x, y): + return np.add(x, y) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +def test_device_copy(): + shape = (N, M) + x = relay.var("x", shape=shape) + const = relay.const(rand(shape)) + + # def @main(x) { add(cpu_to_gpu(x), const) } + def input(): + return tvm.IRModule.from_expr(relay.Function([x], relay.add(cpu_to_gpu(x), const))) + + # def @main(x, on_device={param_device_types=[1], result_device_type=2}) { + # add(cpu_to_gpu(fixed_cpu(x)), constant) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x], relay.add(cpu_to_gpu(fixed_cpu(x)), const)), + [CPU], GPU)) + + def ref(x): + return np.add(x, const) + + exercise(input(), expected(), ref, rands(shape, 1)) + + +def test_shape_func(): + p = relay.var("p") + data_shape = (relay.Any(),) + x = relay.var("x", shape=data_shape) + y = relay.var("y", shape=data_shape) + s = relay.var("s", shape=(1,), dtype="int64") + + # def @main(x, s) { + # let p = fixed_gpu(fn(y) { relu(y) }) -- simulates a primitive post FuseOps + # shape_func(p, + # (shape_of(fixed_gpu(x)),), -- shape of primitive input tensor + # (s,), -- space for output shape + # [False]) -- calling with input shapes not tensors + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function([x, s], + relay.Let(p, + fixed_gpu(relay.Function([y], relay.nn.relu(y))), + relay.op.vm.shape_func(p, + relay.Tuple([ + relay.op.vm.shape_of(fixed_gpu(x))]), + relay.Tuple([s]), + [False])))) + + # def @main(x, s, on_device={param_device_types=[2,1], result_device_type=1}) { + # let p = fixed_gpu(fn(y, param_device_types=[2], result_device_type=2) { relu(y) }) + # shape_func(fixed_gpu(p), + # (shape_of(fixed_gpu(x)),), + # (s,), + # [False]) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x, s], + relay.Let(p, + fixed_gpu( + relay.annotation.function_on_device( + relay.Function([y], relay.nn.relu(y)), + [GPU], GPU)), + relay.op.vm.shape_func(fixed_gpu(p), + relay.Tuple([ + relay.op.vm.shape_of(fixed_gpu(x))]), + relay.Tuple([s]), + [False]))), + [GPU, CPU], CPU)) + + # Don't try to execute this -- it's too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_shape_of(): + compiletime_shape = (relay.Any(), relay.Any()) + runtime_shape = (N, M) + x = relay.var("x", shape=compiletime_shape) + + # def @main(x) { shape_of(fixed_gpu(x)) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([x], relay.op.vm.shape_of(fixed_gpu(x)))) + + # def @main(x, on_device={param_device_types=[2], result_dev_type=1}) { + # shape_of(fixed_gpu(x)) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x], relay.op.vm.shape_of(fixed_gpu(x))), + [GPU], CPU)) + + def ref(x): + return x.shape + + exercise(input(), expected(), ref, rands(runtime_shape, 1)) + + +def test_alloc_storage(): + size = relay.Var("size", relay.scalar_type("int64")) + alignment = relay.Var("alignment", relay.scalar_type("int64")) + main = relay.GlobalVar("main") + stdlib = tvm.IRModule() + stdlib.import_from_std("core.rly") + + # def @main(size, alignment) { alloc_storage(size, alignment, GPU) } + def input(): + mod = tvm.IRModule() + mod.update(stdlib) + mod[main] = relay.Function([size, alignment], + relay.op.memory.alloc_storage(size, alignment, GPU)) + return mod + + # def @main(size, alignment, on_device={param_device_types=[1,1], result_device_type=2}) { + # alloc_storage(fixed_cpu(size), fixed_cpu(alignment), GPU) + # } + def expected(): + mod = tvm.IRModule() + mod.update(stdlib) + mod[main] = relay.annotation.function_on_device( + relay.Function([size, alignment], + relay.op.memory.alloc_storage(fixed_cpu(size), fixed_cpu(alignment), GPU)), + [CPU, CPU], GPU) + return mod + + # Don't try to execute. + exercise(input(), expected(), None, None) + + +def test_alloc_tensor(): + stdlib = tvm.IRModule() + stdlib.import_from_std("core.rly") + sto_type = relay.TypeCall(stdlib.get_global_type_var("Storage"), []) + sto = relay.Var("sto", sto_type) + main = relay.GlobalVar("main") + shape = relay.const(np.array([3, 2]), dtype="int64") + + # def @main(sto) { alloc_tensor(sto, 0, [3, 2]) } + def input(): + mod = tvm.IRModule() + mod.update(stdlib) + mod[main] = relay.Function([sto], + relay.op.memory.alloc_tensor(sto, + relay.const(0, dtype="int64"), + shape)) + return mod + + # def @main(sto, on_device={param_device_types=[2], result_device_type=2}) { + # alloc_tensor(sto, fixed_cpu(0), fixed_cpu([3, 2])) + # } + def expected(): + mod = tvm.IRModule() + mod.update(stdlib) + mod[main] = relay.annotation.function_on_device( + relay.Function([sto], + relay.op.memory.alloc_tensor(sto, + fixed_cpu(relay.const(0, dtype="int64")), + fixed_cpu(shape))), + [GPU], GPU) + return mod + + # Don't try to execute. + exercise(input(), expected(), None, None) + + +def test_reshape_tensor(): + shape = (2, 8) + x = relay.var("x", shape=shape, dtype="float32") + newshape_expr = relay.const([-1, 4, 2], dtype="int64") + newshape_prim = [2, 4, 2] + + # def @main(x) { reshape_tensor(x, shape, newshape=[2,4,2]) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([x], relay.op.vm.reshape_tensor(x, newshape_expr, newshape_prim))) + + # def @main(x, on_device={param_device_types=[2], result_device_type=2}) { + # reshape_tensor(x, fixed_cpu(shape), newshape=[2,4,2]) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x], relay.op.vm.reshape_tensor(x, fixed_cpu(newshape_expr), newshape_prim)), + [GPU], GPU)) + + def ref(x): + return np.reshape(x, newshape_prim) + + exercise(input(), expected(), ref, rands(shape, 1)) + + +def test_dynamic_input(): + compiletime_shape = (relay.Any(), relay.Any()) + runtime_shape = (N, M) + x0 = relay.var("x0", shape=compiletime_shape) + x1 = relay.var("x1", shape=compiletime_shape) + + # def @main(x0, x1) { add(x0, x1) } + def input(): + return tvm.IRModule.from_expr( + relay.Function([x0, x1], relay.add(x0, x1))) + + # def @main(x0, x1), on_device={param_device_types=[2,2], result_device_type=2}) { + # add(x0, x1) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x0, x1], relay.add(x0, x1)), + [GPU, GPU], GPU)) + + def ref(x0, x1): + return np.add(x0, x1) + + exercise(input(), expected(), ref, rands(runtime_shape, 2)) + + +def test_redundant_annotation(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + z = relay.var("z", shape=shape) + + # def @main(x, y, z) { + # %0 = add(x, y) + # (subtract(on_cpu(%0), z), subtract(on_cpu(%0), z)) + # } + def input(): + add = relay.add(x, y) + return tvm.IRModule.from_expr( + relay.Function([x, y, z], + relay.Tuple([ + relay.subtract(on_cpu(add), z), + relay.subtract(on_cpu(add), z)]))) + + # def @main(x, y, z, on_device={param_device_types=[1,1,2], result_device_type=2}) { + # %0 = add(x, y) + # (subtract(cpu_to_gpu(fixed_cpu(%0)), z), subtract(cpu_to_gpu(fixed_cpu(%0)), z)) + # } + def expected(): + add = relay.add(x, y) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x, y, z], + relay.Tuple([ + relay.subtract(cpu_to_gpu(fixed_cpu(add)), z), + relay.subtract(cpu_to_gpu(fixed_cpu(add)), z)])), + [CPU, CPU, GPU], GPU)) + + def ref(x, y, z): + t = np.add(x, y) + return (np.subtract(t, z), np.subtract(t, z)) + + exercise(input(), expected(), ref, rands(shape, 3)) + + +def test_annotate_expr(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + z = relay.var("z", shape=shape) + + # def @main(x, y, z) { on_cpu(subtract(on_gpu(add(x, y)), z)) } -- forces function result also on cpu + def input(): + return tvm.IRModule.from_expr( + relay.Function([x, y, z], + on_cpu(relay.subtract(on_gpu(relay.add(x, y)), z)))) + + # def @main(x, y, z, on_device={param_device_types=[2,2,1], result_device_type=1}) { + # subtract(gpu_to_cpu(fixed_gpu(add(x, y))), z) + # } + def expected(): + add = relay.add(x, y) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x, y, z], + relay.subtract(gpu_to_cpu(fixed_gpu(relay.add(x, y))), z)), + [GPU, GPU, CPU], CPU)) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands(shape, 3)) + + +def test_annotate_all(): + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + z = relay.var("z", shape=shape) + + # def @main(x, y, z) { on_cpu(subtract(on_cpu(add(x, y)), z) } -- top-level also forces result to be CPU + def input(): + return tvm.IRModule.from_expr( + relay.Function([x, y, z], + on_cpu(relay.subtract(on_cpu(relay.add(x, y)), z)))) + + # def @main(x, y, z, on_device={param_device_types=[CPU, CPU, CPU], result_device_type=CPU}) { + # subtract(add(x, y), z) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x, y, z], + relay.subtract(relay.add(x, y), z)), + [CPU, CPU, CPU], CPU)) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands(shape, 3)) + + +def test_conv_network(): + r"""The network and devices are as follows: + data1 data2 <--- CPU + | | + conv2d conv2d <--- CPU + \ / + \ / + add <--- GPU + | + conv2d <--- CPU + | + <--- CPU + """ + batch_size = 1 + dshape = (batch_size, 64, 56, 56) + wshape = (64, 64, 3, 3) + weight = relay.var("weight", shape=wshape) + data1 = relay.var("data1", shape=dshape) + data2 = relay.var("data2", shape=dshape) + + def input(): + conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + add = relay.add(on_cpu(conv2d_1), on_cpu(conv2d_2)) + conv2d_3 = relay.nn.conv2d(on_gpu(add), weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + return tvm.IRModule.from_expr( + relay.Function([data1, data2, weight], on_cpu(conv2d_3))) + + def expected(): + conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + add = relay.add(cpu_to_gpu(fixed_cpu(conv2d_1)), cpu_to_gpu(fixed_cpu(conv2d_2))) + conv2d_3 = relay.nn.conv2d(gpu_to_cpu(fixed_gpu(add)), weight, channels=64, kernel_size=(3, 3), + padding=(1, 1)) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([data1, data2, weight], conv2d_3), + [CPU, CPU, CPU], CPU)) + + # Don't try to execute. + exercise(input(), expected(), None, None) + + +def test_tuple_get_item(): + shape = (3, 3, 4) + x = relay.Var("x", relay.ty.TensorType(shape, "float32")) + t = relay.Var("t") + + # We'll device copy after projection, not before. + # def @main(x) { + # let t = split(x, 3); + # subtract(on_cpu(t).0, on_cpu(t).1) + # } + def input(): + return tvm.IRModule.from_expr( + relay.Function([x], + relay.Let(t, + relay.op.split(x, 3).astuple(), + on_gpu(relay.subtract(relay.TupleGetItem(on_cpu(t), 0), + relay.TupleGetItem(on_cpu(t), 1)))))) + + # def @main(x, on_device={param_device_type=[1], result_device_type=2}) { + # let t = fixed_cpu(split(x, 3)) + # subtract(cpu_to_gpu(fixed_cpu(t.0)), cpu_to_gpu(fixed_cpu(t.1))) + # } + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x], + relay.Let(t, + fixed_cpu(relay.op.split(x, 3).astuple()), + relay.subtract( + cpu_to_gpu(fixed_cpu(relay.TupleGetItem(t, 0))), + cpu_to_gpu(fixed_cpu(relay.TupleGetItem(t, 1)))))), + [CPU], GPU)) + + # Don't try to execute + exercise(input(), expected(), None, None) + + +def test_propogation(): + R""" The network and devices are as follows: + x <--- CPU + | + log <--- CPU + / \ + log2 log10 <--- GPU + \ / + add <--- GPU + | + tan <--- CPU + | + <--- CPU + """ + shape = (N, M) + x = relay.var("x", shape=shape) + + def input(): + log = relay.log(x) + log2 = relay.log2(on_cpu(log)) + log10 = relay.log10(on_cpu(log)) + add = relay.add(on_gpu(log2), on_gpu(log10)) + tan = relay.tan(on_gpu(add)) + return tvm.IRModule.from_expr( + relay.Function([x], on_cpu(tan)) + ) + + def expected(): + log = relay.log(x) + log2 = relay.log2(cpu_to_gpu(fixed_cpu(log))) + log10 = relay.log10(cpu_to_gpu(fixed_cpu(log))) + add = relay.add(log2, log10) + tan = relay.tan(gpu_to_cpu(fixed_gpu(add))) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x], tan), + [CPU], CPU)) + + def ref(x): + y = np.log(x) + return np.tan(np.add(np.log2(y), np.log10(y))) + + exercise(input(), expected(), ref, rands(shape, 1)) + + +def test_fusible_network(): + R""" The network is as follows: + x y <--- GPU + \ / + add <--- GPU + / \ + sqrt \ <--- CPU + \ \ + \ log <--- GPU + \ / + subtract <--- GPU + | + exp <--- CPU + | + <--- CPU + """ + shape = (N, M) + x = relay.var("x", shape=shape) + y = relay.var("y", shape=shape) + + def input(): + add = relay.add(x, y) + sqrt = relay.sqrt(on_gpu(add)) + log = relay.log(add) + subtract = relay.subtract(on_cpu(sqrt), log) + exp = relay.exp(on_gpu(subtract)) + return tvm.IRModule.from_expr(relay.Function([x, y], on_cpu(exp))) + + def expected(): + add = relay.add(x, y) + sqrt = relay.sqrt(gpu_to_cpu(fixed_gpu(add))) + log = relay.log(add) + subtract = relay.subtract(cpu_to_gpu(fixed_cpu(sqrt)), log) + exp = relay.exp(gpu_to_cpu(fixed_gpu(subtract))) + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([x, y], exp), + [GPU, GPU], CPU)) + + def ref(x, y): + z = np.add(x, y) + return np.exp(np.subtract(np.sqrt(z), np.log(z))) + + exercise(input(), expected(), ref, rands(shape, 2)) + + +def test_unpropagatable_graph(): + r"""The network is as follows: + a b <--- CPU + \ / + \ / c d <--- GPU + \ / \ / + add \ / <--- CPU + \ \ / + \ mul <--- GPU + \ / + subtract <--- CPU + | + <--- CPU + """ + shape = (N, M) + a = relay.var("a", shape=shape) + b = relay.var("b", shape=shape) + c = relay.var("c", shape=shape) + d = relay.var("d", shape=shape) + + def input(): + return tvm.IRModule.from_expr( + relay.Function([a, b, c, d], + on_cpu(relay.subtract(on_cpu(relay.add(a, b)), on_gpu(relay.multiply(c, d)))))) + + def expected(): + return tvm.IRModule.from_expr( + relay.annotation.function_on_device( + relay.Function([a, b, c, d], + relay.subtract(relay.add(a, b), gpu_to_cpu(fixed_gpu(relay.multiply(c, d))))), + [CPU, CPU, GPU, GPU], CPU)) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.mul(a, b)) + + exercise(input(), expected(), ref, rands(shape, 4)) + + +if __name__ == "__main__": + # test_plain() + test_left_add_on_cpu() + # test_left_add_on_cpu_via_copy() + # test_both_adds_on_cpu() + # test_sharing() + # test_let_on_cpu() + # test_func_param_on_cpu() + # test_func_result_on_cpu() + # test_higher_order() + # test_function_in_tuple() + # test_device_copy() + # test_shape_func() + # test_shape_of() + # test_alloc_storage() + # test_alloc_tensor() + # test_reshape_tensor() + # test_dynamic_input() + # test_redundant_annotation() + # test_annotate_expr() + # test_annotate_all() + # test_conv_network() + # test_tuple_get_item() + # test_propogation() + # test_fusible_network() + # test_unpropagatable_graph()