From d89917b6705fa18ec9d1287f67dc8380d8dc3a1e Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 4 Feb 2017 15:46:25 -0800 Subject: [PATCH] [PASS] StorageFlatten and StorageSync, safe condition in schedul_ops, gemm example. (#31) --- include/tvm/codegen.h | 48 --- include/tvm/ir.h | 19 ++ include/tvm/ir_pass.h | 58 ++++ include/tvm/ir_visitor.h | 1 + python/tvm/build.py | 12 +- src/api/api_codegen.cc | 12 - src/api/api_pass.cc | 3 + src/codegen/codegen_c.cc | 158 ++++++---- src/codegen/codegen_c.h | 35 ++- src/codegen/codegen_cuda.cc | 18 +- src/codegen/codegen_cuda.h | 3 + src/codegen/codegen_opencl.cc | 40 ++- src/codegen/codegen_opencl.h | 6 +- src/codegen/codegen_stack_vm.cc | 17 +- src/codegen/codegen_stack_vm.h | 2 +- src/pass/ir_mutator.cc | 1 + src/pass/ir_util.h | 5 + src/pass/ir_visitor.cc | 15 +- src/{codegen => pass}/make_api.cc | 9 +- src/{codegen => pass}/split_host_device.cc | 13 +- src/pass/storage_flatten.cc | 188 ++++++++---- src/pass/storage_sync.cc | 283 ++++++++++++++++++ src/runtime/cuda/cuda_module.cc | 2 +- src/runtime/opencl/opencl_module.cc | 6 +- src/runtime/thread_axis_args.h | 106 ------- src/runtime/thread_storage_scope.h | 161 ++++++++++ src/schedule/bound.cc | 23 +- src/schedule/compute_expr.h | 8 +- src/schedule/schedule_ops.cc | 116 +++++-- tests/python/integration/test_gemm.py | 87 ++++++ tests/python/unittest/test_codegen_device.py | 4 +- ...odegen_makeapi.py => test_pass_makeapi.py} | 2 +- .../python/unittest/test_pass_storage_sync.py | 31 ++ .../python/unittest/test_runtime_stack_vm.py | 10 +- 34 files changed, 1113 insertions(+), 389 deletions(-) rename src/{codegen => pass}/make_api.cc (98%) rename src/{codegen => pass}/split_host_device.cc (96%) create mode 100644 src/pass/storage_sync.cc delete mode 100644 src/runtime/thread_axis_args.h create mode 100644 src/runtime/thread_storage_scope.h create mode 100644 tests/python/integration/test_gemm.py rename tests/python/unittest/{test_codegen_makeapi.py => test_pass_makeapi.py} (92%) create mode 100644 tests/python/unittest/test_pass_storage_sync.py diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index 4d76a88a7265..c1796be8e0ca 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -21,54 +21,6 @@ using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -/*! - * \brief Make an user callable API LoweredFunc. - * - * The main task of this function is to create code to : - * - Map the values in the api_args to of Var that is required by body. - * - Insert assertions to check type/value of the passed arguments. - * - * \param body The body of the function. - * \param name The name of the function. - * \param api_args Arguments to the function, can be either Var, or Buffer - * \param num_packed_args Number of arguments that are processed in packed form. - * \return a LoweredFunc with the specified signiture. - * - * \note - * The function signiture have two cases - * - * if num_packed_args is zero: - * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) - * - * if num_packed_args is not zero: - * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, - * api_arg_k, api_arg_k+1, ... api_arg_n) - * - * where n == len(api_args), k == num_packed_args - * - * There is no thread_axis in generated function. - */ -LoweredFunc MakeAPI(Stmt body, - std::string name, - Array api_args, - int num_packed_args); - -/*! - * \brief Count number of undefined vars in f. - * \param f The function to be checked. - * \return Number of undefined vars. - */ -Array UndefinedVars(const LoweredFunc& f); - -/*! - * \brief Split the function into a host function and device functions. - * \param func The function to be splitted. - * - * \return Array of functions, the first one is host function, - * the others are device functions. - */ -Array SplitHostDevice(LoweredFunc func); - /*! * \brief Build a stack VM function. * \param func The LoweredFunc to be build diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 8de8615b0e06..5c22fe27bb2c 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -88,6 +88,25 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; * } */ constexpr const char* tvm_call_global = "tvm_call_global"; +/*! + * \brief See pesudo code + * + * int tvm_call_device(name, TVMValue* args) { + * PackedFunc df = CodeGenEnv->GetDevice(name); + * f (args, type_code_of(args), len(args)); + * return 0; + * } + */ +constexpr const char* tvm_call_device = "tvm_call_device"; +/*! + * \brief See pesudo code + * + * int tvm_storage_sync(std::string storage_scope) { + * __sync(storage_scope); + * return 0; + * } + */ +constexpr const char* tvm_storage_sync = "tvm_storage_sync"; /*! \brief The field id of each field in array */ enum TVMArrayFieldKind { diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index fc7eab94a4cf..8eaec0f52315 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -14,9 +14,11 @@ #include #include #include +#include #include "./expr.h" #include "./buffer.h" #include "./schedule.h" +#include "./lowered_func.h" namespace tvm { namespace ir { @@ -95,6 +97,62 @@ Stmt Inline(Stmt stmt, Stmt StorageFlatten(Stmt stmt, Map extern_buffer); +/*! + * \brief Make an user callable API LoweredFunc. + * + * The main task of this function is to create code to : + * - Map the values in the api_args to of Var that is required by body. + * - Insert assertions to check type/value of the passed arguments. + * + * \param body The body of the function. + * \param name The name of the function. + * \param api_args Arguments to the function, can be either Var, or Buffer + * \param num_packed_args Number of arguments that are processed in packed form. + * \return a LoweredFunc with the specified signiture. + * + * \note + * The function signiture have two cases + * + * if num_packed_args is zero: + * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) + * + * if num_packed_args is not zero: + * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, + * api_arg_k, api_arg_k+1, ... api_arg_n) + * + * where n == len(api_args), k == num_packed_args + * + * There is no thread_axis in generated function. + */ +LoweredFunc MakeAPI(Stmt body, + std::string name, + Array api_args, + int num_packed_args); + +/*! + * \brief Count number of undefined vars in f. + * \param f The function to be checked. + * \return Number of undefined vars. + */ +Array UndefinedVars(const LoweredFunc& f); + +/*! + * \brief Split the function into a host function and device functions. + * \param func The function to be splitted. + * + * \return Array of functions, the first one is host function, + * the others are device functions. + */ +Array SplitHostDevice(LoweredFunc func); + +/*! + * \brief Insert sync between parallel read/write of shared buffers. + * + * \param stmt The stmt to be trasnformed. + * \param storage_scope The storage scope considered. + */ +LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); + } // namespace ir } // namespace tvm diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index 0df5d3e324f6..e5711f65ff86 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -40,6 +40,7 @@ class IRVisitor { virtual void Visit_(const LetStmt* op); virtual void Visit_(const For* op); virtual void Visit_(const Allocate* op); + virtual void Visit_(const IfThenElse* op); virtual void Visit_(const Load* op); virtual void Visit_(const Store* op); virtual void Visit_(const Let* op); diff --git a/python/tvm/build.py b/python/tvm/build.py index 8839031311e9..29321eabe711 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -65,17 +65,19 @@ def build(sch, stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.Simplify(stmt) - print(stmt) - fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list)) - fsplits = codegen.SplitHostDevice(fapi) + fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list)) + fsplits = ir_pass.SplitHostDevice(fapi) + fsplits = [x for x in fsplits] + for i in range(1, len(fsplits)): + fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared") + fsplits[i] = ir_pass.StorageSync(fsplits[i], "global") if record_codes is not None: output_ssa = False for i, f in enumerate(fsplits): t = target if i >= 1 else "c" record_codes.append(codegen.CompileToC(f, output_ssa, t)) - for c in record_codes: - print(c) + if target == "cuda": ret = codegen.BuildNVRTC(fsplits, "stackvm") elif target == "opencl": diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 1cb5a2ad0088..7161016f507f 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -31,18 +31,6 @@ TVM_REGISTER_API(_codegen_CompileToC) } }); - -TVM_REGISTER_API(_codegen_MakeAPI) -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = MakeAPI( - args[0], args[1], args[2], args[3]); - }); - -TVM_REGISTER_API(_codegen_SplitHostDevice) -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = SplitHostDevice(args[0]); - }); - TVM_REGISTER_API(_codegen_BuildStackVM) .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = BuildStackVM(args[0], diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index f549b6b2ee25..6e7bbd849171 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -52,6 +52,9 @@ REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(VerifySSA); REGISTER_PASS4(Inline); REGISTER_PASS2(StorageFlatten); +REGISTER_PASS2(StorageSync); +REGISTER_PASS4(MakeAPI); +REGISTER_PASS1(SplitHostDevice); } // namespace ir } // namespace tvm diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 737cdc18bd7a..2f61e3be920f 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -20,7 +20,6 @@ std::string CodeGenC::Compile(LoweredFunc f, HandleTypeRegister(kv.first.get(), kv.second.type()); } - this->indent += 2; this->stream << "void " << f->name << "("; for (size_t i = 0; i < f->args.size(); ++i) { Var v = f->args[i]; @@ -38,8 +37,9 @@ std::string CodeGenC::Compile(LoweredFunc f, stream << ' ' << vid; } stream << ") {\n"; + int func_scope = this->BeginScope(); this->PrintStmt(f->body); - this->indent -= 2; + this->EndScope(func_scope); this->PrintIndent(); this->stream << "}\n"; return stream.str(); @@ -54,19 +54,23 @@ std::string CodeGenC::SSAGetID(std::string src, Type t) { if (name_alloc_map_.count(src)) return src; auto it = ssa_assign_map_.find(src); if (it != ssa_assign_map_.end()) { - return it->second; - } else { - this->PrintIndent(); - std::string id = GetUniqueName("_"); - ssa_assign_map_[src] = id; - if (src.length() > 3 && - src[0] == '(' && src[src.length() - 1] == ')') { - src = src.substr(1, src.length() - 2); + if (scope_mark_.at(it->second.scope_id)) { + return it->second.vid; } - PrintType(t, stream); - stream << ' ' << id << " = " << src << ";\n"; - return id; } + + this->PrintIndent(); + SSAEntry e; + e.vid = GetUniqueName("_"); + e.scope_id = static_cast(scope_mark_.size() - 1); + ssa_assign_map_[src] = e; + if (src.length() > 3 && + src[0] == '(' && src[src.length() - 1] == ')') { + src = src.substr(1, src.length() - 2); + } + PrintType(t, stream); + stream << ' ' << e.vid << " = " << src << ";\n"; + return e.vid; } void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*) @@ -142,9 +146,12 @@ void CodeGenC::MarkConst(std::string vid) { if (print_ssa_form_) { auto it = ssa_assign_map_.find(vid); if (it == ssa_assign_map_.end()) { - ssa_assign_map_[vid] = vid; + SSAEntry e; + e.vid = vid; + e.scope_id = 0; + ssa_assign_map_[vid] = e; } else { - CHECK_EQ(it->second, vid); + CHECK_EQ(it->second.vid, vid); } } } @@ -242,6 +249,9 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) }) .set_dispatch([](const FloatImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) PrintConst(op, os, p); + }) +.set_dispatch([](const StringImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) + os << "\"" << op->value << "\""; }); template @@ -340,49 +350,22 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) .set_dispatch([](const ProducerConsumer *op, CodeGenC* p) { p->PrintStmt(op->body); }) -.set_dispatch([](const For *op, CodeGenC* p) { - std::string extent = p->PrintExpr(op->extent); - p->PrintIndent(); - std::string vid = p->AllocVarID(op->loop_var.get()); - CHECK(is_zero(op->min)); - p->stream << "for ("; - p->PrintType(op->loop_var.type(), p->stream); - p->stream << ' ' << vid << " = 0; " - << vid << " < " << extent - << "; ++" << vid << ") {\n"; - p->indent += 2; - p->PrintStmt(op->body); - p->indent -= 2; - p->PrintIndent(); - p->stream << "}\n"; - }) .set_dispatch([](const Block *op, CodeGenC* p) { p->PrintStmt(op->first); if (op->rest.defined()) p->PrintStmt(op->rest); }) .set_dispatch([](const Evaluate *op, CodeGenC* p) { if (is_const(op->value)) return; - std::string vid = p->PrintExpr(op->value); - p->PrintIndent(); - p->stream << "(void)" << vid << ";\n"; - }) -.set_dispatch([](const IfThenElse *op, CodeGenC* p) { - std::string cond = p->PrintExpr(op->condition); - p->PrintIndent(); - p->stream << "if (" << cond << ") {\n"; - p->indent += 2; - p->PrintStmt(op->then_case); - p->indent -= 2; - if (op->else_case.defined()) { + const Call* call = op->value.as(); + + if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) { + p->PrintStorageSync(call->args[0].as()->value); + } else { + std::string vid = p->PrintExpr(op->value); p->PrintIndent(); - p->stream << "} else {\n"; - p->indent += 2; - p->PrintStmt(op->else_case); - p->indent -= 2; + p->stream << "(void)" << vid << ";\n"; } - p->PrintIndent(); - p->stream << "}\n"; -}); + }); #define DISPATCH_EXPR(OP) \ @@ -517,13 +500,22 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) .set_dispatch([](const Store *op, CodeGenC* p) { p->PrintStmt(op); }) .set_dispatch([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); }) .set_dispatch([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); }); +.set_dispatch([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); }) +.set_dispatch([](const For *op, CodeGenC* p) { p->PrintStmt(op); }) +.set_dispatch([](const IfThenElse *op, CodeGenC* p) { p->PrintStmt(op); }); -void CodeGenC::PrintThreadTagExpr( - std::string thread_tag, std::ostream& os) const { // NOLINT(*) +void CodeGenC::PrintThreadIndexExpr( + std::string thread_tag, std::ostream& os) { // NOLINT(*) os << thread_tag; } +void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*) +} + +void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) + CHECK_EQ(scope, "global"); +} + void CodeGenC::PrintStmt(const LetStmt* op) { std::string value = PrintExpr(op->value); if (print_ssa_form_) { @@ -581,9 +573,12 @@ void CodeGenC::PrintStmt(const Allocate* op) { int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + const Variable* buffer = op->buffer_var.as(); + std::string scope = alloc_storage_scope_.at(buffer); + PrintStorageScope(scope, stream); PrintType(op->type, stream); stream << ' '<< vid << '[' - << constant_size << "]\n;"; + << constant_size << "];\n"; } HandleTypeRegister(op->buffer_var.get(), op->type); this->PrintStmt(op->body); @@ -599,10 +594,14 @@ void CodeGenC::PrintStmt(const AttrStmt* op) { stream << ' ' << AllocVarID(iv->var.get()) << " = "; - PrintThreadTagExpr(iv->thread_tag, stream); + PrintThreadIndexExpr(iv->thread_tag, stream); stream << ";\n"; } } + } else if (op->type_key == "storage_scope") { + const Variable* v = op->node.as(); + CHECK(v); + alloc_storage_scope_[v] = op->value.as()->value; } this->PrintStmt(op->body); } @@ -619,5 +618,54 @@ void CodeGenC::PrintStmt(const AssertStmt* op) { } } +int CodeGenC::BeginScope() { + int sid = static_cast(scope_mark_.size()); + scope_mark_.push_back(true); + indent += 2; + return sid; +} + +void CodeGenC::EndScope(int scope_id) { + scope_mark_[scope_id] = false; + indent -= 2; +} + +void CodeGenC::PrintStmt(const For* op) { + std::string extent = PrintExpr(op->extent); + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + CHECK(is_zero(op->min)); + stream << "for ("; + PrintType(op->loop_var.type(), stream); + stream << ' ' << vid << " = 0; " + << vid << " < " << extent + << "; ++" << vid << ") {\n"; + int for_scope = BeginScope(); + PrintStmt(op->body); + this->EndScope(for_scope); + PrintIndent(); + stream << "}\n"; +} + +void CodeGenC::PrintStmt(const IfThenElse* op) { + std::string cond = PrintExpr(op->condition); + PrintIndent(); + stream << "if (" << cond << ") {\n"; + int then_scope = BeginScope(); + PrintStmt(op->then_case); + this->EndScope(then_scope); + + if (op->else_case.defined()) { + PrintIndent(); + stream << "} else {\n"; + int else_scope = BeginScope(); + PrintStmt(op->else_case); + this->EndScope(else_scope); + } + PrintIndent(); + stream << "}\n"; +} + + } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 30ae1d6c46bb..d4e70379eee9 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace tvm { @@ -80,13 +81,18 @@ class CodeGenC { virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*) /*! * \brief Print expr representing the thread tag - * \param thread_tag The tag in the thread. + * \param tag The tag in the thread. * \param os The strean to output to */ - virtual void PrintThreadTagExpr( - std::string thread_tag, std::ostream& os) const; // NOLINT(*) + virtual void PrintThreadIndexExpr( + std::string tag, std::ostream& os); // NOLINT(*) + virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(* + virtual void PrintStorageSync(const std::string& scope); // NOLINT(*) + virtual void PrintStmt(const ir::LetStmt* op); virtual void PrintStmt(const ir::Store* op); + virtual void PrintStmt(const ir::For* op); + virtual void PrintStmt(const ir::IfThenElse* op); virtual void PrintStmt(const ir::Allocate* op); virtual void PrintStmt(const ir::AttrStmt* op); virtual void PrintStmt(const ir::AssertStmt* op); @@ -114,6 +120,13 @@ class CodeGenC { std::string arg_addr_space_; private: + /*! \brief entry in ssa assign map */ + struct SSAEntry { + /*! \brief The value id */ + std::string vid; + /*! \brief The scope id */ + int scope_id; + }; /*! * \brief Get the SSA ID corresponds to src * If necessary, generate new assignment @@ -121,6 +134,16 @@ class CodeGenC { * \param t The type of the expression. */ std::string SSAGetID(std::string src, Type t); + /*! + * \brief mark the beginning of a new scope + * \return The scope id. + */ + int BeginScope(); + /*! + * \brief mark the end of an old scope. + * \param scope_id The scope id to be ended. + */ + void EndScope(int scope_id); /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. @@ -145,10 +168,14 @@ class CodeGenC { std::unordered_map var_idmap_; /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; + /*! \brief the storage scope of allocation */ + std::unordered_map alloc_storage_scope_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief assignment map of ssa */ - std::unordered_map ssa_assign_map_; + std::unordered_map ssa_assign_map_; + /*! \brief array to check whether we are inside certain scope */ + std::vector scope_mark_; }; } // namespace codegen diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index a9b69ed9e491..b4957a3d543e 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -22,6 +22,23 @@ std::string CodeGenCUDA::Compile( return CodeGenC::Compile(f, output_ssa); } +void CodeGenCUDA::PrintStorageSync(const std::string& sync) { + if (sync == "shared") { + this->PrintIndent(); + this->stream << "__syncthreads();\n"; + } else if (sync == "global") { + LOG(FATAL) << "not supported"; + } +} + +void CodeGenCUDA::PrintStorageScope( + const std::string& scope, std::ostream& os) { // NOLINT(*) + CHECK_NE(scope, "global"); + if (scope == "shared") { + os << "__shared__ "; + } +} + #if TVM_CUDA_RUNTIME std::unordered_map MakeNVRTC(Array funcs) { @@ -56,7 +73,6 @@ MakeNVRTC(Array funcs) { PackedFunc BuildNVRTC(Array fsplits, std::string host_mode) { Array device_list(fsplits.begin() + 1, fsplits.end()); std::unordered_map device_funcs = MakeNVRTC(device_list); - if (host_mode == "stackvm") { StackVM vm = codegen::CodeGenStackVM().Compile(fsplits[0], device_funcs); auto f = [vm](TVMArgs args, TVMRetValue* rv) { diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index b8e55b80f4b2..a8cca432f49a 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -25,6 +25,9 @@ class CodeGenCUDA : public CodeGenC { */ std::string Compile(LoweredFunc f, bool output_ssa); + // override behavior + void PrintStorageSync(const std::string& sync) final; + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) }; } // namespace codegen diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 54b9b849461a..3d54a66a8251 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -10,6 +10,7 @@ #include "./codegen_stack_vm.h" #include "../runtime/opencl/opencl_common.h" #include "../runtime/opencl/opencl_module.h" +#include "../runtime/thread_storage_scope.h" namespace tvm { namespace codegen { @@ -22,22 +23,31 @@ std::string CodeGenOpenCL::Compile( return CodeGenC::Compile(f, output_ssa); } -void CodeGenOpenCL::PrintThreadTagExpr( - std::string thread_tag, std::ostream& os) const { // NOLINT(*) - if (thread_tag == "threadIdx.x") { - os << "get_local_id(0)"; - } else if (thread_tag == "threadIdx.y") { - os << "get_local_id(1)"; - } else if (thread_tag == "threadIdx.z") { - os << "get_local_id(2)"; - } else if (thread_tag == "blockIdx.x") { - os << "get_global_id(0) / get_local_size(0)"; - } else if (thread_tag == "blockIdx.y") { - os << "get_global_id(1) / get_local_size(1)"; - } else if (thread_tag == "blockIdx.z") { - os << "get_global_id(2) / get_local_size(2)"; +void CodeGenOpenCL::PrintThreadIndexExpr( + std::string tag, std::ostream& os) { // NOLINT(*) + runtime::ThreadScope ts = runtime::ThreadScope::make(tag); + if (ts.rank == 1) { + os << "get_local_id(" << ts.dim_index << ")"; } else { - LOG(FATAL) << "unknown thread tag"; + os << "get_global_id(" << ts.dim_index << ")" + << " / get_local_size(" << ts.dim_index << ")"; + } +} + + +void CodeGenOpenCL::PrintStorageSync(const std::string& sync) { + if (sync == "shared") { + this->PrintIndent(); + this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n"; + } else if (sync == "global") { + LOG(FATAL) << "not supported"; + } +} + +void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) + CHECK_NE(scope, "global"); + if (scope == "shared") { + os << "__local "; } } diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 748599708752..a0b8120f1c30 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -26,8 +26,10 @@ class CodeGenOpenCL : public CodeGenC { std::string Compile(LoweredFunc f, bool output_ssa); // override print thread tag. - void PrintThreadTagExpr( - std::string thread_tag, std::ostream& os) const final; // NOLINT(*) + void PrintThreadIndexExpr( + std::string tag, std::ostream& os) final; // NOLINT(*) + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageSync(const std::string& scope) final; // NOLINT(*) }; } // namespace codegen diff --git a/src/codegen/codegen_stack_vm.cc b/src/codegen/codegen_stack_vm.cc index a2fdf6235348..d1fb0751a8ab 100644 --- a/src/codegen/codegen_stack_vm.cc +++ b/src/codegen/codegen_stack_vm.cc @@ -37,7 +37,7 @@ StackVM CodeGenStackVM::Compile( for (const auto& kv : device_funcs) { int fid = static_cast(vm_.packed_func.size()); vm_.packed_func.push_back(kv.second); - device_fun_idmap_[kv.first] = fid; + device_fun_idmap_[kv.first->name] = fid; } this->Push(f->body); return std::move(vm_); @@ -228,20 +228,19 @@ void CodeGenStackVM::Push_(const ir::Call* op) { this->Push(op->args[0]); this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::EQ_I64); - } else if (op->call_type == Call::Extern && op->func.defined()) { - CHECK(op->func->is_type()); - LoweredFunc f(op->func.node_); - auto it = device_fun_idmap_.find(f); + } else if (op->is_intrinsic(intrinsic::tvm_call_device)) { + std::string func_name = op->args[0].as()->value; + auto it = device_fun_idmap_.find(func_name); CHECK(it != device_fun_idmap_.end()) - << "Cannot find device function " << f->name; + << "Cannot find device function " << func_name; const int fid = it->second; - std::vector arg_type_codes(op->args.size()); - for (size_t i = 0; i < op->args.size(); ++i) { + std::vector arg_type_codes; + for (size_t i = 1; i < op->args.size(); ++i) { this->Push(op->args[i]); Type t = op->args[i].type(); int lanes = t.lanes(); CHECK_EQ(lanes, 1); - arg_type_codes[i] = t.code(); + arg_type_codes.push_back(t.code()); } this->PushCallPacked(fid, arg_type_codes); } else { diff --git a/src/codegen/codegen_stack_vm.h b/src/codegen/codegen_stack_vm.h index 6a81b3bd6b7f..bf640d0b5ef6 100644 --- a/src/codegen/codegen_stack_vm.h +++ b/src/codegen/codegen_stack_vm.h @@ -110,7 +110,7 @@ class CodeGenStackVM { /*! \brief id of each global function */ std::unordered_map global_fun_idmap_; /*! \brief id of device function */ - std::unordered_map device_fun_idmap_; + std::unordered_map device_fun_idmap_; }; } // namespace codegen diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 72f118a4667f..c6b7e6b51c85 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -75,6 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) .DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Store) .DISPATCH_TO_MUTATE_STMT(For) +.DISPATCH_TO_MUTATE_STMT(Allocate) .DISPATCH_TO_MUTATE_STMT(Free); Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 794dcd820715..2fbff80995f6 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -45,6 +45,11 @@ inline Stmt MergeNest(std::vector nest, Stmt body) { body = Stmt(n); } else if (s.as()) { body = Block::make(s, body); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); } else { LOG(FATAL) << "not supported nest type"; } diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 4b8b005ddea5..5baaa851970e 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -59,6 +59,8 @@ inline void VisitRDom(const Array& rdom, IRVisitor* v) { TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Variable) .DISPATCH_TO_VISIT(LetStmt) +.DISPATCH_TO_VISIT(AttrStmt) +.DISPATCH_TO_VISIT(IfThenElse) .DISPATCH_TO_VISIT(For) .DISPATCH_TO_VISIT(Allocate) .DISPATCH_TO_VISIT(Load) @@ -107,6 +109,14 @@ void IRVisitor::Visit_(const Store *op) { this->Visit(op->index); } +void IRVisitor::Visit_(const IfThenElse *op) { + this->Visit(op->condition); + this->Visit(op->then_case); + if (op->else_case.defined()) { + this->Visit(op->else_case); + } +} + void IRVisitor::Visit_(const Let *op) { this->Visit(op->value); this->Visit(op->body); @@ -200,11 +210,6 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) v->Visit(op->first); v->Visit(op->rest); }) -.set_dispatch([](const IfThenElse *op, IRVisitor* v) { - v->Visit(op->condition); - v->Visit(op->then_case); - v->Visit(op->else_case); - }) .set_dispatch([](const Evaluate *op, IRVisitor* v) { v->Visit(op->value); }); diff --git a/src/codegen/make_api.cc b/src/pass/make_api.cc similarity index 98% rename from src/codegen/make_api.cc rename to src/pass/make_api.cc index 3c1324a9aa6f..de35b54cc75d 100644 --- a/src/codegen/make_api.cc +++ b/src/pass/make_api.cc @@ -2,7 +2,7 @@ * Copyright (c) 2017 by Contributors * \file make_api.cc Build API function. */ -#include +#include #include #include @@ -10,11 +10,10 @@ #include #include -#include "../pass/ir_util.h" +#include "./ir_util.h" namespace tvm { -namespace codegen { -using namespace ir; +namespace ir { inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) { return Call::make( @@ -196,5 +195,5 @@ LoweredFunc MakeAPI(Stmt body, } return f; } -} // namespace codegen +} // namespace ir } // namespace tvm diff --git a/src/codegen/split_host_device.cc b/src/pass/split_host_device.cc similarity index 96% rename from src/codegen/split_host_device.cc rename to src/pass/split_host_device.cc index 213dd8a40dfc..186733fa2f71 100644 --- a/src/codegen/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -3,7 +3,6 @@ * \file split_host_device.cc * \brief Split device function from host. */ -#include #include #include #include @@ -11,9 +10,7 @@ #include namespace tvm { -namespace codegen { - -using namespace ir; +namespace ir { // use/def analysis, also delete unreferenced lets class IRUseDefAnalysis : public IRMutator { @@ -161,7 +158,7 @@ class HostDeviceSplitter : public IRMutator { private: Stmt SplitDeviceFunc(Stmt body) { std::ostringstream os; - os << name_ << "_kernel" << device_funcs_.size(); + os << name_ << "__kernel" << device_funcs_.size(); std::shared_ptr n = std::make_shared(); // isolate the device function. IRUseDefAnalysis m; @@ -181,6 +178,7 @@ class HostDeviceSplitter : public IRMutator { } LoweredFunc f_device(n); Array call_args; + call_args.push_back(StringImm::make(f_device->name)); for (Var arg : n->args) { call_args.push_back(arg); } @@ -190,7 +188,8 @@ class HostDeviceSplitter : public IRMutator { } device_funcs_.emplace_back(f_device); return Evaluate::make(Call::make( - Int(32), f_device->name, call_args, Call::Extern, f_device)); + Int(32), intrinsic::tvm_call_device, + call_args, Call::Intrinsic)); } // function name @@ -214,5 +213,5 @@ Array SplitHostDevice(LoweredFunc func) { return HostDeviceSplitter().Split(func); } -} // namespace codegen +} // namespace ir } // namespace tvm diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index ab344fcc0f3c..5cfb8f9c5c2a 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -6,6 +6,8 @@ #include #include #include +#include "./ir_util.h" +#include "../runtime/thread_storage_scope.h" namespace tvm { namespace ir { @@ -45,10 +47,9 @@ namespace tvm { namespace ir { using Halide::Internal::Region; +using runtime::StorageScope; +using runtime::ThreadScope; -// inliner to inline a function -// the result may not be SSA, -// ConvertSSA need to be applied after this pass class StorageFlattener : public IRMutator { public: explicit StorageFlattener(Map extern_buffer) { @@ -59,9 +60,123 @@ class StorageFlattener : public IRMutator { buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e; } } - Expr Mutate(Expr expr) final { - expr = IRMutator::Mutate(expr); - const Call* op = expr.as(); + + Stmt Flatten(Stmt stmt) { + stmt = this->Mutate(stmt); + StorageScope key; key.rank = 0; + if (move_alloc_out_) { + StorageScope key; key.rank = 0; + stmt = MergeNest(allocs_[key], stmt); + } + return stmt; + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->type_key == "realize_scope") { + storage_scope_[op->node.get()] = op->value.as()->value; + return this->Mutate(op->body); + } else if (op->type_key == "scope") { + IterVar iv(op->node.node_); + if (iv->thread_tag.length() != 0) { + ThreadScope ts = ThreadScope::make(iv->thread_tag); + curr_thread_scope_.push_back(ts); + Stmt stmt = IRMutator::Mutate_(op, s); + curr_thread_scope_.pop_back(); + op = stmt.as(); + + bool first_scope = true; + for (const ThreadScope& t : curr_thread_scope_) { + if (t.rank == ts.rank) first_scope = false; + } + if (first_scope && move_alloc_out_) { + StorageScope key; + key.rank = ts.rank + 1; + std::vector& vec = allocs_[key]; + if (vec.size() != 0) { + Stmt body = MergeNest(vec, op->body); + vec.clear(); + return AttrStmt::make( + op->node, op->type_key, op->value, body); + } + } + return stmt; + } + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Provide* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + TensorKey key{op->func, op->value_index}; + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << key.f; + const BufferEntry& e = it->second; + CHECK(!e.released) + << "Read a buffer that is already out of scope"; + return e.buffer.MakeStore(e.RelIndex(op->args), op->value); + } + + Stmt Mutate_(const Realize* op, const Stmt& s) final { + TensorKey key{op->func, op->value_index}; + if (buf_map_.count(key)) { + CHECK(buf_map_.at(key).external); + return this->Mutate(op->body); + } else { + // create a buffer entry + // TODO(tqchen) allow permutation and inference of index dimension. + BufferEntry e; + e.bounds = op->bounds; + Array shape; + for (auto r : e.bounds) { + shape.push_back(r->extent); + } + e.buffer = Buffer(shape, op->type, key.GetName()); + + buf_map_[key] = e; + Stmt body = this->Mutate(op->body); + buf_map_[key].released = true; + // deduce current storage scope. + auto it = storage_scope_.find(op->func.get()); + CHECK(it != storage_scope_.end()); + StorageScope key; key.rank = 0; + const std::string& skey = it->second; + if (skey.length() == 0) { + if (curr_thread_scope_.size() != 0) { + key.rank = curr_thread_scope_.back().rank + 1; + } + } else { + key = StorageScope::make(skey); + } + + if (move_alloc_out_) { + allocs_[key].push_back( + AttrStmt::make( + e.buffer->data, "storage_scope", + StringImm::make(key.to_string()), + Evaluate::make(0))); + allocs_[key].push_back( + Allocate::make( + e.buffer->data, e.buffer->dtype, e.buffer->shape, + make_const(Bool(e.buffer->dtype.lanes()), true), + Evaluate::make(0))); + return body; + } else { + Stmt ret = Allocate::make( + e.buffer->data, e.buffer->dtype, e.buffer->shape, + make_const(Bool(e.buffer->dtype.lanes()), true), body); + ret = AttrStmt::make( + e.buffer->data, "storage_scope", + StringImm::make(key.to_string()), ret); + return ret; + } + } + } + + Expr Mutate_(const Call* op, const Expr& olde) final { + Expr expr = IRMutator::Mutate_(op, olde); + op = expr.as(); if (op != nullptr && op->call_type == Call::Halide) { TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); @@ -76,17 +191,6 @@ class StorageFlattener : public IRMutator { } } - Stmt Mutate(Stmt stmt) final { - const Realize* realize = stmt.as(); - if (realize != nullptr) { - return HandleRealize(realize); - } else if (stmt.as()) { - return HandleProvide(stmt); - } else { - return IRMutator::Mutate(stmt); - } - } - private: // The buffer entry in the flatten map struct BufferEntry { @@ -113,54 +217,20 @@ class StorageFlattener : public IRMutator { } } }; - + // whether move allocation to the outmost scope as possible. + bool move_alloc_out_{true}; // The buffer assignment map std::unordered_map buf_map_; - - Stmt HandleRealize(const Realize* op) { - TensorKey key{op->func, op->value_index}; - if (buf_map_.count(key)) { - CHECK(buf_map_.at(key).external); - return this->Mutate(op->body); - } else { - // create a buffer entry - // TODO(tqchen) allow permutation and inference of index dimension. - BufferEntry e; - e.bounds = op->bounds; - Array shape; - for (auto r : e.bounds) { - shape.push_back(r->extent); - } - e.buffer = Buffer(shape, op->type, key.GetName()); - - buf_map_[key] = e; - Stmt body = this->Mutate(op->body); - buf_map_[key].released = true; - - return Allocate::make( - e.buffer->data, e.buffer->dtype, e.buffer->shape, - make_const(Bool(e.buffer->dtype.lanes()), true), body); - } - } - - Stmt HandleProvide(Stmt stmt) { - stmt = IRMutator::Mutate(stmt); - const Provide* op = stmt.as(); - TensorKey key{op->func, op->value_index}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; - const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; - return e.buffer.MakeStore(e.RelIndex(op->args), op->value); - } + std::unordered_map storage_scope_; + // The current thread scope. + std::vector curr_thread_scope_; + // The allocations by rank + std::unordered_map > allocs_; }; - Stmt StorageFlatten(Stmt stmt, Map extern_buffer) { - stmt = StorageFlattener(extern_buffer).Mutate(stmt); + stmt = StorageFlattener(extern_buffer).Flatten(stmt); return stmt; } diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc new file mode 100644 index 000000000000..eac89c5a8577 --- /dev/null +++ b/src/pass/storage_sync.cc @@ -0,0 +1,283 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file storage_sync.cc + */ +#include +#include +#include +#include +#include +#include +#include "./ir_util.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +using runtime::StorageScope; + +class StorageSyncPlanner : public IRVisitor { + public: + explicit StorageSyncPlanner(StorageScope sync_scope) + : sync_scope_(sync_scope) {} + // only intended to be used once. + // The syncs inserted before each statement + std::unordered_set Plan(Stmt stmt) { + CHECK_EQ(scope_.size(), 0U); + scope_.push_back(std::vector()); + this->Visit(stmt); + this->PlanSync(false); + return std::move(syncs_inserted_); + } + void Visit_(const Load* op) final { + CHECK(allow_load_); + const Variable* buf = op->buffer_var.as(); + StorageScope s = GetScope(buf); + if (s == sync_scope_) { + curr_stmt_.access.emplace_back( + AccessEntry(buf, kRead, s)); + } + } + void Visit_(const Store* op) final { + allow_load_ = true; + CHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.stmt = op; + const Variable* buf = op->buffer_var.as(); + StorageScope s = GetScope(buf); + if (s == sync_scope_) { + curr_stmt_.access.emplace_back( + AccessEntry(buf, kWrite, s)); + } + // traverse child + IRVisitor::Visit_(op); + // push to the scope + scope_.back().push_back(curr_stmt_); + // clear access entry. + curr_stmt_.access.clear(); + allow_load_ = false; + } + void Visit_(const AttrStmt* op) final { + if (op->type_key == "storage_scope") { + const Variable* buf = op->node.as(); + storage_scope_[buf] = + StorageScope::make(op->value.as()->value); + } + IRVisitor::Visit_(op); + } + void Visit_(const For* op) final { + scope_.push_back(std::vector()); + IRVisitor::Visit_(op); + StmtEntry s; s.stmt = op; + s.access = PlanSync(true); + scope_.pop_back(); + scope_.back().emplace_back(std::move(s)); + } + void Visit_(const Call* op) final { + if (op->is_intrinsic(Call::address_of)) { + const Load *l = op->args[0].as(); + IRVisitor::Visit_(l); + } else { + IRVisitor::Visit_(op); + } + } + void Visit_(const IfThenElse* op) final { + ++condition_counter_; + this->Visit(op->condition); + scope_.push_back(std::vector()); + this->Visit(op->then_case); + + StmtEntry s; s.stmt = op; + s.access = PlanSync(false); + scope_.pop_back(); + if (op->else_case.defined()) { + scope_.push_back(std::vector()); + auto v = PlanSync(false); + scope_.pop_back(); + s.access.insert(s.access.end(), v.begin(), v.end()); + } + scope_.back().emplace_back(std::move(s)); + --condition_counter_; + } + + private: + // Storage access type + enum AccessType { + kRead, + kWrite, + kSync + }; + // The access entry + struct AccessEntry { + /*! \brief The buffer variable, if any */ + const Variable* buffer{nullptr}; + /*! \brief The type of access */ + AccessType type; + /*! \brief The storage scope */ + StorageScope scope; + // constructor + AccessEntry() {} + AccessEntry(const Variable* buffer, + AccessType type, + StorageScope scope) + : buffer(buffer), type(type), scope(scope) {} + }; + // The statment entry + struct StmtEntry { + // the associated statement. + const Node* stmt; + std::vector access; + }; + // Get current storage scope. + StorageScope GetScope(const Variable* buf) const { + auto it = storage_scope_.find(buf); + StorageScope s; s.rank = 0; + if (it == storage_scope_.end()) return s; + return it->second; + } + // Plan the sync + std::vector PlanSync(bool is_loop) { + // unsynced reads and writes + std::vector reads; + std::vector writes; + const std::vector& seq = scope_.back(); + + // if it is a loop, rotate two times to consider effect of loop. + size_t max_seq = seq.size(); + if (is_loop) max_seq *= 2; + // simulation based approach to find dependenceies + for (size_t i = 0; i < max_seq; ++i) { + const StmtEntry& s = seq[i % seq.size()]; + // check if sync before statement is needed. + bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); + // Apply the syncs added already. + if (sync_before_stmt) { + reads.clear(); + writes.clear(); + } + for (const AccessEntry& acc : s.access) { + if (acc.type == kRead) { + if (FindConflict(writes, acc)) { + sync_before_stmt = true; break; + } + } else if (acc.type == kWrite) { + if (FindConflict(reads, acc)) { + sync_before_stmt = true; break; + } + } else if (acc.type == kSync) { + reads.clear(); writes.clear(); + } + } + // If sync is inserted. remove the irrelevant things. + if (sync_before_stmt) { + reads.clear(); writes.clear(); + } + // Add the read/write of current statement + for (const AccessEntry& acc : s.access) { + if (acc.type == kRead) { + reads.push_back(acc); + } else if (acc.type == kWrite) { + writes.push_back(acc); + } else if (acc.type == kSync) { + reads.clear(); writes.clear(); + } + } + if (sync_before_stmt) { + CHECK_EQ(condition_counter_, 0) + << "Cannot insert syncs inside condition"; + syncs_inserted_.insert(s.stmt); + } + } + // return the exposed entries, remove unecessary ones. + int sync_count = 0; + // head are before first sync, tail are after last sync + std::vector head, tail; + for (const StmtEntry& s : seq) { + if (syncs_inserted_.count(s.stmt)) { + if (sync_count != 0) { + tail.clear(); + } else { + head.push_back(AccessEntry(nullptr, kSync, sync_scope_)); + } + ++sync_count; + } + for (const AccessEntry& acc : s.access) { + if (acc.type == kSync) { + if (sync_count != 0) { + tail.clear(); + } else { + head.push_back(AccessEntry(nullptr, kSync, sync_scope_)); + } + ++sync_count; + } else { + if (sync_count != 0) { + tail.push_back(acc); + } else { + head.push_back(acc); + } + } + } + } + head.insert(head.end(), tail.begin(), tail.end()); + return head; + } + // find conflicting entry in vec. + static bool FindConflict(const std::vector& vec, + const AccessEntry& e) { + for (const AccessEntry& x : vec) { + if (x.buffer == e.buffer) return true; + } + return false; + } + // Whether we are inside condition. + int condition_counter_{0}; + // whether load is enabled. + bool allow_load_{false}; + // the current free stmt entry. + StmtEntry curr_stmt_; + // access scope + std::vector > scope_; + // The storage scope of each buffer + std::unordered_map storage_scope_; + // The syncs inserted before each statement + std::unordered_set syncs_inserted_; + // The sync scope we care about. + StorageScope sync_scope_; +}; + +class StorageSyncInserter : public IRMutator { + public: + StorageSyncInserter(StorageScope sync_scope, + std::unordered_set syncs) + : sync_scope_(sync_scope), syncs_(syncs) {} + + Stmt Mutate(Stmt stmt) final { + stmt = IRMutator::Mutate(stmt); + if (syncs_.count(stmt.get())) { + stmt = Block::make( + Evaluate::make( + Call::make(Int(32), intrinsic::tvm_storage_sync, + {StringImm::make(sync_scope_.to_string())}, + Call::Intrinsic)), + stmt); + } + return stmt; + } + + StorageScope sync_scope_; + std::unordered_set syncs_; +}; + +Stmt StorageSync(Stmt stmt, std::string storage_scope) { + StorageScope sync_scope = StorageScope::make(storage_scope); + auto syncs = StorageSyncPlanner(sync_scope).Plan(stmt); + return StorageSyncInserter(sync_scope, syncs).Mutate(stmt); +} + +LoweredFunc StorageSync(LoweredFunc f, std::string storage_scope) { + auto n = std::make_shared(*f.operator->()); + n->body = StorageSync(f->body, storage_scope); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 83ead0847598..962254dd76d0 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -13,7 +13,7 @@ #include #include "./cuda_common.h" #include "../void_addr_args.h" -#include "../thread_axis_args.h" +#include "../thread_storage_scope.h" namespace tvm { namespace runtime { diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 64bff819401c..98e889311ca2 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -11,7 +11,7 @@ #include #include #include "../void_addr_args.h" -#include "../thread_axis_args.h" +#include "../thread_storage_scope.h" namespace tvm { namespace runtime { @@ -87,13 +87,13 @@ class OpenCLWrappedFunc { ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); cl_uint work_dim = static_cast(thread_axis_cfg_.work_dim()); for (cl_uint i = 0; i < work_dim; ++i) { - wl.work_size[i + 3] *= wl.work_size[i]; + wl.work_size[i] *= wl.work_size[i + 3]; } // launch kernel OPENCL_CALL(clEnqueueNDRangeKernel( queue, kernel, work_dim, nullptr, - wl.work_size + 3, wl.work_size, + wl.work_size + 3, 0, nullptr, nullptr)); } diff --git a/src/runtime/thread_axis_args.h b/src/runtime/thread_axis_args.h deleted file mode 100644 index 96b34eaddece..000000000000 --- a/src/runtime/thread_axis_args.h +++ /dev/null @@ -1,106 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file thread_axis_args.h - * \brief Extract thread axis configuration from TVMArgs. - */ -#ifndef TVM_RUNTIME_THREAD_AXIS_ARGS_H_ -#define TVM_RUNTIME_THREAD_AXIS_ARGS_H_ - -#include -#include - -namespace tvm { -namespace runtime { - -/*! \brief workload speccification */ -struct ThreadWorkLoad { - // array, first three are thread configuration. - size_t work_size[6]; - /*! - * \param i The block dimension. - * \return i-th block dim - */ - inline size_t block_dim(size_t i) const { - return work_size[i]; - } - /*! - * \param i The grid dimension. - * \return i-th grid dim - */ - inline size_t grid_dim(size_t i) const { - return work_size[i + 3]; - } -}; -/*! \brief Thread axis configuration */ -class ThreadAxisConfig { - public: - void Init(size_t base, - const std::vector& thread_axis_tags) { - base_ = base; - std::vector filled(6, false); - for (size_t i = 0; i < thread_axis_tags.size(); ++i) { - const std::string& tag = thread_axis_tags[i]; - if (tag == "threadIdx.x") { - arg_index_map_.push_back(0); - filled[0] = true; - } else if (tag == "threadIdx.y") { - arg_index_map_.push_back(1); - filled[1] = true; - } else if (tag == "threadIdx.z") { - arg_index_map_.push_back(2); - filled[2] = true; - } else if (tag == "blockIdx.x") { - arg_index_map_.push_back(3 + 0); - filled[3] = true; - } else if (tag == "blockIdx.y") { - arg_index_map_.push_back(3 + 1); - filled[3 + 1] = true; - } else if (tag == "blockIdx.z") { - arg_index_map_.push_back(3 + 2); - filled[3 + 2] = true; - } else { - LOG(FATAL) << "do not known thread_tag=" << tag; - } - } - work_dim_ = 3; - for (int i = 0; i < 3; ++i) { - if (!filled[i]) { - for (int j = i; j < 3; ++j) { - CHECK(!filled[j] && !filled[j + 3]) - << "Invalid thread group configuration"; - } - work_dim_ = i; - break; - } else { - CHECK(filled[i]) - << "Must have both threadIdx and blockIdx"; - } - } - } - // extract workload from arguments. - ThreadWorkLoad Extract(TVMArgs x) const { - ThreadWorkLoad w; - std::fill(w.work_size, w.work_size + 6, 1); - for (size_t i = 0; i < arg_index_map_.size(); ++i) { - w.work_size[arg_index_map_[i]] = - static_cast(x.values[base_ + i].v_int64); - } - return w; - } - // return the work dim - size_t work_dim() const { - return work_dim_; - } - - private: - /*! \brief base axis */ - size_t base_; - /*! \brief The worker dimension */ - size_t work_dim_; - /*! \brief The index mapping. */ - std::vector arg_index_map_; -}; - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_THREAD_AXIS_ARGS_H_ diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h new file mode 100644 index 000000000000..436fe015ad81 --- /dev/null +++ b/src/runtime/thread_storage_scope.h @@ -0,0 +1,161 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file thread_storage_scope.h + * \brief Extract thread axis configuration from TVMArgs. + */ +#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ +#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ + +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief class to represent storage scope */ +struct StorageScope { + /*! \brief The rank of the storage */ + int rank{0}; + // comparator + inline bool operator==(const StorageScope& other) const { + return rank == other.rank; + } + inline std::string to_string() const { + switch (rank) { + case 0: return "global"; + case 1: return "shared"; + case 2: return "local"; + default: LOG(FATAL) << "unknown storage scope"; return ""; + } + } + /*! + * \brief make storage scope from string + * \param s The string to be parsed. + * \return The storage scope. + */ + static StorageScope make(const std::string& s) { + StorageScope r; + if (s == "global") { + r.rank = 0; + } else if (s == "shared") { + r.rank = 1; + } else if (s == "local") { + r.rank = 2; + } else { + LOG(FATAL) << "unknown storage scope " << s; + } + return r; + } +}; + +/*! \brief class to represent thread scope */ +struct ThreadScope { + /*! \brief The rank of thread scope */ + int rank{0}; + /*! \brief the dimension index under the rank */ + int dim_index{0}; + /*! + * \brief make storage scope from string + * \param s The string to be parsed. + * \return The storage scope. + */ + static ThreadScope make(const std::string& s) { + ThreadScope r; + if (s.compare(0, 9, "blockIdx.") == 0) { + r.rank = 0; + r.dim_index = static_cast(s[9] - 'x'); + } else if (s.compare(0, 10, "threadIdx.") == 0) { + r.rank = 1; + r.dim_index = static_cast(s[10] - 'x'); + } else { + LOG(FATAL) << "Unknown threadscope " << s; + } + return r; + } +}; + + +/*! \brief workload speccification */ +struct ThreadWorkLoad { + // array, first three are thread configuration. + size_t work_size[6]; + /*! + * \param i The block dimension. + * \return i-th block dim + */ + inline size_t block_dim(size_t i) const { + return work_size[i + 3]; + } + /*! + * \param i The grid dimension. + * \return i-th grid dim + */ + inline size_t grid_dim(size_t i) const { + return work_size[i]; + } +}; +/*! \brief Thread axis configuration */ +class ThreadAxisConfig { + public: + void Init(size_t base, + const std::vector& thread_axis_tags) { + base_ = base; + std::vector filled(6, false); + for (size_t i = 0; i < thread_axis_tags.size(); ++i) { + const std::string& tag = thread_axis_tags[i]; + ThreadScope ts = ThreadScope::make(tag); + arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); + filled[ts.rank * 3 + ts.dim_index] = true; + } + work_dim_ = 3; + for (int i = 0; i < 3; ++i) { + if (!filled[i]) { + for (int j = i; j < 3; ++j) { + CHECK(!filled[j] && !filled[j + 3]) + << "Invalid thread group configuration"; + } + work_dim_ = i; + break; + } else { + CHECK(filled[i]) + << "Must have both threadIdx and blockIdx"; + } + } + } + // extract workload from arguments. + ThreadWorkLoad Extract(TVMArgs x) const { + ThreadWorkLoad w; + std::fill(w.work_size, w.work_size + 6, 1); + for (size_t i = 0; i < arg_index_map_.size(); ++i) { + w.work_size[arg_index_map_[i]] = + static_cast(x.values[base_ + i].v_int64); + } + return w; + } + // return the work dim + size_t work_dim() const { + return work_dim_; + } + + private: + /*! \brief base axis */ + size_t base_; + /*! \brief The worker dimension */ + size_t work_dim_; + /*! \brief The index mapping. */ + std::vector arg_index_map_; +}; + +} // namespace runtime +} // namespace tvm + +namespace std { +template <> +struct hash<::tvm::runtime::StorageScope> { + std::size_t operator()(const ::tvm::runtime::StorageScope& k) const { + return static_cast(k.rank); + } +}; +} // namespace std +#endif // TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 36532aa419d7..706550843326 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -9,6 +9,7 @@ #include #include "./int_set.h" #include "./graph.h" +#include "../runtime/thread_storage_scope.h" namespace tvm { namespace schedule { @@ -181,24 +182,13 @@ BoundProp(const Array& post_order, // check if scope -bool ScopeRelax(const IterVar& iv, const std::string& scope) { +inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { + using runtime::ThreadScope; + using runtime::StorageScope; if (iv->thread_tag.length() == 0) return false; if (scope.length() == 0) return false; - static std::unordered_map scope_rank{ - {"global", 0}, - {"shared", 1}, - {"local", 2} - }; - static std::unordered_map thread_tag_rank{ - {"blockIdx.x", 0}, - {"blockIdx.y", 0}, - {"blockIdx.z", 0}, - {"threadIdx.x", 1}, - {"threadIdx.y", 1}, - {"threadIdx.z", 1} - }; - return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag); + return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank; } void InferBound(const Stage& stage, @@ -248,7 +238,7 @@ void InferBound(const Stage& stage, } auto result = BoundProp(post_order, &bp_state); - // Set relaxation + // Set relaxation for the threads in parent. Map relax_set; Stage s = stage; while (s->attach_type == kScope) { @@ -259,6 +249,7 @@ void InferBound(const Stage& stage, } } } + for (auto iv : stage->op->root_iter_vars()) { CHECK(result.count(iv)); CHECK(!rmap->count(iv)); diff --git a/src/schedule/compute_expr.h b/src/schedule/compute_expr.h index 0feb582fcec2..ee1947b61039 100644 --- a/src/schedule/compute_expr.h +++ b/src/schedule/compute_expr.h @@ -32,7 +32,7 @@ template inline bool GetConst(Expr e, T* out); template<> -bool GetConst(Expr e, int64_t *out) { +inline bool GetConst(Expr e, int64_t *out) { if (e.type().is_vector()) return false; const int64_t *v = as_const_int(e); if (v) { @@ -42,7 +42,7 @@ bool GetConst(Expr e, int64_t *out) { } } template<> -bool GetConst(Expr e, uint64_t *out) { +inline bool GetConst(Expr e, uint64_t *out) { if (e.type().is_vector()) return false; const uint64_t *v = as_const_uint(e); if (v) { @@ -77,7 +77,7 @@ template<> inline Expr ComputeExpr(Expr a, Expr b) { if (is_zero(b)) return a; TVM_CONST_PROPAGATION(sub, -); - return ir::Add::make(a, b); + return ir::Sub::make(a, b); } template<> @@ -91,7 +91,7 @@ inline Expr ComputeExpr(Expr a, Expr b) { template<> inline Expr ComputeExpr(Expr a, Expr b) { if (is_one(b)) return a; - return ir::Mul::make(a, b); + return ir::Div::make(a, b); } template<> diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 61f0347bcd2b..e1390b5891f8 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -11,6 +11,7 @@ #include "../pass/ir_util.h" #include "./int_set.h" #include "./graph.h" +#include "./compute_expr.h" namespace tvm { namespace schedule { @@ -47,6 +48,49 @@ void PassDownFlag(const Stage& s, } } +/*! + * \brief message passing to find if boundary checking on IterVar is needed. + * \param s The stage to be used. + * \param p_state The message passing state + * IterVar->flag + */ +void PassUpBoundCheck(const Stage& s, + const Map& dom_map, + std::unordered_map* p_state) { + auto& state = *p_state; + using Halide::Internal::can_prove; + for (size_t i = s->relations.size(); i != 0; --i) { + IterVarRelation rel = s->relations[i - 1]; + if (rel.as()) { + const SplitNode* s = rel.as(); + bool outer = state.at(s->outer); + bool inner = state.at(s->inner); + Expr factor = dom_map.at(s->inner)->extent; + Expr step = dom_map.at(s->outer)->extent; + + if (outer || inner) { + state[s->parent] = true; + } else { + if (can_prove(dom_map.at(s->parent)->extent == factor * step)) { + state[s->parent] = false; + } else { + state[s->parent] = true; + } + } + } else if (rel.as()) { + const FuseNode* s = rel.as(); + bool fused = state.at(s->fused); + state[s->outer] = fused; + state[s->inner] = fused; + } else if (rel.as()) { + const RebaseNode* s = rel.as(); + state[s->parent] = state.at(s->rebased); + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + /*! * \brief use message passing to calculate the assignment of each Var inside the loop body. * \param s The schedule to be used. @@ -107,8 +151,9 @@ MakeLoopNest(const Stage& sch, const Map& dom_map, size_t begin_loop, bool reduce_init_loop, - std::unordered_map* p_value_map, - const std::unordered_map& skip_iter) { + const std::unordered_map& bound_state, + const std::unordered_map& skip_iter, + std::unordered_map* p_value_map) { auto leaf_iter_vars = sch->leaf_iter_vars; Stmt no_op = Evaluate::make(0); // create the loop nest @@ -167,6 +212,21 @@ MakeLoopNest(const Stage& sch, } // message passing to get offset of root iter vars. PassUpOffset(sch, dom_map, &value_map); + + // insert conditions + for (IterVar iv : sch->op->root_iter_vars()) { + if (skip_iter.count(iv)) continue; + Range dom = dom_map.at(iv); + if (bound_state.at(iv)) { + Expr condition = ComputeExpr(value_map.at(iv), dom->min) < dom->extent; + nest.back().emplace_back(IfThenElse::make(condition, no_op)); + } + CHECK(iv->dom.defined()); + if (!reduce_init_loop && !iv->dom.same_as(dom)) { + Expr condition = ComputeExpr(value_map.at(iv), iv->dom->min) < iv->dom->extent; + nest.back().emplace_back(IfThenElse::make(condition, no_op)); + } + } return nest; } @@ -175,7 +235,16 @@ Stmt MakeLoop(const Stage& s, Stmt provide, Stmt init) { std::unordered_map value_map; - auto nest = MakeLoopNest(s, dom_map, 0, false, &value_map, {}); + // bound check state. + std::unordered_map bound_state; + for (IterVar iv : s->leaf_iter_vars) { + bound_state[iv] = false; + } + PassUpBoundCheck(s, dom_map, &bound_state); + auto nest = MakeLoopNest(s, dom_map, 0, false, + bound_state, {}, &value_map); + + provide = Substitute(provide, value_map); if (init.defined()) { // try to find the location to insert the initialization. @@ -204,13 +273,13 @@ Stmt MakeLoop(const Stage& s, } // skip loops that does not relates to axis. std::unordered_map skip_iter; - for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) { - auto iv = leaf_iter_vars[i]; - int flag = reduce_state.at(iv); - if ((flag & 1) == 0) skip_iter[iv] = true; + for (auto kv : reduce_state) { + int flag = kv.second; + if ((flag & 1) == 0) skip_iter[kv.first] = true; } auto init_nest = MakeLoopNest( - s, dom_map, begin_loop, true, &init_value_map, skip_iter); + s, dom_map, begin_loop, true, + bound_state, skip_iter, &init_value_map); init = Substitute(init, init_value_map); init = MergeNest(init_nest, init); // common nest @@ -250,7 +319,6 @@ Stmt MakeRealize(const ComputeOpNode* op, void MakeReduction(const ComputeOpNode* op, const std::vector& tensors, - const Map& dom_map, Stmt* init, Stmt* provide) { Stmt no_op = Evaluate::make(0); @@ -279,43 +347,49 @@ void MakeReduction(const ComputeOpNode* op, *provide = Provide::make(t->op, t->value_index, update_value, args); } -Stmt MakePipeline(const Stage& sch, +Stmt MakePipeline(const Stage& s, const Map& dom_map, Stmt consumer) { std::vector tensors; - for (int i = 0; i < sch->op->num_outputs(); ++i) { - tensors.emplace_back(sch->op.output(i)); + for (int i = 0; i < s->op->num_outputs(); ++i) { + tensors.emplace_back(s->op.output(i)); } Stmt init, provide; - const ComputeOpNode* compute = sch->op.as(); + const ComputeOpNode* compute = s->op.as(); if (compute) { if (compute->reduce_axis.size() == 0) { provide = MakeProvide(compute, tensors); } else { - MakeReduction(compute, tensors, dom_map, &init, &provide); + MakeReduction(compute, tensors, &init, &provide); } } else { - LOG(FATAL) << "not supported op " << sch->op->type_key(); + LOG(FATAL) << "not supported op " << s->op->type_key(); } - Stmt producer = MakeLoop(sch, dom_map, provide, init); - producer = ProducerConsumer::make(sch->op, true, producer); + Stmt producer = MakeLoop(s, dom_map, provide, init); + producer = ProducerConsumer::make(s->op, true, producer); Stmt pipeline = producer; if (consumer.defined()) { - consumer = ProducerConsumer::make(sch->op, false, consumer); + consumer = ProducerConsumer::make(s->op, false, consumer); pipeline = Block::make(producer, consumer); } - if (sch->op.as()) { - return MakeRealize(sch->op.as(), - dom_map, tensors, pipeline); + if (s->op.as()) { + pipeline = MakeRealize(s->op.as(), + dom_map, tensors, pipeline); } else { LOG(FATAL) << "not supported op"; return Stmt(); } + // use attribute to mark scope of the operation. + pipeline = AttrStmt::make( + s->op, "realize_scope", + StringImm::make(s->scope), + pipeline); + return pipeline; } // inject the operator's realization on the stmt. diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py new file mode 100644 index 000000000000..ac5c5c2c4b66 --- /dev/null +++ b/tests/python/integration/test_gemm.py @@ -0,0 +1,87 @@ +import tvm +import numpy as np + +def test_gemm(): + # graph + nn = 1235 + n = tvm.Var('n') + #n = tvm.convert(nn) + m = n + l = n + A = tvm.placeholder((n, l), name='A') + B = tvm.placeholder((m, l), name='B') + AA = tvm.compute(A.shape, lambda *i : A(*i), name="AA") + BB = tvm.compute(B.shape, lambda *i : B(*i), name="BB") + k = tvm.IterVar((0, l), name='k') + CC = tvm.compute( + (n, m), + lambda ii, jj: tvm.sum(AA[ii, k] * BB[jj, k], axis=k), + name='CC') + C = tvm.compute(CC.shape, lambda *i: CC(*i), name="C") + + # schedule + s = tvm.Schedule(C.op) + xtile, ytile = 32, 32 + s[AA].set_scope("shared") + #s[CC].set_scope("global") + s[BB].set_scope("shared") + + scale = 8 + num_thread = 8 + block_factor = scale * num_thread + block_x = tvm.IterVar(thread_tag="blockIdx.x") + thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") + block_y = tvm.IterVar(thread_tag="blockIdx.y") + thread_y = tvm.IterVar((0, num_thread), thread_tag="threadIdx.y") + + _, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y) + _, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x) + s[C].reorder(block_y, block_x, yi, xi) + _, yi = s[C].split(yi, outer=thread_y) + _, xi = s[C].split(xi, outer=thread_x) + s[C].reorder(thread_y, thread_x, yi, xi) + yo, xo = CC.op.axis + s[CC].reorder(k, yo, xo) + + s[CC].compute_at(s[C], thread_x) + s[AA].compute_at(s[CC], k) + s[BB].compute_at(s[CC], k) + + _, xi = s[AA].split(s[AA].op.axis[0], outer=thread_y) + _, xi = s[AA].split(xi, outer=thread_x) + _, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y) + _, xi = s[BB].split(xi, outer=thread_x) + + # lowering test + s.normalize() + + def check_device(target): + codes = [] + f = tvm.build(s, [A, B, C], target, record_codes=codes) + for c in codes[1:]: + print(c) + if target == "cuda": + ctx = tvm.gpu(0) + else: + ctx = tvm.cl(0) + if not ctx.enabled: + return + # launch the kernel. + n = nn + m = n + l = n + a_np = np.random.uniform(size=(n, l)).astype(A.dtype) + b_np = np.random.uniform(size=(m, l)).astype(B.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) + f(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) + + tvm.init_opencl() + check_device("cuda") + check_device("opencl") + +if __name__ == "__main__": + test_gemm() diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 56a9c29d8c2a..5d48ed8b453d 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -24,8 +24,8 @@ def test_add_pipeline(): Cb = tvm.Buffer(C.shape, C.dtype, name='C') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.Simplify(stmt) - fapi = tvm.codegen.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3) - fsplits = tvm.codegen.SplitHostDevice(fapi) + fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3) + fsplits = tvm.ir_pass.SplitHostDevice(fapi) def check_cuda(): output_ssa = False diff --git a/tests/python/unittest/test_codegen_makeapi.py b/tests/python/unittest/test_pass_makeapi.py similarity index 92% rename from tests/python/unittest/test_codegen_makeapi.py rename to tests/python/unittest/test_pass_makeapi.py index 689556db9f28..a862eca1ca59 100644 --- a/tests/python/unittest/test_codegen_makeapi.py +++ b/tests/python/unittest/test_pass_makeapi.py @@ -18,7 +18,7 @@ def test_makeapi(): stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) num_packed_args = 2 - f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args) + f = tvm.ir_pass.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args) assert(f.handle_data_type[Ab.data].dtype == Ab.dtype) assert(len(f.args) == 5) output_ssa = False diff --git a/tests/python/unittest/test_pass_storage_sync.py b/tests/python/unittest/test_pass_storage_sync.py new file mode 100644 index 000000000000..53b0b66af228 --- /dev/null +++ b/tests/python/unittest/test_pass_storage_sync.py @@ -0,0 +1,31 @@ +import tvm + +def test_storage_sync(): + m = tvm.Var('m') + l = tvm.Var('l') + A = tvm.placeholder((m, l), name='A') + + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + + s = tvm.Schedule(A2.op) + block_x = tvm.IterVar(thread_tag="blockIdx.x") + xo, xi = s[A2].split(A2.op.axis[0], factor=8, outer=block_x) + s[A1].compute_at(s[A2], xo) + s[A1].set_scope("shared") + + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.collections.Map) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + Ab = tvm.Buffer(A.shape, A.dtype, name='A') + A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) + f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 2) + flist = tvm.ir_pass.SplitHostDevice(f) + f = flist[1] + f = tvm.ir_pass.StorageSync(f, "shared") + print(f.body) + +if __name__ == "__main__": + test_storage_sync() diff --git a/tests/python/unittest/test_runtime_stack_vm.py b/tests/python/unittest/test_runtime_stack_vm.py index 363473661a3a..2de2da544036 100644 --- a/tests/python/unittest/test_runtime_stack_vm.py +++ b/tests/python/unittest/test_runtime_stack_vm.py @@ -16,9 +16,7 @@ def tvm_call_back_get_shape(shape0): n = tvm.Var('n') Ab = tvm.Buffer((n, ), tvm.float32) stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0])) - print(stmt) - fapi = tvm.codegen.MakeAPI(stmt, "print_shape", [Ab], 1) - print(fapi.body) + fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1) f = tvm.codegen.BuildStackVM(fapi) f(a) @@ -41,8 +39,7 @@ def test_stack_vm_loop(): tvm.make.Load(dtype, Ab.data, i) + 1, i + 1), tvm.make.Evaluate(tvm_call_global("tvm_stack_vm_print", i)))) - print(stmt) - fapi = tvm.codegen.MakeAPI(stmt, "ramp", [Ab], 1) + fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1) f = tvm.codegen.BuildStackVM(fapi) a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) @@ -63,8 +60,7 @@ def test_stack_vm_cond(): tvm.make.Load(dtype, Ab.data, i) + 1, i + 1), tvm.make.Store(Ab.data, tvm.make.Load(dtype, Ab.data, i) + 2, i + 1))) - print(stmt) - fapi = tvm.codegen.MakeAPI(stmt, "test", [Ab], 1) + fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1) f = tvm.codegen.BuildStackVM(fapi) a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a)