From 7d9d9c956049710fa59991170cd97f4ef9896676 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 17 Aug 2021 16:52:54 -0700 Subject: [PATCH 1/2] Remove compile_engine.h for real --- src/relay/backend/build_module.cc | 4 +- src/relay/backend/compile_engine.cc | 338 ------------------ src/relay/backend/compile_engine.h | 115 ------ src/relay/backend/interpreter.cc | 7 +- .../auto_scheduler_layout_rewrite.cc | 4 +- 5 files changed, 6 insertions(+), 462 deletions(-) delete mode 100644 src/relay/backend/compile_engine.cc delete mode 100644 src/relay/backend/compile_engine.h diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b2b73e9bad02..88e9c8f058f5 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -33,7 +33,7 @@ #include "../../target/func_registry_generator.h" #include "../../target/source/codegen_source_base.h" -#include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { @@ -286,8 +286,6 @@ class RelayBuildModule : public runtime::ModuleNode { executor_ = executor; CheckAndUpdateHostConsistency(&targets_, &target_host_); BuildRelay(mod, params_, mod_name); - // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. - CompileEngine::Global()->Clear(); } protected: diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc deleted file mode 100644 index 6142e8323dea..000000000000 --- a/src/relay/backend/compile_engine.cc +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.cc - * \brief Internal compialtion engine. - */ -#include "compile_engine.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../runtime/meta_data.h" -#include "../transforms/pass_utils.h" -#include "te_compiler_cache.h" -#include "utils.h" - -namespace tvm { -namespace relay { - -TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); - -class CompileEngineImpl : public CompileEngineNode { - public: - // Lower the function. - CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { - return LowerInternal(key, mangle_fn)->cached_func; - } - - CachedFunc Lower(const CCacheKey& key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; - - return Lower(key, mangle_fn); - } - - // For now, build one module per function. - PackedFunc JIT(const CCacheKey& key) final { - auto mangle_fn = [](String name) { return name; }; - CCacheValue value = LowerInternal(key, mangle_fn); - if (value->packed_func != nullptr) return value->packed_func; - auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); - value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); - return value->packed_func; - } - - CachedFunc LowerShapeFunc(const CCacheKey& key) final { - return LowerShapeFuncInternal(key)->cached_func; - } - - Array LowerExternalFunctions() { - Array ret; - std::unordered_map cached_symbol; - std::vector cached_ext_funcs; - for (const auto& it : cache_) { - auto src_func = it.first->source_func; - ICHECK(src_func.defined()); - - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); - ICHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen.value(); - cached_ext_funcs.push_back(it.first); - - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" - << AsText(src_func, false) << "\n" - << "Functions with external codegen must have the " - << tvm::attr::kGlobalSymbol << " attr set."; - - std::string sn = symbol_name.value(); - if (!cached_symbol.count(sn)) { - cached_symbol[sn] = code_gen_name; - } else { - ICHECK_NE(cached_symbol[sn], code_gen_name) - << "Found duplicated symbol: " << sn << " for: " << code_gen_name; - } - - std::string ext_name = "relay.ext." + code_gen_name; - auto pf = tvm::runtime::Registry::Get(ext_name); - ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; - // No need to keep compiler attribute at this point, functions have been - // extracted for specific codegen. - src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); - runtime::Module ext_mod = (*pf)(src_func); - - // todo(@zhiics, @jroesch): Should this be a user visible error? - ICHECK(ext_mod.defined()) << "No external library was generated for " << ext_name - << "even though it was requested" - "by the annotated function " - << PrettyPrint(src_func); - - ret.push_back(ext_mod); - } - } - - // No need to cache external functions as we collected them all to create - // external runtime modules. - for (const auto& it : cached_ext_funcs) { - cache_.erase(it); - } - return ret; - } - - void Clear() final { cache_.clear(); } - - // List all items in the cache. - Array ListItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - // List all items in the shape_func_cache. - Array ListShapeFuncItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : shape_func_cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - /*! - * \brief Get the cache key of the function that is being lowered currently - * \return the cache key - */ - CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } - - private: - // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = cache_.find(key); - if (it != cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - if (!backend::IsCompileEngineCacheDisabled()) { - cache_[key] = value; - } - } - cur_ccache_key_ = key; - - // No need to lower external functions for now. We will invoke the external - // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto ir_module = IRModule(); - const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - auto func_name = std::string(name_node.value()); - auto target = Target("ext_dev"); - auto global_var = GlobalVar(func_name); - global_var->checked_type_ = key->source_func->checked_type(); - ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); - return value; - } - - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(mangle_fn(name), &name_map_); - }); - - // Skip lowering for device copy node. - const Expr body = (key->source_func)->body; - if (const CallNode* call_node = body.as()) { - if (call_node->attrs.as()) { - value->cached_func = cfunc; - return value; - } - } - - // NOTE: array will copy on write. - Array all_args = Array(cfunc->inputs); - for (te::Tensor arg : cfunc->outputs) { - all_args.push_back(arg); - } - // lower the function - std::unordered_map binds; - auto func_name = cfunc->prim_fn_var->name_hint; - cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); - value->cached_func = cfunc; - - return value; - } - - // implement lowered shape func - CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = shape_func_cache_.find(key); - if (it != shape_func_cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - shape_func_cache_[key] = value; - } - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); - - auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(name, &name_map_); - }); - - value->cached_func = cached_func; - return value; - } - - /*! \brief compiler cache lock*/ - std::mutex mutex_; - /*! \brief internal name map to get an unique name */ - std::unordered_map name_map_; - /*! \brief internal compiler cache */ - std::unordered_map cache_; - /*! \brief internal compiler cache for shape funcs */ - std::unordered_map shape_func_cache_; - /*! \brief the cache key of the function that is being lowered currently*/ - CCacheKey cur_ccache_key_; -}; - -/*! \brief The global compile engine */ -CompileEngine& CompileEngine::Global() { - // intentionally allocate raw pointer to avoid - // free during destructuion. - static CompileEngine* inst = new CompileEngine(make_object()); - return *inst; -} - -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool); - -TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") - .set_body_typed([](tvm::Array outputs, OpImplementation impl) { - return LoweredOutput(outputs, impl); - }); - -TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") - .set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { - return CompileEngine::Global(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) { - self->Clear(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") - .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { - return self->Lower(key, mod_name); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") - .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListItems(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListShapeFuncItems(); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->GetCurrentCCacheKey(); - }); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h deleted file mode 100644 index 4afdc6d30485..000000000000 --- a/src/relay/backend/compile_engine.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.h - * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. - * - * This layer represents the older design of the Relay compilation flow and is being deprecated - * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of - * Relay functions. - * - */ -#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ -#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "te_compiler_cache.h" - -namespace tvm { -namespace relay { - -using namespace tvm::relay::tec; - -/*! - * \brief Backend compilation engine for - * low level code generation. - */ -class CompileEngineNode : public Object { - public: - /*! \brief destructor */ - virtual ~CompileEngineNode() {} - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The mangling function for mangling names. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; - - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The module name to mangle the functions. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, const String mangle_fn) = 0; - /*! - * \brief Just in time compile to get a PackedFunc. - * \param key The key to the cached function. - * \return The result. - */ - virtual PackedFunc JIT(const CCacheKey& key) = 0; - /*! - * \brief Lower the shape function. - * \param key The key to the cached function. - * \return The result. - */ - virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; - /*! - * \brief Lower the external function using external codegen tools. - * \return The runtime moduels for each needed external codegen tool. - */ - virtual tvm::Array LowerExternalFunctions() = 0; - - /*! \brief clear the cache. */ - virtual void Clear() = 0; - - // VisitAttrs - void VisitAttrs(AttrVisitor*) {} - - static constexpr const char* _type_key = "relay.CompileEngine"; - TVM_DECLARE_FINAL_OBJECT_INFO(CompileEngineNode, Object); -}; - -/*! \brief cache entry used in compile engine */ -class CompileEngine : public ObjectRef { - public: - CompileEngine() {} - explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} - CompileEngineNode* operator->() { return static_cast(get_mutable()); } - using ContainerType = CompileEngineNode; - /*! \brief The global compile engine. */ - TVM_DLL static CompileEngine& Global(); -}; - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index af2cbae1f72d..b264fe8f5c85 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -36,7 +36,6 @@ #include #include "../transforms/pass_utils.h" -#include "compile_engine.h" #include "te_compiler.h" namespace tvm { @@ -479,13 +478,13 @@ class Interpreter : public ExprFunctor, // flattened form of this arg. Does that match what lowering actually does? int64_t state = prim_shape_fn_states[i]->value; for (const auto& nd_array : FlattenADT(args[i])) { - if (state & kNeedInputData) { + if (state & tec::kNeedInputData) { auto arr = nd_array.CopyTo(shape_device); inputs[arg_counter] = arr; setter(arg_counter, arr); ++arg_counter; } - if (state & kNeedInputShape) { + if (state & tec::kNeedInputShape) { int64_t ndim = nd_array.Shape().size(); NDArray shape_arr; if (ndim == 0) { @@ -922,7 +921,7 @@ std::pair> Prepare(IRModule mod, Device device, // Lower all primitive functions reachable from expr. // TODO(mbs): This should be just another pass in seq above, which requires LoweredModule to // be merged into IRModule. - LoweredModule lowered_module = + tec::LoweredModule lowered_module = tec::LowerTE(mod, targets, device_map, memory_plan, /*module_name=*/"intrp", [](Function func) { /* no-op */ }); return {lowered_module.main_module, lowered_module.per_target_module}; diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 7a86af8aeffa..c24c41a086d4 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -34,7 +34,7 @@ #include #include -#include "../backend/compile_engine.h" +#include "../backend/te_compiler.h" #include "pattern_utils.h" namespace tvm { @@ -126,7 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); + tec::PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; From 79d42e5c81d4cdad5c65f47cdd4591987014b645 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 17 Aug 2021 22:56:42 -0700 Subject: [PATCH 2/2] Fix format --- src/relay/transforms/auto_scheduler_layout_rewrite.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index c24c41a086d4..c538dac048b3 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -126,7 +126,8 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - tec::PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); + tec::PrimFuncFor(GetRef(func), Target::Current(), + [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function.";