From c6f62aafc91e2600ed7772597fd4238c924c2a1b Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Fri, 20 Aug 2021 16:08:07 -0700 Subject: [PATCH] [Texture support][Part 1] TIR lowering and OpenCL support (#7686) * Add support for kTexture storage rank. * Add scaffolding for texture_flatten pass. * Add scaffolding for texture allocation. * Implement 2d texture flattening to builtin tir.text2d_alloca. * Lower BufferStore/Load to builtin texture store/load. * Add vectorizable attribure to texture load and store. * Support auto-vectorization on the innermost (RGBA) axis. * Add read/write_imagef opencl codegen for builtin texture load/store. * Add TextureType support. * Add InferTextureAccess pass to deduce __read_only and __write_only access qualifiers for texture vars. Also refactor use of restrict keyword to be var dependent. * Implement texture allocation as external function in TIR lowering. * Remove commented lines. * Add nd->2d texture flattening. * Bug fixes in opencl codegen (row<>col, access quals.) * Improve texture codegen by explicitly allocating local vector for the texture load. Also support indexing individual elements of the RGBA vector. * Remove automatic vectorization code as it is no longer needed. * Improve SSA local use when storing texture read to scalar buffer. * Define texture flattening convention such that the outer Nd-1 axes are stored as rows, and the last axis is stored as columns. * Add tir lowering and opencl codegen support for float16 textures. * Disable SSA when texture load is immediately casted. * Allow RGBA extent to be of length 1. * Add pass to forward externally allocated textures in place of textures realized from cache_read. Fix to better follow indexing spec. * Add buffer_common.h to house buffer offset simplification routines. * More refactor and clean up in texture lowering. * Add IsTextureType to tir and allow buffer var type annotation to be TextureType in addition to PointerType. * Bug fix in texture access qualifier inference pass * Step toward handling external texture buffer forwarding when external buffer is not stored directly to cache_read realized buffer. For example when it is conditionally stored via an IfThenElse node when padding is used. * [Part 2/3] Support texture:weight lowering convention for externally provided texture buffers. Need to propagate this to allocated textures when cache_read(texture) is used for weights. * Bug fix in texture access qualifier inference pass * Tighten constraint on external buffer forwarding -- cache_read(texture) cancellation -- to avoid incorrect programs. Currently only forward through if_then_else node and direct external loads. For if_then_else, still need proper analysis of structural equality between buffers and access patterns to determine if an external buffer can replace the texture buffer realized via cache_read. * Use texture lowering convention from texture runtime util. * Use updated texture lowering utilities * Use inherited visitor overloads in texture flattener. * Add check in codegen for float/half until read/write_image codegen supports other types. * Rename tir texture builtins * Remove codegen and tir runtime dependence on for TVMBackendAlloc/FreeTexture. * Dispatch texture allocas via target specialized tir.tvm_call_packed * Remove kTexture scope and use kGlobal with texture tag. * Remove TextureType. * Remove TextureType from OpenCL codegen. * Remove TextureType from TIR lowering. * Remove dependency on MergeMulMod. * Revert "Add buffer_common.h to house buffer offset simplification routines." This reverts commit 027628259229aaee051dbf1dfbed4e63ef820544. * Prune include list * Add more documentation to texture flattening. * Add TextureFlatten transform to refactored tvm lower API. * Apply clang formatting. * Blacken python APIs. * Apply cpplint changes. * Attempt to extract storage scope from pointer scope. * Remove ExternalBufferForwarding (cache_read cancellation) for now. * Apply MyPy. * Clang format * Only visit RealizeBuffer body for texture storage. * Fix bad merge. * Utilize OpenCL preprocessor to switch between sampler-less and codegen provided sampler for texture reads depending on whether the opencl runtime is 2.0 compliant. * Add texture codegen test example. * Refactor tests to use pytest parameterization. Blacken tests. * Respond to CRs. --- include/tvm/tir/builtin.h | 14 + include/tvm/tir/transform.h | 9 + python/tvm/tir/transform/transform.py | 15 + src/driver/driver_api.cc | 1 + src/target/source/codegen_c.cc | 10 +- src/target/source/codegen_c.h | 2 + src/target/source/codegen_opencl.cc | 211 ++- src/target/source/codegen_opencl.h | 22 +- src/te/operation/op_utils.cc | 6 +- src/tir/op/builtin.cc | 11 + src/tir/transforms/lower_tvm_builtin.cc | 41 + src/tir/transforms/texture_flatten.cc | 205 +++ src/tir/transforms/vectorize_loop.cc | 14 + .../test_target_texture_codegen_opencl.py | 1400 +++++++++++++++++ 14 files changed, 1952 insertions(+), 9 deletions(-) create mode 100644 src/tir/transforms/texture_flatten.cc create mode 100644 tests/python/unittest/test_target_texture_codegen_opencl.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 61280d33f1df..86857a33cdf4 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -600,6 +600,20 @@ TVM_DLL const Op& vectorcombine(); * \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA */ TVM_DLL const Op& atomic_add(); +/*! + * \brief Create a texture 2d memory allocation + */ +TVM_DLL const Op& texture2d_alloca(); + +/*! + * \brief Store to texture 2d memory + */ +TVM_DLL const Op& texture2d_store(); + +/*! + * \brief Load from texture 2d memory + */ +TVM_DLL const Op& texture2d_load(); /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index c0fa62d7caf0..b5998874f7e3 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -437,6 +437,15 @@ TVM_DLL Pass LowerMatchBuffer(); */ TVM_DLL Pass FlattenBuffer(); +/* + * \brief Flatten the multi-dimensional read/write + * to two dimensional texture Load/Store and realize + * texture buffer allocations. + * + * \return The Pass + */ +TVM_DLL Pass TextureFlatten(); + /*! * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 74dafa4157d7..2183319a006f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -95,6 +95,21 @@ def StorageFlatten(cache_line_size, create_bound_attribute: bool = False): return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) # type: ignore +def TextureFlatten(): + """Flatten the multi-dimensional read/write to 2D. + + + Parameters + ---------- + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.TextureFlatten() # type: ignore + + def InjectCopyIntrin(pragma_key: str, fintrin): """Inject virtual thread loops. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index ff00e68d91f0..bfea3e7b67c0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -215,6 +215,7 @@ Array CreatePassList(bool disable_loop_partition) { // PHASE 1 pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::TextureFlatten()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index f676f0f598d8..a311111532c8 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -106,8 +106,8 @@ void CodeGenC::AddFunction(const PrimFunc& f) { } } - if (no_alias && restrict_keyword_.length() != 0) { - stream << ' ' << restrict_keyword_; + if (no_alias) { + PrintRestrict(v, stream); } } else { PrintType(GetType(v), stream); @@ -1018,6 +1018,12 @@ void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, return; } +void CodeGenC::PrintRestrict(const Var& v, std::ostream& os) { + if (restrict_keyword_.length() != 0) { + os << ' ' << restrict_keyword_; + } +} + static bool CheckOutermostBracketMatch(const std::string& s) { if (!s.empty() && s.front() == '(' && s.back() == ')') { size_t len = s.size(); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 6ebade7191f2..299f7e0a9cef 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -200,6 +200,8 @@ class CodeGenC : public ExprFunctor, virtual std::string CastFromTo(std::string value, DataType from, DataType target); // Get load of single element with expression virtual void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os); + // Print restrict keyword for a given Var if applicable + virtual void PrintRestrict(const Var& v, std::ostream& os); protected: // Print reference to struct location diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index edb614d9c122..7abff36a3ddb 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -27,18 +27,63 @@ #include #include "../../runtime/opencl/opencl_module.h" +#include "../../runtime/texture.h" #include "../../runtime/thread_storage_scope.h" #include "../build_common.h" namespace tvm { namespace codegen { -CodeGenOpenCL::CodeGenOpenCL() { restrict_keyword_ = "restrict"; } +class InferTextureAccess : public StmtExprVisitor { + public: + static constexpr const uint8_t kReadAccess = 1; + static constexpr const uint8_t kWriteAccess = 2; + + InferTextureAccess() {} + std::unordered_map Infer(const Stmt& n) { + StmtExprVisitor::VisitStmt(n); + std::unordered_map storage_scope_qualifiers; + for (auto& texture : var_access_map_) { + if (texture.second == kReadAccess) { + storage_scope_qualifiers.insert({texture.first, "texture_read"}); + } else if (texture.second == kWriteAccess) { + storage_scope_qualifiers.insert({texture.first, "texture_write"}); + } else if (texture.second == (kReadAccess | kWriteAccess)) { + storage_scope_qualifiers.insert({texture.first, ""}); + } + } + return storage_scope_qualifiers; + } + void VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::texture2d_load())) { + var_access_map_[op->args[0].as()] |= kReadAccess; + } else if (op->op.same_as(builtin::texture2d_store())) { + var_access_map_[op->args[0].as()] |= kWriteAccess; + } else { + StmtExprVisitor::VisitExpr_(op); + } + StmtExprVisitor::VisitExpr_(op); + } + + private: + std::unordered_map var_access_map_; +}; + +CodeGenOpenCL::CodeGenOpenCL() { + // Set OpenCL specific restrict keyword + restrict_keyword_ = "restrict"; +} void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); + this->SetTextureScope(InferTextureAccess().Infer(f->body)); for (Var arg : f->params) { - if (arg.dtype().is_handle()) { + auto ptr_type = arg->type_annotation.as(); + if (ptr_type && runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) { + // Storage scope qualifiers for textures are inferred + // and set prior to function codegen. + continue; + } else if (arg.dtype().is_handle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -75,6 +120,40 @@ std::string CodeGenOpenCL::Finish() { decl_stream << "#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n" "#pragma OPENCL EXTENSION cl_khr_global_int32_extended_atomics : enable\n\n"; } + + // Enable OpenCL 1.2 sampler-less texture reads, but utilize + // provided sampler in OpenCL 2.0. + if (enable_compliant_texture_reads_) { + // TODO(csullivan, lunderberg): Extend device attribute querying to support remote devices + // generically through the device API such that a target can be created from a specific device's + // attributes and utilized during codegen. Potential generlization of #8127 (c02cafb) for remote + // devices. + // + // E.g. Only provide an image sampler when the local or remote device supports OpenCL 2.0, + // see below for context. + // + // For backwards compatibility with OpenCL 1.2, sampler-less read_image calls are used. + // By default in sampler-less read_image calls OpenCL defaults to + // sampler_ = "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST"; + // See section 6.12.14.3 Built-in Image Sampler-less Read Functions in the OpenCL 1.2 + // specification. For OpenCL 2.0 it can be preferable to use, + // sampler_ = "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST"; + // For now we rely on OpenCL preprocessor directives to utilize the correct behavior + // depending on the OpenCL version detected at OpenCL compile time. + decl_stream << "#ifdef __OPENCL_VERSION__\n" + << "#if __OPENCL_VERSION__ == CL_VERSION_2_0\n" + << "#define READ_IMAGEH(image, sampler, coord) " + << "read_imageh(image, sampler, coord)\n" + << "#define READ_IMAGEF(image, sampler, coord) " + << "read_imagef(image, sampler, coord)\n" + << "#else\n" + << "#define READ_IMAGEH(image, sampler, coord) " + << "read_imageh(image, coord)\n" + << "#define READ_IMAGEF(image, sampler, coord) " + << "read_imagef(image, coord)\n" + << "#endif\n" + << "#endif\n\n"; + } return CodeGenC::Finish(); } @@ -162,6 +241,23 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; } +void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) + if (auto* ptr = type.as()) { + return PrintType(ptr->dtype, os); + } else if (auto* ptr = type.as()) { + if (runtime::IsTextureStorage(std::string(ptr->storage_scope))) { + os << "image2d_t"; + } else { + PrintType(ptr->element_type, os); + os << '*'; + } + } else if (IsVoidType(type)) { + os << "void"; + } else { + LOG(FATAL) << "Type " << type << " does not have a corresponding C Type"; + } +} + void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { @@ -210,6 +306,19 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os os << "__global "; } else if (scope == "shared") { os << "__local "; + } else if (scope == "texture_read") { + os << "__read_only "; + } else if (scope == "texture_write") { + os << "__write_only "; + } +} + +void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) { + // Apply restrict qualifer for non-texture types only + if (auto* ptr = v->type_annotation.as()) { + if (!runtime::IsTextureStorage(std::string(ptr->storage_scope))) { + os << ' ' << restrict_keyword_; + } } } @@ -229,6 +338,39 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType return os.str(); } +void CodeGenOpenCL::VisitStmt_(const StoreNode* op) { + if (auto call = op->value.as()) { + if (call->op.same_as(builtin::texture2d_load())) { + need_texture_ssa_ = false; + // If storing a texture load into a buffer, don't use an + // intermediate local unless the buffer allocation is a + // single element selected from the texture read. + auto it = allocation_size_.find(op->buffer_var.get()); + if (it != allocation_size_.end() && it->second == 1) { + need_texture_ssa_ = true; + } + } + } + CodeGenC::VisitStmt_(op); + need_texture_ssa_ = true; +} + +void CodeGenOpenCL::VisitExpr_(const CastNode* op, std::ostream& os) { + if (auto call = op->value.as()) { + if (call->op.same_as(builtin::texture2d_load())) { + need_texture_ssa_ = false; + } + } + CodeGenC::VisitExpr_(op, os); + need_texture_ssa_ = true; +} + +void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) { + allocation_size_.insert( + {op->buffer_var.get(), op->constant_allocation_size() * op->dtype.lanes()}); + CodeGenC::VisitStmt_(op); +} + void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { if (op->op.same_as(builtin::address_of())) { // Overload tvm_address_of to add storage scope (e.g. __global). @@ -243,6 +385,64 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { os << " *)" << this->GetVarID(load->buffer_var.get()) << " + "; this->PrintExpr(load->index, os); os << ')'; + } else if (op->op.same_as(builtin::texture2d_store())) { + auto* ptr_type = op->args[0].as()->type_annotation.as(); + ICHECK(ptr_type != nullptr) << "Texture Var's must be of PointerType"; + ICHECK(runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) + << "builtin::texture2d_store() only supports storing to texture buffers"; + DataType buffer_type = ptr_type->element_type.as()->dtype; + if (buffer_type.is_float16()) { + os << "write_imageh("; + } else if (buffer_type.is_float()) { + os << "write_imagef("; + } else { + LOG(FATAL) << "Unsupported type: " << buffer_type + << ", currently only float and half are supported for image2d OpenCL codegen."; + } + this->PrintExpr(op->args[0], os); + os << ", "; + os << "(int2)("; + this->PrintExpr(op->args[1], os); + os << ", "; + this->PrintExpr(op->args[2], os); + os << "), "; + this->PrintExpr(op->args[3], os); + os << ")"; + } else if (op->op.same_as(builtin::texture2d_load())) { + enable_compliant_texture_reads_ = true; + std::stringstream ss; + if (op->dtype.is_float16()) { + ss << "READ_IMAGEH("; + } else if (op->dtype.is_float()) { + ss << "READ_IMAGEF("; + } else { + LOG(FATAL) << "Unsupported type: " << op->dtype + << ", currently only float and half are supported for image2d OpenCL codegen."; + } + this->PrintExpr(op->args[0], ss); + ss << ", "; + ss << "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST, "; + ss << "((int2)("; + this->PrintExpr(op->args[1], ss); + ss << ", "; + this->PrintExpr(op->args[2], ss); + ss << ")))"; + + // Only use local SSA if texture is not already being stored + if (need_texture_ssa_) { + std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4)); + if (op->args.back().as()) { + os << rhs; + } else { + os << "(("; + this->PrintType(op->dtype.with_lanes(1), os); + os << "*)&" << rhs << ")["; + this->PrintExpr(op->args.back(), os); + os << "]"; + } + } else { + os << ss.str(); + } } else if (op->op.same_as(builtin_call_extern_)) { auto func = Downcast(op->args[0]); // Enable atomics extension if used. @@ -280,6 +480,13 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N } } +void CodeGenOpenCL::SetTextureScope( + const std::unordered_map& scope) { // NOLINT(*) + for (auto& texture : scope) { + alloc_storage_scope_.insert(texture); + } +} + runtime::Module BuildOpenCL(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 32102fec22b9..a8c293c03056 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -27,6 +27,7 @@ #include #include +#include #include "codegen_c.h" @@ -45,18 +46,24 @@ class CodeGenOpenCL final : public CodeGenC { void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(const Type& type, std::ostream& os) final; // NOLINT(*) std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final; void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) // the address of load/store void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, - std::ostream& os); // NOLINT(*) - std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + std::ostream& os); // NOLINT(*) + void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) + std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + void SetTextureScope(const std::unordered_map&); // NOLINT(*) // overload visitor - void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitStmt_(const AllocateNode* op) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitStmt_(const StoreNode* op) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension @@ -64,6 +71,15 @@ class CodeGenOpenCL final : public CodeGenC { bool enable_fp64_{false}; // Whether to enable atomics extension. bool enable_atomics_{false}; + // Whether to enable sampler or sampler-less texture reads, + // where the choice depends on the OpenCL version used. + bool enable_compliant_texture_reads_{false}; + // Key to disable use of texture SSA in certain scenarios. For example, + // when loaded value is stored directly to a user declared l-value buffer + bool need_texture_ssa_{true}; + // Mapping from buffer to allocation size. + // Useful to track when a scalar store of a vectorized texture load is required. + std::unordered_map allocation_size_; }; } // namespace codegen diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index b3897e142545..ddc78866ae02 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -156,10 +156,12 @@ std::vector > MakeLoopNest(const Stage& stage, nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; + } else if (stage->scope == "") { + value_map[iv] = var; } else { runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag); - if (stage->scope == "" || - static_cast(runtime::StorageScope::Create(stage->scope).rank) <= ts.rank) { + runtime::StorageScope ss = runtime::StorageScope::Create(stage->scope); + if (static_cast(ss.rank) <= ts.rank) { value_map[iv] = var; } else if (stage->scope == "warp" && ts.rank == 1) { // To determine whether a thread index is inside or outside a warp, we need diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index f0ca04cbd5fd..c593cbf7290c 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -246,6 +246,17 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine) TIR_DEFINE_BUILTIN_FUNC(atomic_add) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(texture2d_alloca) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(texture2d_store) + .set_attr("TVectorizable", true) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(texture2d_load) + .set_attr("TVectorizable", true) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + } // namespace builtin } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 8b70817398e4..f5a553aa0598 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -98,6 +98,15 @@ class BuiltinLower : public StmtExprMutator { } } + Stmt VisitStmt_(const LetStmtNode* op) final { + if (const CallNode* call = op->value.as()) { + if (call->op.same_as(builtin::texture2d_alloca())) { + return StmtExprMutator::VisitStmt(MakeTextureAlloc(op, call)); + } + } + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const AllocateNode* op) { // Lower allocate to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); @@ -341,6 +350,38 @@ class BuiltinLower : public StmtExprMutator { return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args); } + Stmt MakeTextureAlloc(const LetStmtNode* let, const CallNode* call) { + ICHECK(device_type_.defined()) << "Unknown device type in current IR"; + ICHECK(device_id_.defined()) << "Unknown device id in current IR"; + Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); + + Stmt body = SeqStmt( + {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error), + let->body}); + DataType dtype = + let->var->type_annotation.as()->element_type.as()->dtype; + + std::string fdevapi_prefix = "device_api."; + fdevapi_prefix += runtime::DeviceName(device_type_.as()->value); + Call call_packed = + Call(let->var.dtype(), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".AllocTexture"), cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), call->args[0]), + cast(DataType::UInt(64), call->args[1]), IntImm(DataType::Int(32), dtype.code()), + IntImm(DataType::Int(32), dtype.bits())}); + + Stmt alloca = LetStmt(let->var, call_packed, body); + + Call free_op = + Call(DataType::Int(32), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".FreeTexture"), cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), let->var}); + + Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); + body = SeqStmt({alloca, free_stmt}); + return body; + } + private: bool IsArrayHandle(const PrimExpr& arg) { // specially set array handle. diff --git a/src/tir/transforms/texture_flatten.cc b/src/tir/transforms/texture_flatten.cc new file mode 100644 index 000000000000..7dc800737944 --- /dev/null +++ b/src/tir/transforms/texture_flatten.cc @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file texture_flatten.cc + * \brief Flattens texture storage from multi-dimensional array + * to 2D (width, height) buffer access + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "../../arith/ir_visitor_with_analyzer.h" +#include "../../runtime/texture.h" +#include "../../runtime/thread_storage_scope.h" + +namespace tvm { +namespace tir { +using runtime::ApplyTexture2DFlattening; +using runtime::DefaultTextureLayoutSeparator; +using runtime::IsTextureStorage; + +class TextureLoweringBase : public StmtExprMutator { + public: + explicit TextureLoweringBase(const Map& extern_buffer_map, + IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_{bound_analyzer} { + for (auto kv : extern_buffer_map) { + extern_buf_.insert(kv.second); + } + } + + inline PrimExpr SimplifyOffset(const Array& shape, const Array& index) const { + PrimExpr base = make_const(DataType::Int(32), 0); + ICHECK_EQ(shape.size(), index.size()); + if (index.size() > 0) { + PrimExpr offset = index[0]; + for (size_t i = 1; i < index.size(); ++i) { + offset = bound_analyzer_->Simplify(offset * shape[i] + index[i]); + } + base = base + offset; + } + return base; + } + + protected: + std::string GetStorageScope(const Buffer& buffer) { + auto* ptr = buffer->data->type_annotation.as(); + ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + return ptr->storage_scope; + } + + // Set of all external input and output buffers + std::unordered_set extern_buf_; + // Bound analzer + IRVisitorWithAnalyzer* bound_analyzer_; +}; + +// Lower Nd storage access to 2d texture access using lowering convention +// specified by the buffers storage scope. +class TextureFlattener : public TextureLoweringBase { + public: + using StmtExprMutator::VisitStmt_; + explicit TextureFlattener(const Map& extern_buffer_map, + IRVisitorWithAnalyzer* bound_analyzer) + : TextureLoweringBase(extern_buffer_map, bound_analyzer) {} + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + if (extern_buf_.count(op->buffer)) { + return this->VisitStmt(op->body); + } + + std::string storage_scope = GetStorageScope(op->buffer); + Var buffer_var(op->buffer->data->name_hint, + PointerType(PrimType(op->buffer->dtype), String(storage_scope))); + let_binding_.insert({op->buffer->data, buffer_var}); + + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + // Rewrite any buffer realizations with storage scope to 2d texture allocations + if (IsTextureStorage(storage_scope)) { + Stmt body = this->VisitStmt(op->body); + ICHECK(op->bounds.size() >= 3) << "Only 2d RGBA texture is currently supported"; + int vec_length = static_cast(op->bounds.back()->extent.as()->value); + ICHECK(vec_length == 4 || vec_length == 1) + << "Inner dimension of texture must be vector of length 1 or 4 (RGBA)"; + + struct ShapeFromRange { + const Array& bounds; + PrimExpr operator[](size_t i) const { return bounds[i]->extent; } + }; + size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope); + auto texture = + ApplyTexture2DFlattening(ShapeFromRange{op->bounds}, op->bounds.size(), axis); + Array args = {texture.width, texture.height}; + stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::texture2d_alloca(), args), body); + } + + return stmt; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + std::string storage_scope = GetStorageScope(op->buffer); + // Lower to two dimensional access + if (IsTextureStorage(storage_scope)) { + Array args = GetTextureAccessArgs(op, op->buffer); + args.push_back(op->value); + stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args)); + } + + return stmt; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + // Lower to two dimensional access + std::string storage_scope = GetStorageScope(op->buffer); + if (IsTextureStorage(storage_scope)) { + Array args = GetTextureAccessArgs(op, op->buffer); + args.push_back(op->indices.back()); + expr = Call(op->buffer->dtype, builtin::texture2d_load(), args); + } + + return expr; + } + + protected: + template + Array GetTextureAccessArgs(const T* op, const Buffer& buffer) { + Array args; + if (let_binding_.count(op->buffer->data)) { + args.push_back(let_binding_[op->buffer->data]); + } else { + args.push_back(buffer->data); + } + Array row_dims, row_indices, col_dims, col_indices; + for (size_t i = 0; i < op->buffer->shape.size() - 1; i++) { + if (i < DefaultTextureLayoutSeparator(op->buffer->shape.size(), GetStorageScope(buffer))) { + col_dims.push_back(op->buffer->shape[i]); + col_indices.push_back(op->indices[i]); + } else { + row_dims.push_back(op->buffer->shape[i]); + row_indices.push_back(op->indices[i]); + } + } + PrimExpr row_offset = SimplifyOffset(row_dims, row_indices); + PrimExpr col_offset = SimplifyOffset(col_dims, col_indices); + args.push_back(row_offset); + args.push_back(col_offset); + return args; + } + + // Bindings to new texture vars with texture pointer scope + std::unordered_map let_binding_; +}; + +PrimFunc TextureFlatten(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + IRVisitorWithAnalyzer bound_analyzer; + bound_analyzer(fptr->body); + fptr->body = TextureFlattener(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + return func; +} + +namespace transform { + +Pass TextureFlatten() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return TextureFlatten(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.TextureFlatten", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.TextureFlatten").set_body_typed(TextureFlatten); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 64956bc8ee54..cd2d230f5775 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -265,6 +265,20 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.same_as(builtin::if_then_else())) { return MutateIfThenElseExpr_(op); + } else if (op->op.same_as(builtin::texture2d_load())) { + int lane = 0; + Array fcd = MutateArray({op->args.back()}, &lane); + auto new_args = op->args; + new_args.pop_back(); + new_args.push_back(fcd[0]); + return Call(op->dtype.with_lanes(4), op->op, new_args); + } else if (op->op.same_as(builtin::texture2d_store())) { + int lane = 0; + // Vectorize the value to store + Array value{op->args.back()}; + Array mutated_value = MutateArray(value, &lane); + Array new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; + return Call(op->dtype.with_lanes(lane), op->op, new_args); } auto* op_ptr = op->op.as(); bool vectorizable = op_ptr && op_vectorizable_.get(GetRef(op_ptr), false); diff --git a/tests/python/unittest/test_target_texture_codegen_opencl.py b/tests/python/unittest/test_target_texture_codegen_opencl.py new file mode 100644 index 000000000000..03944c85ade5 --- /dev/null +++ b/tests/python/unittest/test_target_texture_codegen_opencl.py @@ -0,0 +1,1400 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys + +import numpy as np +import pytest + +import tvm +from tvm import autotvm +from tvm import te +from tvm.topi import testing +from tvm.topi.utils import get_const_tuple, simplify +from tvm.topi import nn + + +def compute_plus_one_rank3(shape): + X = te.placeholder(shape, name="X", dtype="float32") + Y = te.compute(shape, lambda i, j, k: X[i, j, k] + 1, name="Compute_Y") + return X, Y + + +def schedule_plus_one_rank3(X, Y): + s = te.create_schedule(Y.op) + # Xt = s.cache_read(X, "texture", [Y]) + # Xt = s.cache_read(X, "global", [Y]) + Xt = s.cache_read(X, "global.texture", [Y]) + + # copy to texture stage + x, y, c = s[Xt].op.axis + s[Xt].bind(x, te.thread_axis("blockIdx.x")) + s[Xt].bind(y, te.thread_axis("threadIdx.x")) + s[Xt].vectorize(c) + + # the compute stage + x, y, c = s[Y].op.axis + xo, yo, xi, yi = s[Y].tile(x, y, 4, 4) + s[Y].bind(xo, te.thread_axis("blockIdx.x")) + s[Y].bind(yo, te.thread_axis("threadIdx.x")) + s[Y].vectorize(c) + return s + + +def compute_plus_one_rank5(shape): + X = te.placeholder(shape, name="X", dtype="float32") + Y = te.compute(shape, lambda i, j, k, l, m: X[i, j, k, l, m] + 1, name="Compute_Y") + return X, Y + + +def schedule_plus_one_rank5(X, Y): + s = te.create_schedule(Y.op) + Xt = s.cache_read(X, "global.texture", [Y]) + + # copy to texture stage + a, b, c, d, e = s[Xt].op.axis + abc = s[Xt].fuse(a, b, c) + s[Xt].bind(abc, te.thread_axis("blockIdx.x")) + s[Xt].bind(d, te.thread_axis("threadIdx.x")) + s[Xt].vectorize(e) + + # the compute stage + a, b, c, d, e = s[Y].op.axis + abc = s[Y].fuse(a, b, c) + xo, yo, xi, yi = s[Y].tile(abc, d, 4, 4) + s[Y].bind(xo, te.thread_axis("blockIdx.x")) + s[Y].bind(yo, te.thread_axis("threadIdx.x")) + s[Y].vectorize(e) + return s + + +def compute_matmul(shape): + A = te.placeholder(shape, name="A", dtype="float32") + B = te.placeholder(shape, name="B", dtype="float32") + k = te.reduce_axis((0, shape[1]), name="k") + C = te.compute( + (shape[0] * shape[2], shape[0] * shape[2]), + lambda i, j: te.sum( + A[i // shape[2], k, i % shape[2]].astype("float32") + * B[j // shape[2], k, j % shape[2]].astype("float32"), + axis=[k], + ), + name="Compute_MatMul", + ) + return A, B, C + + +def schedule_matmul(A, B, C, local=False): + s = te.create_schedule(C.op) + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + if local: + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + + def copy_to_texture(stage): + _io, _k, _ii = s[stage].op.axis + s[stage].vectorize(_ii) + s[stage].bind(_io, bx) + s[stage].bind(_k, tx) + + copy_to_texture(At) + copy_to_texture(Bt) + + # copy to global stage + _i, _j = s[C].op.axis + xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) + s[C].unroll(xi) + s[C].vectorize(yi) + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(yo, te.thread_axis("threadIdx.x")) + + # the compute stage + s[Cl].compute_at(s[C], yo) + (_k,) = Cl.op.reduce_axis + _x, _y = s[Cl].op.axis + s[Cl].reorder(_k, _x, _y) + s[Cl].unroll(_x) + s[Cl].vectorize(_y) + + if local: + s[Al].compute_at(s[Cl], _k) + s[Al].vectorize(s[Al].op.axis[-1]) + s[Bl].compute_at(s[Cl], _k) + s[Bl].vectorize(s[Bl].op.axis[-1]) + + return s + + +def compute_matmul_inner(shape): + A = te.placeholder(shape, name="A", dtype="float32") + B = te.placeholder(shape, name="B", dtype="float32") + k = te.reduce_axis((0, shape[1] * shape[2]), name="k") + # (M, K) x (N, K) + # (32, 256) x (32, 256) + # (32, 64, 4) x (32, 64, 4) + C = te.compute( + (shape[0], shape[0]), + lambda i, j: te.sum( + A[i, k // shape[2], k % shape[2]].astype("float32") + * B[j, k // shape[2], k % shape[2]].astype("float32"), + axis=[k], + ), + name="Compute_MatMul", + ) + return A, B, C + + +def schedule_matmul_inner(A, B, C, local=False): + s = te.create_schedule(C.op) + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + if local: + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + + def copy_to_texture(stage): + _i, _ko, _ki = s[stage].op.axis + s[stage].vectorize(_ki) + s[stage].bind(_i, bx) + s[stage].bind(_ko, tx) + + copy_to_texture(At) + copy_to_texture(Bt) + + # copy to global stage + _i, _j = s[C].op.axis + xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) + s[C].unroll(xi) + s[C].vectorize(yi) + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(yo, te.thread_axis("threadIdx.x")) + + # the compute stage + s[Cl].compute_at(s[C], yo) + (_k,) = Cl.op.reduce_axis + _x, _y = s[Cl].op.axis + s[Cl].reorder(_x, _y, _k) + s[Cl].unroll(_x) + # TODO(csullivan): consider whether the below error is worth resolving + # s[Cl].vectorize(_y) # error + + if local: + s[Al].compute_at(s[Cl], _x) + s[Al].vectorize(s[Al].op.axis[-1]) + s[Bl].compute_at(s[Cl], _x) + s[Bl].vectorize(s[Bl].op.axis[-1]) + + return s + + +def compute_matmul_vector_accumulator(shapeA, shapeB): + # A x B + # (K/4, M, K%4) x (K, N/4, N%4) = (M, N) + # (32, 64, 4) x (128, 16, 4) = (64, 64) + A = te.placeholder(shapeA, name="A", dtype="float32") + B = te.placeholder(shapeB, name="B", dtype="float32") + k = te.reduce_axis((0, shapeB[0]), name="k") + C = te.compute( + (shapeA[1], shapeB[1] * shapeB[2]), + lambda i, j: te.sum( + A[k // shapeA[-1], i, k % shapeA[-1]].astype("float32") + * B[k, j // shapeB[-1], j % shapeB[-1]].astype("float32"), + axis=[k], + ), + name="Compute_MatMul", + ) + return A, B, C + + +def schedule_matmul_vector_accumulator(A, B, C, local=False): + s = te.create_schedule(C.op) + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + if local: + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + def copy_to_texture(stage): + _y, _x, _v = s[stage].op.axis + # TODO(csullivan): removing this vectorize results in numerical errors, autovectorize + s[stage].vectorize(_v) + s[stage].bind(_y, te.thread_axis("blockIdx.x")) + s[stage].bind(_x, te.thread_axis("threadIdx.x")) + + copy_to_texture(At) + copy_to_texture(Bt) + + # copy to global stage + _i, _j = s[C].op.axis + xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) + s[C].unroll(xi) + s[C].vectorize(yi) + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(yo, te.thread_axis("threadIdx.x")) + + # the compute stage + s[Cl].compute_at(s[C], yo) + (_k,) = Cl.op.reduce_axis + _a, _b = s[Cl].op.axis + _ko, _ki = s[Cl].split(_k, factor=4) + s[Cl].reorder(_ko, _a, _ki, _b) + s[Cl].unroll(_ki) + s[Cl].unroll(_a) + s[Cl].vectorize(_b) + + if local: + s[Al].compute_at(s[Cl], _a) + _aa, _ka, _ba = s[Al].op.axis + # TODO(csullivan)[BEFORE PR]: removing this vectorize command causes a crash. This needs to be autovectorized. + s[Al].vectorize(_ba) + s[Bl].compute_at(s[Cl], _ko) + _ab, _kb, _bb = s[Bl].op.axis + s[Bl].vectorize(_bb) + s[Bl].unroll(_ab) + + return s + + +def compute_conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape): + # conv2d( [N, C, H, W, c] , [1, 1, C, K, k] + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + c = te.reduce_axis((0, input_shape[1]), name="C") + c4 = te.reduce_axis((0, input_shape[-1]), name="c4") + kh = te.reduce_axis((0, filter_shape[0]), name="kh") + kw = te.reduce_axis((0, filter_shape[1]), name="kw") + conv = te.compute( + (input_shape[0], filter_shape[-2], input_shape[2], input_shape[3], filter_shape[-1]), + lambda n, ko, i, j, ki: te.sum( + data[n, c, i, j, c4].astype("float32") + * filt[kh, kw, c * input_shape[-1] + c4, ko, ki].astype("float32"), + axis=[kh, kw, c, c4], + ), + # name="Compute_conv2d_1x1_NCHWc_RSCKk", + name="conv2d_1x1", + ) + return data, filt, conv + + +def schedule_conv2d_1x1_NCHWc_RSCKk(data, filt, conv): + # inputs: (1, 128//4, 56, 56, 4), (1, 1, 128, 128//4, 4) + # outputs: + s = te.create_schedule(conv.op) + A, B, C = data, filt, conv + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(At) + copy_to_texture(Bt) + + _n, _ko, _h, _w, _ki = s[C].op.axis + s[C].vectorize(_ki) + s[C].bind(_n, te.thread_axis("blockIdx.x")) + s[C].bind(_ko, te.thread_axis("threadIdx.x")) + + s[Cl].compute_at(s[C], _w) + _nl, _kol, _hl, _wl, _kil = s[Cl].op.axis + _khl, _kwl, _cl, _cl4 = s[Cl].op.reduce_axis + _clo, _cli = s[Cl].split(_cl, factor=4) + s[Cl].reorder(_clo, _cli, _cl4, _kil) + s[Cl].unroll(_cli) + s[Cl].unroll(_cl4) + s[Cl].vectorize(_kil) + + s[Al].compute_at(s[Cl], _cli) + s[Al].vectorize(s[Al].op.axis[-1]) + s[Bl].compute_at(s[Cl], _kwl) + s[Bl].vectorize(s[Bl].op.axis[-1]) + + return s + + +def compute_conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape): + # input_shape = [W, C, H, N, c] -> [W, C, H*N, c] + # filter_shape = [C, R, S, K, k] -> [C, R*S*K, k] + # output_shape: [WK, HN, k] -> [W, K, H, N, k] + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + + packed_data = te.compute( + (input_shape[0], input_shape[1], input_shape[2] * input_shape[3], input_shape[4]), + lambda i, j, k, l: data[i, j, k // input_shape[3], k % input_shape[3], l], + name="packed_data", + ) + + # Logical transformation of Nd -> 3d tensor + # CRSKk -> C|RSK|k + # r = rsk // SK + # sk = rsk % SK + # s = sk // K == (rsk % SK) // K == (rsk // K) % S + # k = sk % K == (rsk % SK) % K == rsk % K + packed_filter = te.compute( + (filter_shape[0], filter_shape[1] * filter_shape[2] * filter_shape[3], filter_shape[4]), + lambda i, j, k: filt[ + i, + j // (filter_shape[3] * filter_shape[2]), + (j // filter_shape[3]) % filter_shape[2], + j % filter_shape[3], + k, + ], + name="packed_filter", + ) + + c = te.reduce_axis((0, input_shape[1]), name="C") + c4 = te.reduce_axis((0, input_shape[-1]), name="c4") + r = te.reduce_axis((0, filter_shape[1]), name="r") + s = te.reduce_axis((0, filter_shape[2]), name="s") + + conv = te.compute( + (input_shape[0], filter_shape[3], input_shape[2], input_shape[3], filter_shape[4]), + lambda w, ko, h, n, ki: te.sum( + packed_data[w, c, h * input_shape[3] + n, c4].astype("float32") + * packed_filter[ + c * input_shape[-1] + c4, ((r * filter_shape[2]) + s) * filter_shape[3] + ko, ki + ].astype("float32"), + axis=[r, s, c, c4], + ), + name="conv2d_1x1", + ) + return data, filt, packed_data, packed_filter, conv + + +def schedule_conv2d_1x1_WCHNc_CRSKk(data, filt, packed_data, packed_filter, conv): + # data: [W, C, H*N, c] + # filter: [C, R*S*K, k] + # output: [W, K, H, N, k] + + # conv2d( [N, C, H, W, c] , [1, 1, C, K, k] + # inputs: (1, 128//4, 56, 56, 4), (1, 1, 128, 128//4, 4) + + # data: (56, 128//4, 56*1, 4) = (56, 32, 56, 4) + # filt: (128, 1*1*128//4, 4) = (128, 32, 4) + # conv: (56, 32, 56, 1, 4) + + s = te.create_schedule(conv.op) + cfg = autotvm.get_config() + + s[packed_data].compute_inline() + s[packed_filter].compute_inline() + A, B, C = packed_data, packed_filter, conv + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(At) + copy_to_texture(Bt) + + _w, _ko, _h, _n, _ki = s[C].op.axis + kernel_scope, _n = s[C].split(_n, nparts=1) + + cfg.define_split("tile_f", _ko, num_outputs=4) + cfg.define_split("tile_w", _w, num_outputs=4) + cfg.define_split("tile_h", _h, num_outputs=4) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + bk, vk, tk, ki = cfg["tile_f"].apply(s, C, _ko) + bw, vw, tw, wi = cfg["tile_w"].apply(s, C, _w) + bh, vh, th, hi = cfg["tile_h"].apply(s, C, _h) + s[C].reorder(bh, _n, vh, th, hi) + bhn = s[C].fuse(bh, _n) + + s[C].bind(bk, te.thread_axis("blockIdx.z")) + s[C].bind(bhn, te.thread_axis("blockIdx.y")) + s[C].bind(bw, te.thread_axis("blockIdx.x")) + s[C].bind(vk, te.thread_axis("vthread")) + s[C].bind(vh, te.thread_axis("vthread")) + s[C].bind(vw, te.thread_axis("vthread")) + s[C].bind(tk, te.thread_axis("threadIdx.z")) + s[C].bind(th, te.thread_axis("threadIdx.y")) + s[C].bind(tw, te.thread_axis("threadIdx.x")) + s[C].reorder(bw, bk, bhn, vw, vk, vh, tw, tk, th, ki, hi, wi, _ki) + s[C].vectorize(_ki) + + # TODO(csullivan): Try uneven workgroup split + # _wo, _wi = s[C].split(_w, factor=4) + # #_hno, _hni = s[C].split(_hn, factor=8) + # #s[C].reorder(_wo, _wi, _ko, _hno, _hni, _ki) + # s[C].reorder(_wo, _ko, _hn, _ki, _wi) + # s[C].unroll(_wi) + + # # mace: + # # const int out_ch_blk = get_global_id(0); + # # const int out_w_blk = get_global_id(1); + # # const int out_hb = get_global_id(2); + + # bx = te.thread_axis("blockIdx.x") + # by = te.thread_axis("blockIdx.y") + # bz = te.thread_axis("blockIdx.z") + # s[C].bind(_ko, bx) + # s[C].bind(_wo, by) + # s[C].bind(_hn, bz) + + # s[Cl].compute_at(s[C], _hn) + s[Cl].compute_at(s[C], th) + + _wl, _kol, _hl, _nl, _kil = s[Cl].op.axis + _khl, _kwl, _cl, _cl4 = s[Cl].op.reduce_axis + + cfg.define_split("tile_c", _cl, num_outputs=2) + cfg.define_split("tile_kh", _khl, num_outputs=2) + cfg.define_split("tile_kw", _kwl, num_outputs=2) + + _clo, _cli = cfg["tile_c"].apply(s, Cl, _cl) + _khlo, _khli = cfg["tile_kh"].apply(s, Cl, _khl) + _kwlo, _kwli = cfg["tile_kw"].apply(s, Cl, _kwl) + # s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) + s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli, _kol, _hl, _nl, _kil, _wl) + # s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli) + # s[Cl].reorder(_cl, _cl4, _kil, _wl) + s[Cl].unroll(_cl4) + s[Cl].unroll(_wl) + s[Cl].vectorize(_kil) + + _wla, _cla, _hnla, _cl4a = s[Al].op.axis + s[Al].compute_at(s[Cl], _cli) + s[Al].vectorize(_cl4a) + s[Al].unroll(_wla) + + _clb, _rskolb, _kilb = s[Bl].op.axis + s[Bl].compute_at(s[Cl], _cli) + s[Bl].vectorize(_kilb) + s[Bl].unroll(_clb) + + s[C].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + + WO, K, HO, N, K4 = get_const_tuple(C.shape) + RSC, _, _ = get_const_tuple(B.shape) + cfg.add_flop(2 * N * K * K4 * HO * WO * RSC) + + return s + + +def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None): + """Convolution operator in NCHWc layout. """ + + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape + num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") + + rcc = te.reduce_axis((0, in_channel_chunk), name="rc") + rcb = te.reduce_axis((0, in_channel_block), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + # NCHWc x KCRSk + # texture: NCH|W|c + # texture: K|CRS|k + # c = crs//RS + # rs = crs % RS + # r = rs // W == (crs // S) % R + # s = rs % W == crs % S + Filter = te.compute( + (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), + lambda ffc, crs, ffb: Filter[ + ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb + ], + name="packed_filter", + ) + return te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + temp[ + nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb + ].astype(out_dtype) + * Filter[ + ffc, ((rcc * in_channel_block + rcb) * kernel_h + ry) * kernel_w + rx, ffb + ].astype(out_dtype), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc_kcrsk_texture", + ) + + +def schedule_conv2d_NCHWc_KCRSk(cfg, s, conv): + """schedule optimized for batch size = 1""" + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + rcc, rcb, ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_fc", fc, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rcc", rcc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + kernel = flattened_kernel + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, "local") + else: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + + # create cache stage + AT = s.cache_read(pad_data, "global.texture", [OL]) + WT = s.cache_read(kernel, "global.texture", [OL]) + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(AT) + copy_to_texture(WT) + + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + + # tile and bind spatial axes + n, fc, y, x, fb = s[output].op.axis + + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) + s[output].vectorize(fb) + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, fc, y, x, fb = s[OL].op.axis + + rcc, rcb, ry, rx = s[OL].op.reduce_axis + rco, rci = cfg["tile_rcc"].apply(s, OL, rcc) + ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) + rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) + + # TODO(csullivan): check position of rcb + s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) + s[OL].vectorize(fb) + s[OL].unroll(rcb) + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + # cooperative fetching + for load in [AA, WW]: + if load == WW: + n, fyx, v = s[load].op.axis + fused = s[load].fuse(n, fyx) + else: + n, f, y, x, v = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(v) + + # unroll + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + + N, OCC, OH, OW, OCB = get_const_tuple(output.shape) + _, ICKHKW, _ = get_const_tuple(kernel.shape) + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) + + +def compute_conv2d_NCHWc_KCRSk_acc32(Input, Filter, stride, padding, dilation, out_dtype=None): + """Convolution operator in NCHWc layout. """ + + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape + num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") + + rcc = te.reduce_axis((0, in_channel_chunk), name="rc") + rcb = te.reduce_axis((0, in_channel_block), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + # NCHWc x KCRSk + # texture: NCH|W|c + # texture: K|CRS|k + # c = crs//RS + # rs = crs % RS + # r = rs // W == (crs // S) % R + # s = rs % W == crs % S + Filter = te.compute( + (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), + lambda ffc, crs, ffb: Filter[ + ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb + ], + name="packed_filter", + ) + conv = te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + ( + temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] + * Filter[ffc, ((rcc * in_channel_block + rcb) * kernel_h + ry) * kernel_w + rx, ffb] + ).astype(out_dtype), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc_kcrsk_texture", + ) + output = te.compute(conv.shape, lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype("float32")) + return output + + +def schedule_conv2d_NCHWc_KCRSk_acc32(cfg, s, output): + """schedule optimized for batch size = 1""" + + conv = output.op.input_tensors[0] + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + rcc, rcb, ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_fc", fc, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rcc", rcc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + kernel = flattened_kernel + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, "local") + else: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + + # create cache stage + AT = s.cache_read(pad_data, "global.texture", [OL]) + WT = s.cache_read(kernel, "global.texture", [OL]) + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(AT) + copy_to_texture(WT) + + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + + # tile and bind spatial axes + n, fc, y, x, fb = s[output].op.axis + + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) + s[output].vectorize(fb) + + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, fc, y, x, fb = s[OL].op.axis + + rcc, rcb, ry, rx = s[OL].op.reduce_axis + rco, rci = cfg["tile_rcc"].apply(s, OL, rcc) + ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) + rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) + + # TODO(csullivan): check position of rcb + s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) + s[OL].vectorize(fb) + s[OL].unroll(rcb) + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + # cooperative fetching + for load in [AA, WW]: + if load == WW: + n, fyx, v = s[load].op.axis + fused = s[load].fuse(n, fyx) + else: + n, f, y, x, v = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(v) + + # unroll + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + + N, OCC, OH, OW, OCB = get_const_tuple(output.shape) + _, ICKHKW, _ = get_const_tuple(kernel.shape) + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) + + +def compute_depthwise_conv2d_NCHWc_KCRSk_acc32( + Input, Filter, stride, padding, dilation, out_dtype=None +): + """Depthwise convolution operator in NCHWc layout. """ + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, channel_chunk, in_height, in_width, channel_block = Input.shape + _, channel_multiplier, kernel_h, kernel_w, _ = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel_chunk = simplify(channel_chunk * channel_multiplier) + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") + + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] + # NCHWc x CMRS + # texture: NCH|W|c + # texture: C|MRS|c + # output: N + # m = mrs//RS + # rs = mrs % RS + # r = rs // W == (mrs // S) % R + # s = rs % W == mrs % S + Filter = te.compute( + (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), + lambda ffc, mrs, ffb: Filter[ + ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb + ], + name="packed_filter", + ) + + conv = te.compute( + (batch, out_channel_chunk, out_height, out_width, channel_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + ( + temp[ + nn, + ffc // channel_multiplier, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, + ffb, + ] + * Filter[ + ffc // channel_multiplier, + ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, + ffb, + ] + ).astype(out_dtype), + axis=[ry, rx], + ), + tag="depthwise_conv2d_nchwc_kcrsk_texture", + ) + return te.compute( + conv.shape, lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype("float32") + ) + + +def schedule_depthwise_conv2d_NCHWc_KCRSk_acc32(cfg, s, output): + """schedule optimized for batch size = 1""" + + conv = output.op.input_tensors[0] + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_fc", fc, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + kernel = flattened_kernel + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, "local") + else: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + + # create cache stage + AT = s.cache_read(pad_data, "global.texture", [OL]) + WT = s.cache_read(kernel, "global.texture", [OL]) + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(AT) + copy_to_texture(WT) + + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + + # tile and bind spatial axes + n, fc, y, x, fb = s[output].op.axis + + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) + s[output].vectorize(fb) + + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, fc, y, x, fb = s[OL].op.axis + + ry, rx = s[OL].op.reduce_axis + ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) + rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) + + s[OL].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb) + s[OL].vectorize(fb) + # s[OL].unroll() + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + # cooperative fetching + for load in [AA, WW]: + if load == WW: + n, fyx, v = s[load].op.axis + fused = s[load].fuse(n, fyx) + else: + n, f, y, x, v = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(v) + + # unroll + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + + N, OCC, OH, OW, OCB = get_const_tuple(output.shape) + ICC, MKHKW, ICB = get_const_tuple(kernel.shape) + M = (OCC * OCB) // (ICC * ICB) + KHKW = MKHKW // M + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW) + + +def scheduler(compute, schedule, *args, **kwargs): + placeholders = compute(*args) + s = schedule(*placeholders, **kwargs) + return s, placeholders + + +def conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape): + placeholders = compute_conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape) + s = schedule_conv2d_1x1_NCHWc_RSCKk(*placeholders) + return s, placeholders + + +def conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape): + placeholders = compute_conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape) + s = schedule_conv2d_1x1_WCHNc_CRSKk(*placeholders) + return s, (placeholders[0], placeholders[1], placeholders[-1]) + + +def conv2d_NCHWc_KCRSk(input_shape, filter_shape): + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + conv = compute_conv2d_NCHWc_KCRSk(data, filt, [1, 1], [0, 0], [1, 1], "float32") + cfg = autotvm.get_config() + s = te.create_schedule([x.op for x in [conv]]) + schedule_conv2d_NCHWc_KCRSk(cfg, s, conv) + return s, (data, filt, conv) + + +def conv2d_NCHWc_KCRSk_fp32_acc(input_shape, filter_shape): + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + output = compute_conv2d_NCHWc_KCRSk_acc32(data, filt, [1, 1], [0, 0], [1, 1], "float32") + cfg = autotvm.get_config() + s = te.create_schedule([x.op for x in [output]]) + schedule_conv2d_NCHWc_KCRSk_acc32(cfg, s, output) + return s, (data, filt, output) + + +def depthwise_conv2d_NCHWc_KCRSk_acc32(input_shape, filter_shape): + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + output = compute_depthwise_conv2d_NCHWc_KCRSk_acc32( + data, filt, [1, 1], [0, 0], [1, 1], "float32" + ) + cfg = autotvm.get_config() + s = te.create_schedule([x.op for x in [output]]) + schedule_depthwise_conv2d_NCHWc_KCRSk_acc32(cfg, s, output) + return s, (data, filt, output) + + +def ref_convolution(data, kernel, stride, pad): + import mxnet as mx + + groups = 1 + kernel_size = (kernel.shape[2], kernel.shape[3]) + num_filter = kernel.shape[0] + ref_res = mx.nd.Convolution( + data=mx.nd.array(data), + weight=mx.nd.array(kernel), + bias=None, + no_bias=True, + kernel=kernel_size, + stride=stride, + pad=pad, + num_filter=num_filter, + num_group=groups, + ) + return ref_res.asnumpy() + + +def ref_depthwise_convolution(data, kernel, stride, pad): + import mxnet as mx + + groups = kernel.shape[0] + kernel_size = (kernel.shape[2], kernel.shape[3]) + num_filter = kernel.shape[0] + multiplier = kernel.shape[1] + ref_res = mx.nd.Convolution( + data=mx.nd.array(data), + weight=mx.nd.array(kernel), + bias=None, + no_bias=True, + kernel=kernel_size, + stride=stride, + pad=pad, + num_filter=num_filter, + num_group=groups, + ) + return ref_res.asnumpy() + + +def validate(workload, target, dev, input_shapes, *args, **kwargs): + s, placeholders = workload(*input_shapes, *args, **kwargs) + func = tvm.driver.build(s, [*placeholders], target=target, name="TestFunction") + + args_tvm = [] + args_np = [] + for var in placeholders[:-1]: + var_np = np.random.uniform(size=[i.value for i in var.shape]).astype(var.dtype) + args_np.append(var_np) + args_tvm.append(tvm.nd.array(var_np, dev)) + args_tvm.append( + tvm.nd.array( + np.zeros([i.value for i in placeholders[-1].shape], dtype=placeholders[-1].dtype), dev + ) + ) + func(*args_tvm) + + if "plus_one" in workload.__name__: + np_result = args_np[0] + 1.0 + elif "matmul" in workload.__name__: + if "inner" in workload.__name__: + np_result = np.matmul( + args_np[0].reshape(32, 256), args_np[1].reshape(32, 256).transpose(1, 0) + ) + elif "accum" in workload.__name__: + np_result = np.matmul( + args_np[0].transpose((1, 0, 2)).reshape(64, 128), args_np[1].reshape(128, 64) + ) + else: + np_result = np.matmul( + args_np[0].transpose((0, 2, 1)).reshape(128, 64), + args_np[1].transpose(1, 0, 2).reshape(64, 128), + ) + elif "conv2d_1x1_NCHWc_RSCKk" in workload.__name__: + vec_length = args_np[1].shape[-1] + # nchwc -> nchw + args_np[0] = ( + args_np[0] + .transpose((0, 1, 4, 2, 3)) + .reshape( + args_np[0].shape[0], + args_np[0].shape[1] * args_np[0].shape[-1], + args_np[0].shape[2], + args_np[0].shape[3], + ) + ) + # rsckk -> rsck -> kcrs + args_np[1] = ( + args_np[1] + .reshape( + args_np[1].shape[0], + args_np[1].shape[1], + args_np[1].shape[2], + args_np[1].shape[3] * args_np[1].shape[4], + ) + .transpose((3, 2, 0, 1)) + ) + np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) + # nkhw -> nkhwk + np_result = np_result.reshape( + np_result.shape[0], + np_result.shape[1] // vec_length, + vec_length, + np_result.shape[2], + np_result.shape[3], + ).transpose(0, 1, 3, 4, 2) + elif "conv2d_1x1_WCHNc_CRSKk" in workload.__name__: + vec_length = args_np[1].shape[-1] + # wchnc -> nchw + args_np[0] = ( + args_np[0] + .transpose((3, 1, 4, 2, 0)) + .reshape( + args_np[0].shape[3], + args_np[0].shape[1] * args_np[0].shape[-1], + args_np[0].shape[2], + args_np[0].shape[0], + ) + ) + # crskk -> crsk -> kcrs + args_np[1] = ( + args_np[1] + .reshape( + args_np[1].shape[0], + args_np[1].shape[1], + args_np[1].shape[2], + args_np[1].shape[3] * args_np[1].shape[4], + ) + .transpose((3, 0, 1, 2)) + ) + np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) + # nkhw -> nkkhw -> wkhnk + np_result = np_result.reshape( + np_result.shape[0], + np_result.shape[1] // vec_length, + vec_length, + np_result.shape[2], + np_result.shape[3], + ).transpose(4, 1, 3, 0, 2) + elif "NCHW_KCRS" in workload.__name__: + np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) + elif "NCHWc_KCRSk" in workload.__name__: + vec_length = args_np[1].shape[-1] + # nchwc -> nchw + args_np[0] = ( + args_np[0] + .transpose((0, 1, 4, 2, 3)) + .reshape( + args_np[0].shape[0], + args_np[0].shape[1] * args_np[0].shape[-1], + args_np[0].shape[2], + args_np[0].shape[3], + ) + ) + # kcrsk/cmrsc -> kcrs/cmrs + args_np[1] = ( + args_np[1] + .transpose((0, 4, 1, 2, 3)) + .reshape( + args_np[1].shape[0] * args_np[1].shape[4], + args_np[1].shape[1], + args_np[1].shape[2], + args_np[1].shape[3], + ) + ) + if "depthwise" in workload.__name__: + # np_result = testing.depthwise_conv2d_python_nchw(args_np[0], args_np[1], 1, "VALID") + np_result = ref_depthwise_convolution(args_np[0], args_np[1], [], []) + else: + # np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) + np_result = ref_convolution(args_np[0], args_np[1], [], []) + # nkhw -> nkhwk + np_result = np_result.reshape( + np_result.shape[0], + np_result.shape[1] // vec_length, + vec_length, + np_result.shape[2], + np_result.shape[3], + ).transpose(0, 1, 3, 4, 2) + np.testing.assert_allclose(args_tvm[-1].asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + +class BaseSingleShapeValidator: + @tvm.testing.parametrize_targets("opencl") + def test_unary(self, test_func, input_shape, target, dev): + validate(test_func, target, dev, [input_shape]) + + +class TestPlusOneRank3(BaseSingleShapeValidator): + input_shape = tvm.testing.parameter((32, 32, 4)) + + def plus_one(input_shape): + return scheduler(compute_plus_one_rank3, schedule_plus_one_rank3, input_shape) + + test_func = tvm.testing.parameter(plus_one) + + +class TestPlusOneRank5(BaseSingleShapeValidator): + input_shape = tvm.testing.parameter((32, 2, 4, 4, 4)) + + def plus_one(input_shape): + return scheduler(compute_plus_one_rank5, schedule_plus_one_rank5, input_shape) + + test_func = tvm.testing.parameter(plus_one) + + +class TestMatmul: + input_shape = tvm.testing.parameter((32, 64, 4)) + local = tvm.testing.parameter(False, True) + + def matmul(input_shape, local): + return scheduler(compute_matmul, schedule_matmul, input_shape, local=local) + + def matmul_inner(input_shape, local): + return scheduler(compute_matmul_inner, schedule_matmul_inner, input_shape, local=local) + + test_func = tvm.testing.parameter(matmul, matmul_inner) + + @tvm.testing.parametrize_targets("opencl") + def test_matmul(self, test_func, input_shape, local, target, dev): + validate(test_func, target, dev, [input_shape], local=local) + + +class TestMatmulVectorAccumulator: + shapeA = tvm.testing.parameter((32, 64, 4)) + shapeB = tvm.testing.parameter((128, 16, 4)) + local = tvm.testing.parameter(False, True) + + def matmul_vector_accumulator(shapeA, shapeB, local): + return scheduler( + compute_matmul_vector_accumulator, + schedule_matmul_vector_accumulator, + shapeA, + shapeB, + local=local, + ) + + test_func = tvm.testing.parameter(matmul_vector_accumulator) + + @tvm.testing.parametrize_targets("opencl") + def test_matmul_vec_acc(self, test_func, shapeA, shapeB, local, target, dev): + validate(test_func, target, dev, [shapeA, shapeB], local=local) + + +class BaseConv2DValidator: + @tvm.testing.parametrize_targets("opencl") + def test_conv2d(self, test_func, input_shapes, target, dev): + validate(test_func, target, dev, input_shapes) + + +class TestConv2dNCHWcRSCKk(BaseConv2DValidator): + input_shapes = tvm.testing.parameter([(1, 32, 56, 56, 4), (1, 1, 128, 32, 4)]) + test_func = tvm.testing.parameter(conv2d_1x1_NCHWc_RSCKk) + + +class TestConv2dWCHNcCRSKk(BaseConv2DValidator): + input_shapes = tvm.testing.parameter([(56, 32, 56, 1, 4), (128, 1, 1, 32, 4)]) + test_func = tvm.testing.parameter(conv2d_1x1_WCHNc_CRSKk) + + +class TestConv2dNCHWcKCRSk(BaseConv2DValidator): + input_shapes = tvm.testing.parameter( + [(1, 32, 56, 56, 4), (32, 128, 1, 1, 4)], [(1, 32, 112, 112, 4), (32, 128, 3, 3, 4)] + ) + test_func = tvm.testing.parameter(conv2d_NCHWc_KCRSk, conv2d_NCHWc_KCRSk_fp32_acc) + + +class TestDepthwiseConv2dNCHWcKCRSk(BaseConv2DValidator): + input_shapes = tvm.testing.parameter([(1, 24, 257, 257, 4), (24, 1, 3, 3, 4)]) + test_func = tvm.testing.parameter(depthwise_conv2d_NCHWc_KCRSk_acc32) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv))