From 845f51627d2099b48b18aa24778012311026b954 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Thu, 18 Nov 2021 11:01:48 -0800 Subject: [PATCH] [checkpoint] externs treated as call_lowered so can see dyn shape function info It compiles something, but not the same as the original. --- include/tvm/runtime/vm/executable.h | 9 +- python/tvm/runtime/vm.py | 25 ++-- src/relay/backend/te_compiler.cc | 193 +++++++++++++-------------- src/relay/backend/vm/compiler.cc | 5 +- src/relay/transforms/dead_code.cc | 5 + src/relay/transforms/memory_alloc.cc | 32 ++--- src/runtime/vm/executable.cc | 19 +++ 7 files changed, 160 insertions(+), 128 deletions(-) diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index d382e2fd1ec31..5b18d05af80c5 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -144,6 +144,13 @@ class Executable : public ModuleNode { */ std::string GetVirtualDevices() const; + /*! + * \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable. + * These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by + * a BYOC external codegen. + */ + std::string GetPrimitives() const; + /*! * \brief Print the detailed statistics of the given code, i.e. number of * globls and constants, etc. @@ -201,7 +208,7 @@ class Executable : public ModuleNode { int host_device_index = -1; /*! \brief The global constant pool. */ std::vector constants; - /*! \brief A map from globals (as strings) to their index in the function map. */ + /*! \brief A map from globals (as strings) to their index in the Relay function map. */ std::unordered_map global_map; /*! \brief A mapping from the packed function's global name (as string) to the index that * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 365e38c6e06c0..e07cdcdfa915a 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -73,6 +73,7 @@ def __init__(self, mod): self._get_bytecode = self.mod["get_bytecode"] self._get_constants = self.mod["get_constants"] self._get_virtual_devices = self.mod["get_virtual_devices"] + self._get_primitives = self.mod["get_primitives"] self._get_stats = self.mod["get_stats"] self._get_function_arity = self.mod["get_function_arity"] self._get_function_param_name = self.mod["get_function_param_name"] @@ -257,6 +258,12 @@ def virtual_devices(self): """Returns a human-readable description of all the (virtual) devices in the executable.""" return self._get_virtual_devices() + @property + def primitive(self): + """Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the + executable""" + return self._get_primitives() + @property def globals(self): """Get the globals used by the Relay VM executable. @@ -522,15 +529,15 @@ def get_input_index(self, input_name, func_name="main"): return self._get_input_index(input_name, func_name) def benchmark( - self, - device, - *args, - func_name="main", - repeat=5, - number=5, - min_repeat_ms=None, - end_to_end=False, - **kwargs, + self, + device, + *args, + func_name="main", + repeat=5, + number=5, + min_repeat_ms=None, + end_to_end=False, + **kwargs, ): """Calculate runtime of a function by repeatedly calling it. diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 2f15fd8b02fd6..436847b7621eb 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -505,6 +505,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { * nullptr if none. */ BaseFunc ResolveToPrimitive(const Expr& expr) { + // NOTE: We can't assume expr->checked_type_ is defined, so can't early exit for first-order + // expressions. if (const auto* global_var_node = expr.as()) { if (!module_->ContainGlobalVar(global_var_node->name_hint)) { // TODO(mbs): extern function cleanup @@ -518,8 +520,12 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { return GetRef(prim_func_node); } else if (const auto* var_node = expr.as()) { auto itr = primitive_functions_.find(GetRef(var_node)); - ICHECK(itr != primitive_functions_.end()); - return itr->second; + if (itr == primitive_functions_.end()) { + // Not bound to a function. + return {}; + } else { + return itr->second; + } } else if (const auto* function_node = expr.as()) { if (!function_node->HasNonzeroAttr(attr::kPrimitive)) { // Not marked as primitive by FuseOps. @@ -545,108 +551,88 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { */ Expr MakeLoweredCall(Function func, Array visited_args, Array type_args, Span span, Target target) { + CCacheKey key = CCacheKey(func, target); + CachedFunc cfunc = compiler_->Lower(key, module_name_); + ICHECK(cfunc.defined()); + auto opt_compiler = func->GetAttr(attr::kCompiler); - if (opt_compiler.defined()) { - // BYOC flow. - CCacheKey key = CCacheKey(func, target); - CachedFunc ext_func = compiler_->Lower(key, module_name_); - ICHECK(ext_func.defined()) << "Lowering returned undefined function for " - << PrettyPrint(ext_func->prim_fn_var); - - // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT - Map prim_fns; - relay::Function func_with_metadata = func; - func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); - func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, ext_func->target); - - // Provide a callback hook which allows one-level up code generators to - // act when we process a function. - this->process_fn_(func_with_metadata); - // TODO(mbs): Dynamic shapes? - // TODO(@mbs, electriclilies): Make extern functions explicit - return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, span); - } else { - // Non-External Relay Function - CCacheKey key = CCacheKey(func, target); - CachedFunc lowered_func = compiler_->Lower(key, module_name_); - - // Collect all the lowered functions produced for this primitive function. - // Special handling for "device_copy": No lowered definition is returned (indeed the - // lowered_func->funcs is empty). - // TODO(mbs): "device_copy" cleanup - Map prim_fns; - Array all_prim_fn_vars; - for (const auto& pair : lowered_func->funcs->functions) { - CHECK(pair.second.as()) + // Add some metadata on top of the *original function* and invoke the callback so it can + // be captured. + // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT + Map prim_fns; + Array all_prim_fn_vars; + for (const auto& pair : cfunc->funcs->functions) { + if (opt_compiler) { + // We expect just the original func but with just the ExternalSymbol attribute signalling + // the function (will be) compiled externally. + ICHECK(pair.second.as()) + << PrettyPrint(pair.first) << " must be bound to an (external) Function"; + } else { + // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive + // (and the rest in support of that via tir::Calls). + ICHECK(pair.second.as()) << PrettyPrint(pair.first) << " must be bound to a PrimFunc"; prim_fns.Set(pair.first, Downcast(pair.second)); all_prim_fn_vars.push_back(pair.first); } + } + Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", cfunc->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target); + this->process_fn_(func_with_metadata); + + auto call_lowered_attrs = make_object(); + + // Non-External Relay Function + // TODO(mbs): "reshape" cleanup. + if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) { + call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); + } - // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT - relay::Function func_with_metadata = func; - func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); - func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target); - - // Provide a callback hook which allows one-level up code generators to - // act when we process a function. - this->process_fn_(func_with_metadata); - - auto call_lowered_attrs = make_object(); - // TODO(mbs): "reshape" cleanup. - if (func->HasNonzeroAttr(attr::kReshapeOnly)) { - call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); - } - - // Special handling for "device_copy": Look inside the primitive Function to see if it - // is a device_copy. If so capture source and destination attributes so downstream passes - // can follow along. - // TODO(mbs): "device_copy" cleanup - DeviceCopyProps props = GetDeviceCopyProps(func->body); - if (props.body.defined()) { - call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope); - call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope); - } + // Special handling for "device_copy": Look inside the primitive Function to see if it + // is a device_copy. If so capture source and destination attributes so downstream passes + // can follow along. + // TODO(mbs): "device_copy" cleanup + DeviceCopyProps props = GetDeviceCopyProps(func->body); + if (!opt_compiler && props.body.defined()) { + call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope); + call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope); + } - call_lowered_attrs->metadata.Set("relay_attrs", func->attrs); - call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); - - if (IsDynamic(func->ret_type)) { - // Also lower the dynamic shape function. - // Shape function keys use the underlying primitive function as their 'function', - // but the generic 'cpu' target as the target since all shape functions run - // on the host cpu irrespective of where the primitive runs. - // TODO(mbs): Cleanup target handling. - Target shape_target = host_se_scope_->target; - CCacheKey shape_key(func, shape_target); - CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); - // Capture the shape function's global var and parameters 'states' in call - // annotations so calling convention can be recovered. - // TODO(mbs): Capture all this as part of a 'call into TIR' construct once available. - // The way the shape function calling convention is derived and passed to call sites - // via the 'parameter states' could be improved. - call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); - call_lowered_attrs->metadata.Set("prim_shape_fn_states", - lowered_shape_func->shape_func_param_states); - call_lowered_attrs->metadata.Set( - "prim_shape_fn_num_inputs", - Integer(static_cast(lowered_shape_func->inputs.size()))); - call_lowered_attrs->metadata.Set( - "prim_shape_fn_num_outputs", - Integer(static_cast(lowered_shape_func->outputs.size()))); - Array all_prim_shape_fn_vars; - for (const auto& pair : lowered_shape_func->funcs->functions) { - CHECK(pair.second.as()) << "must be a prim fn"; - all_prim_shape_fn_vars.push_back(pair.first); - } - call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); + call_lowered_attrs->metadata.Set("relay_attrs", func->attrs); + call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); + + if (IsDynamic(func->ret_type)) { + // Also lower the companion dynamic shape function. + // Shape function keys use the underlying primitive function as their 'function', + // but the generic 'cpu' target as the target since all shape functions run + // on the host cpu irrespective of where the primitive runs. + CCacheKey shape_key(func, host_se_scope_->target); + CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); + + // Capture the shape function's global var and parameters 'states' in call + // annotations so calling convention can be recovered. + // TODO(mbs): Shape cleanup. + call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + call_lowered_attrs->metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + call_lowered_attrs->metadata.Set( + "prim_shape_fn_num_inputs", Integer(static_cast(lowered_shape_func->inputs.size()))); + call_lowered_attrs->metadata.Set( + "prim_shape_fn_num_outputs", + Integer(static_cast(lowered_shape_func->outputs.size()))); + Array all_prim_shape_fn_vars; + for (const auto& pair : lowered_shape_func->funcs->functions) { + CHECK(pair.second.as()) << "must be a prim fn"; + all_prim_shape_fn_vars.push_back(pair.first); } - return CallLowered(lowered_func->prim_fn_var, visited_args, Attrs(call_lowered_attrs), - type_args, span); + call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); } + + return CallLowered(cfunc->prim_fn_var, std::move(visited_args), Attrs(call_lowered_attrs), + type_args, std::move(span)); } std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { @@ -655,7 +641,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { BaseFunc prim_func = ResolveToPrimitive(new_value); if (prim_func.defined() && !prim_func->IsInstance()) { - // Remember let var is bound (possibly indirectly) to a non-tir primitive. + // Remember let var is bound (possibly indirectly) to a primitive. Function func = Downcast(prim_func); primitive_functions_.emplace(var, func); } @@ -672,8 +658,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override { - if (function_node->HasNonzeroAttr(attr::kPrimitive)) { - // Nothing to lower inside primitive functions. + if (function_node->HasNonzeroAttr(attr::kPrimitive) || + function_node->GetAttr(attr::kExternalSymbol)) { + // Nothing to lower inside primitive/external functions. return GetRef(function_node); } else { return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node); @@ -682,8 +669,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { // Passes before lowering might insert a call_lowered to call a function that has already - // been lowered. Therefore we might see call_lowered ops here, but we don't need to do anything - // because ResolveToPrimitive returns null for all calls where the call_node->op is an OpNode + // been lowered. Therefore we might see call_lowered ops here, but we don't need to do + // anything because ResolveToPrimitive returns null for all calls where the call_node->op is + // an OpNode Call call = GetRef(call_node); // Look for (indirect) calls to primitives. @@ -1056,8 +1044,6 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // retrieved. They may, however, have been renamed. compiler->AddExterns(updated_module); - VLOG(1) << "rewritten but without lowered functions:" << std::endl << PrettyPrint(updated_module); - // Add the lowered functions. IRModule lowered_module = compiler->GetLoweredFunctions(); for (const auto& kv : lowered_module->functions) { @@ -1066,7 +1052,9 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr VLOG(1) << "rewritten with lowered functions:" << std::endl << PrettyPrint(updated_module); - // Invoke external codegen for all Functions in the cache tagged with "Compiler". + // Invoke external codegen for all Functions in the cache tagged with "Compiler", and + // annotate the module with the resulting runtime modules. + // TODO(mbs): runtime modules should be first class rather than attributes. Array external_modules = compiler->LowerExternalFunctions(); // Annotate the module with C Device API context mapping (this is until we have Target's @@ -1080,6 +1068,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr if (backend::IsAutoSchedulerEnabled()) { // Capture all the operator weights. + // TODO(mbs): Constants should be represented as global defns in the IRModule. updated_module = WithAttr(updated_module, "op_weights", compiler->GetOpWeights()); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 867f9dbb287a5..f66505a4322b9 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -934,7 +934,10 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) VLOG(1) << std::endl << "-------------------------------------------------" << std::endl - << exec_->GetVirtualDevices() << exec_->GetConstants() << exec_->GetBytecode() + << exec_->GetVirtualDevices() // + << exec_->GetConstants() // + << exec_->GetPrimitives() // + << exec_->GetBytecode() // << "-------------------------------------------------"; if (backend::IsAutoSchedulerEnabled()) { diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index 07c935a221720..dcd35811c7d6d 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -79,6 +79,11 @@ class PurityVisitor : ExprFunctor { // allowed. for (const auto& kv : mod_->functions) { if (const auto* function_node = kv.second.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive) || + function_node->GetAttr(attr::kExternalSymbol)) { + // Ignore primitive and external functions. + continue; + } // Everything of interest will be recorded in the purity maps so we ignore the result. (void)VisitGlobalFunction(kv.first, GetRef(function_node)); } diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 7bed4ced868b9..8a32b8db6d5bf 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -106,24 +106,23 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } Expr DeviceAwareVisitExpr_(const CallNode* call_node) final { - Call call = GetRef(call_node); - SEScope se_scope = GetSEScope(call); CallLoweredProps props = GetCallLoweredProps(call_node); - if (!props.lowered_func.defined()) { + // This is a call to a user-defined Relay functino, which will be handled directly by + // the VM and does not need conversion to DPS. return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node); } - VLOG(1) << "looking at lowered call:" << std::endl << PrettyPrint(call); + Call call = GetRef(call_node); + VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call); - // Because we are in ANF we do not need to visit the arguments. - // TODO(mbs): But does so anyway... + SEScope se_scope = GetSEScope(call); LetList& scope = scopes_.back(); + std::vector new_args; for (const auto& arg : props.arguments) { new_args.push_back(Mutate(arg)); } - Tuple ins(new_args); Type ret_type = call_node->checked_type_; std::vector out_types = FlattenTupleType(ret_type); @@ -167,6 +166,9 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { call_node->span); } + // At this point we could be calling a PrimFunc or an 'external' and already compiled primitive. + // The calling conventions are identical. + // Handle 'dynamic' calls, ie to PrimFuncs whose result shape must be first computed // by a companion shape function. if (IsDynamic(ret_type)) { @@ -174,20 +176,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { se_scope); } - // Handle ordinary calls. - Array outs; + // Handle ordinary primitive calls. + Array outputs; for (size_t i = 0; i < out_types.size(); ++i) { - auto out = MakeStaticAllocation(&scope, out_types[i], se_scope, std::to_string(i)); - outs.push_back(out); + outputs.push_back(MakeStaticAllocation(&scope, out_types[i], se_scope, std::to_string(i))); } - Tuple output(outs); - Expr invoke = InvokeTVMOp(props.lowered_func, ins, output, + Tuple outs(outputs); + Expr invoke = InvokeTVMOp(props.lowered_func, ins, outs, Downcast(props.attrs.metadata.at("relay_attrs"))); scope.Push(OnDevice(invoke, se_scope, /*is_fixed=*/true)); - return ToTupleType(ret_type, std::vector(output->fields.begin(), output->fields.end())); + return ToTupleType(ret_type, std::vector(outputs.begin(), outputs.end())); } - /*! Returns the Relay Constant representing the 1d tensor with \p value. + /*! + * \brief Returns the Relay Constant representing the 1d tensor with \p value. * * CAUTION: Make sure the constant ends up on the correct device. */ diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 58b52da7fab85..0d8ca3ed39275 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -65,6 +65,8 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtrGetConstants(); }); } else if (name == "get_virtual_devices") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetVirtualDevices(); }); + } else if (name == "get_primitives") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetPrimitives(); }); } else if (name == "get_stats") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "save") { @@ -190,6 +192,23 @@ std::string Executable::GetVirtualDevices() const { return oss.str(); } +std::string Executable::GetPrimitives() const { + std::ostringstream os; + std::vector> entries; + entries.reserve(primitive_map.size()); + for (const auto& kv : primitive_map) { + entries.emplace_back(kv.second, kv.first); + } + std::sort(entries.begin(), entries.end(), + [](const std::pair& left, const std::pair& right) { + return left.first < right.first; + }); + for (const auto& entry : entries) { + os << "VM Primitive[" << entry.first << "]: " << entry.second << std::endl; + } + return os.str(); +} + std::string Executable::Stats() const { std::ostringstream oss; oss << "Relay VM executable statistics:" << std::endl;