diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 60ab5e5ae9d2..efc8b32832c0 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -183,7 +183,7 @@ The first time you invoke the compiled module with ``fadd(a, b, c)``, ``GetFunct auto it = fmap_.find(name); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } @@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal fcache_[device_id] = m_->GetFunc(device_id, func_name_); } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); CUresult result = cuLaunchKernel( fcache_[device_id], wl.grid_dim(0), diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 25ed2f9ae8d1..55f4fc62649c 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -240,10 +240,12 @@ namespace attr { * * Call(f, * [arg1, arg2, ..., arg_n, - * work_size_1, work_size_2, ... work_size_m]) + * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) * * Here n = len(arg), m = len(work_size) = len(device_thread_axis). * + * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. + * * The list of device_thread_axis indicates how can be bind the * work_size arguments to the corresponding threads. * @@ -251,6 +253,13 @@ namespace attr { */ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; +/*! + * \brief Whether or not use dynamic shared memory. + * + * Type: Integer + */ +constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory"; + /*! * \brief Whether to set noalias rule on the function arguments. * diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index a877bc634300..7d6879a62aba 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -153,12 +153,12 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -168,10 +168,10 @@ class CUDAWrappedFunc { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), - wl.block_dim(2), 0, strm, void_args, nullptr); + wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { const char* msg; cuGetErrorName(result, &msg); @@ -201,8 +201,8 @@ class CUDAWrappedFunc { // Device function cache per device. // mark as mutable, to enable lazy initialization mutable std::array fcache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; class CUDAPrepGlobalBarrier { @@ -241,7 +241,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 32dd1d8020c9..35832e83f59c 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -43,7 +43,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("arg_types", sarg_types); - writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags); + writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags); writer->EndObject(); } @@ -52,7 +52,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { std::vector sarg_types; helper.DeclareField("name", &name); helper.DeclareField("arg_types", &sarg_types); - helper.DeclareField("thread_axis_tags", &thread_axis_tags); + helper.DeclareOptionalField("launch_param_tags", &launch_param_tags); + helper.DeclareOptionalField("thread_axis_tags", + &launch_param_tags); // for backward compatibility helper.ReadAllFields(reader); arg_types.resize(sarg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -63,13 +65,13 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(name); writer->Write(arg_types); - writer->Write(thread_axis_tags); + writer->Write(launch_param_tags); } bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&name)) return false; if (!reader->Read(&arg_types)) return false; - if (!reader->Read(&thread_axis_tags)) return false; + if (!reader->Read(&launch_param_tags)) return false; return true; } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index e3ec155dc291..002012a1e1cc 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -99,11 +99,14 @@ Module MetadataModuleCreate( const std::unordered_map& metadata, const std::unordered_map>& sym_vars); +/*! \brief A tag to specify whether or not dynamic shared memory is used */ +constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; std::vector arg_types; - std::vector thread_axis_tags; + std::vector launch_param_tags; void Save(dmlc::JSONWriter* writer) const; void Load(dmlc::JSONReader* reader); diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 88501880557e..1e81ac1bbb34 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -178,7 +178,7 @@ void SaveToBinary(dmlc::Stream* stream) final { // initialize the METAL function. void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { w_ = metal::MetalWorkspace::Global(); m_ = m; sptr_ = sptr; @@ -186,7 +186,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; std::fill(scache_.begin(), scache_.end(), (id)nil); - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int dev_id = t->device.device_id; scache_[dev_id] = m->GetPipelineState(dev_id, func_name); @@ -201,7 +201,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); @@ -242,8 +242,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons // Device state cache per device. // mark as mutable, to enable lazy initialization mutable std::array, kMetalMaxNumDevice> scache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; PackedFunc MetalModuleNode::GetFunction(const std::string& name, @@ -261,7 +261,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); + info.launch_param_tags); pf = PackFuncNonBufferArg(f, info.arg_types); }; return pf; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 4040d82b33e7..f6c7f6232819 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -40,14 +40,14 @@ class OpenCLWrappedFunc { // initialize the OpenCL function. void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, std::string func_name, std::vector arg_size, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { w_ = m->GetGlobalWorkspace(); m_ = m; sptr_ = sptr; entry_ = entry; func_name_ = func_name; arg_size_ = arg_size; - thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags); + launch_param_config_.Init(arg_size.size(), launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -73,8 +73,8 @@ class OpenCLWrappedFunc { OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg)); } cl_command_queue queue = w_->GetQueue(t->device); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - cl_uint work_dim = static_cast(thread_axis_cfg_.work_dim()); + ThreadWorkLoad wl = launch_param_config_.Extract(args); + cl_uint work_dim = static_cast(launch_param_config_.work_dim()); for (cl_uint i = 0; i < work_dim; ++i) { wl.work_size[i] *= wl.work_size[i + 3]; } @@ -96,8 +96,8 @@ class OpenCLWrappedFunc { std::string func_name_; // convert code for void argument std::vector arg_size_; - // thread axis config - ThreadAxisConfig thread_axis_cfg_; + // launch parameters config + LaunchParamConfig launch_param_config_; }; OpenCLModuleNode::~OpenCLModuleNode() { @@ -148,7 +148,7 @@ PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, } } // initialize the wrapped func. - f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags); + f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 567557c56794..487ad23e16b9 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -147,12 +147,12 @@ class ROCMWrappedFunc { public: // initialize the ROCM function. void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { @@ -164,13 +164,14 @@ class ROCMWrappedFunc { hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, HIP_LAUNCH_PARAM_END}; // HIP supports only extra_args. - ROCM_DRIVER_CALL(hipModuleLaunchKernel( - fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), - wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast(&config))); + ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), + wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), + wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr, + reinterpret_cast(&config))); } private: @@ -183,8 +184,8 @@ class ROCMWrappedFunc { // Device function cache per device. // mark as mutable, to enable lazy initialization mutable std::array fcache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; PackedFunc ROCMModuleNode::GetFunction(const std::string& name, @@ -195,7 +196,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; ROCMWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncPackedArg(f, info.arg_types); } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 9d140aedd810..ac8260ffbe39 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -19,7 +19,7 @@ /*! * \file thread_storage_scope.h - * \brief Extract thread axis configuration from TVMArgs. + * \brief Extract launch parameters configuration from TVMArgs. */ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ @@ -29,6 +29,8 @@ #include #include +#include "meta_data.h" + namespace tvm { namespace runtime { @@ -182,6 +184,8 @@ struct ThreadScope { struct ThreadWorkLoad { // array, first three are thread configuration. size_t work_size[6]; + // Dynamic shared memory allocation size in bytes. + size_t dyn_shmem_size{0}; /*! * \param i The block dimension. * \return i-th block dim @@ -193,17 +197,23 @@ struct ThreadWorkLoad { */ inline size_t grid_dim(size_t i) const { return work_size[i]; } }; -/*! \brief Thread axis configuration */ -class ThreadAxisConfig { +/*! \brief Launch parameters configuration */ +class LaunchParamConfig { public: - void Init(size_t base, const std::vector& thread_axis_tags) { + void Init(size_t base, const std::vector& launch_param_tags) { base_ = base; std::vector filled(6, false); - for (size_t i = 0; i < thread_axis_tags.size(); ++i) { - const std::string& tag = thread_axis_tags[i]; - ThreadScope ts = ThreadScope::Create(tag); - arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); - filled[ts.rank * 3 + ts.dim_index] = true; + for (size_t i = 0; i < launch_param_tags.size(); ++i) { + const std::string& tag = launch_param_tags[i]; + if (tag == kUseDynamicSharedMemoryTag) { + ICHECK_EQ(i, launch_param_tags.size() - 1) + << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; + use_dyn_shared_memory_ = true; + } else { + ThreadScope ts = ThreadScope::Create(tag); + arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); + filled[ts.rank * 3 + ts.dim_index] = true; + } } work_dim_ = 1; for (int i = 0; i < 3; ++i) { @@ -223,6 +233,9 @@ class ThreadAxisConfig { w.work_size[arg_index_map_[i]] = size; } } + if (use_dyn_shared_memory_) { + w.dyn_shmem_size = static_cast(x.values[base_ + arg_index_map_.size()].v_int64); + } return w; } // return the work dim @@ -235,6 +248,8 @@ class ThreadAxisConfig { size_t work_dim_; /*! \brief The index mapping. */ std::vector arg_index_map_; + /*! \brief Whether or not use dynamic shared memory. */ + bool use_dyn_shared_memory_{false}; }; } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 103b2aa7692c..0712f723bb64 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -33,13 +33,13 @@ namespace vulkan { void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); } void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, @@ -50,7 +50,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); } const auto& pipeline = scache_[device_id]; - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); std::vector descriptor_buffers; descriptor_buffers.resize(num_buffer_args_); for (size_t i = 0; i < num_buffer_args_; ++i) { @@ -197,7 +197,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& name, VulkanWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); + info.launch_param_tags); return PackFuncNonBufferArg(std::move(f), info.arg_types); } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index a174f22eba59..cd4774bf0f5a 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -58,7 +58,7 @@ class VulkanWrappedFunc { public: void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags); + const std::vector& launch_param_tags); void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; @@ -73,11 +73,10 @@ class VulkanWrappedFunc { size_t num_buffer_args_; // number of packed arguments. size_t num_pack_args_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; // Device state cache per device. // mark as mutable, to enable lazy initialization - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; - mutable std::array, kVulkanMaxNumDevice> scache_; }; diff --git a/src/target/build_common.h b/src/target/build_common.h index d2fe6468eef8..c66c2b52822e 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -53,7 +53,12 @@ inline std::unordered_map ExtractFuncInfo(co if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { - info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); + info.launch_param_tags.push_back(thread_axis[i]->thread_tag); + } + } + if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { + if (opt.value()) { + info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 9aec8f4e867b..7770e42086de 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -72,51 +72,45 @@ class CodeGenAMDGPU : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the AMD devices - if (info.alignment > 16) { - info.alignment = 16; - } auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { -#if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); -#else - alloca->setAlignment(info.alignment); -#endif - } - buf = alloca; + + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + LOG(WARNING) << "Dynamic shared memory support for rocm is experimental."; + buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), + llvm::GlobalValue::ExternalLinkage); } else { - ICHECK(storage_scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; - // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); - if (global->getAlignment() < static_cast(info.alignment)) { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + // maximum necessary alignment in the AMD devices + if (info.alignment > 16) { + info.alignment = 16; + } + if (storage_scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif + } + buf = alloca; + } else { + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + // Shared memory: address space == 3 + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, + llvm::GlobalValue::PrivateLinkage); } - buf = global; } buf = builder_->CreatePointerCast( diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bdae93b82aff..b83748b784b6 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -524,6 +524,22 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp *p_alignment = align_bits / 8; } +llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t size, + unsigned int shared_address_space, + int alignment, + llvm::GlobalValue::LinkageTypes linkage) { + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); +#if TVM_LLVM_VERSION >= 100 + global->setAlignment(llvm::Align(alignment)); +#else + global->setAlignment(alignment); +#endif + return global; +} + std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { #if TVM_LLVM_VERSION >= 100 auto debug_info = std::make_unique(); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 810e59be7214..52c5b98a0025 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -292,6 +292,11 @@ class CodeGenLLVM : public ExprFunctor, const Var& loop_var, const Stmt& body); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index); + + llvm::GlobalVariable* AllocateSharedMemory(DataType dtype, size_t size, + unsigned int shared_address_space, int alignment, + llvm::GlobalValue::LinkageTypes linkage); + // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 43ea0e6b7ae9..15543eda423f 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -48,48 +48,44 @@ class CodeGenNVPTX : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } // maximum necessary alignment in the NV devices if (info.alignment > 16) { info.alignment = 16; } + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { -#if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); -#else - alloca->setAlignment(info.alignment); -#endif - } - buf = alloca; - } else { - ICHECK(storage_scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); + buf = + AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + if (storage_scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif - buf = global; + } + buf = alloca; + } else { + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, + llvm::GlobalValue::PrivateLinkage); + } } buf = builder_->CreatePointerCast( diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d7dcbec7ebe3..7897490730a3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -525,6 +525,8 @@ void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) "all global arrays as input instead"; if (scope == "shared") { os << "__shared__ "; + } else if (scope == "shared.dyn") { + os << "extern __shared__ "; } } @@ -703,9 +705,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; std::string scope = GetPtrStorageScope(op->buffer_var); + const VarNode* buffer = op->buffer_var.as(); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || @@ -719,19 +720,28 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } - const VarNode* buffer = op->buffer_var.as(); - constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); } - if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && - scope == "shared") { - constant_size = constant_size / (32 / op->dtype.bits()); + + if (scope == "shared.dyn") { + stream << ' ' << vid << "[];\n"; + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + + if (scope.find("wmma.") == 0) { + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + } + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { + constant_size = constant_size / (32 / op->dtype.bits()); + } + stream << ' ' << vid << '[' << constant_size << "];\n"; } - stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 829b7d822d11..eafed837cee3 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -67,7 +67,7 @@ class StorageAccessInfoLower : public StmtExprMutator { StorageScope scope = StorageScope::Create(op->value.as()->value); StorageEntry e; e.scope = scope; - if (scope.tag.length() != 0) { + if (scope.tag.length() != 0 && scope.tag != ".dyn") { e.info = GetMemoryInfo(op->value.as()->value); ICHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index f01d98707586..795ae9d6a73a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -33,6 +33,9 @@ #include +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -89,6 +92,17 @@ class VarUseDefAnalysis : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { this->HandleDef(op->buffer_var.get()); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); + dyn_shmem_size_ = op->extents[0]; + for (size_t i = 1; i < op->extents.size(); ++i) { + dyn_shmem_size_ *= op->extents[i]; + } + dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); + use_dyn_shmem_ = true; + } return StmtExprMutator::VisitStmt_(op); } @@ -175,6 +189,8 @@ class VarUseDefAnalysis : public StmtExprMutator { Array undefined_; Array thread_axis_; Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; std::unordered_map use_count_; std::unordered_map def_count_; @@ -262,6 +278,10 @@ class HostDeviceSplitter : public StmtMutator { WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); + if (m.use_dyn_shmem_) { + device_func = + WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); + } (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); // generate calls to the device function @@ -273,6 +293,9 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } + if (m.use_dyn_shmem_) { + call_args.push_back(m.dyn_shmem_size_); + } return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 613d02614b39..b216b8b848db 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -512,7 +512,7 @@ class StoragePlanRewriter : public StmtExprMutator { // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { @@ -546,7 +546,7 @@ class StoragePlanRewriter : public StmtExprMutator { make_const(DataType::Int(32), 1), e->allocs[0]->extents); e->new_alloc = Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) @@ -587,7 +587,7 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = analyzer_.Simplify(combo_size); e->new_alloc = Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 355d3abed559..0329134bb3fa 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -18,6 +18,7 @@ from tvm import te import numpy as np import tvm.testing +from tvm.topi.math import cast def test_for(): @@ -497,6 +498,62 @@ def check_target(target, ir): check_target("vulkan", searchsorted_ir_gpu) +@tvm.testing.requires_gpu +def test_dyn_shared(): + n = te.size_var("n") + dtype = "float32" + A = te.placeholder((n,), name="A") + + def test_device_ir(A, B): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + temp = ib.allocate(dtype, (n,), scope="shared.dyn") # n is symbolic size + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + + temp[tx] = Aptr[tx] + depth = tvm.tir.log2(cast(n, "float32")) + + with ib.for_range(0, depth) as i: + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + d = n >> (i + 1) + with ib.if_scope(tx < d): + temp[tx] += temp[tx + d] + + Bptr[0] = temp[0] + return ib.get() + + B = te.extern( + (1,), + [A], + lambda ins, outs: test_device_ir(ins[0], outs[0]), + name="reduce", + dtype=dtype, + ) + s = te.create_schedule(B.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + freduce = tvm.build(s, [A, B], target) + dev = tvm.device(target, 0) + + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.zeros(1, dtype=B.dtype), dev) + freduce(a, b) + tvm.testing.assert_allclose(b.numpy()[0], np.sum(a.numpy()), 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + if __name__ == "__main__": test_prefetch() test_if() @@ -507,3 +564,4 @@ def check_target(target, ir): test_while_collatz() test_while_mandel() test_while_binary_search() + test_dyn_shared() diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index f12837f421f8..226797eb7d19 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -39,7 +39,7 @@ export async function detectGPUDevice(): Promise { interface FunctionInfo { name: string; arg_types: Array; - thread_axis_tags: Array; + launch_param_tags: Array; } /** @@ -114,8 +114,8 @@ export class WebGPUContext { const dispatchToDim: Array = []; - for (let i = 0; i < finfo.thread_axis_tags.length; ++i) { - const tag: string = finfo.thread_axis_tags[i]; + for (let i = 0; i < finfo.launch_param_tags.length; ++i) { + const tag: string = finfo.launch_param_tags[i]; if (tag.startsWith("blockIdx.")) { const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); assert(target >= 0 && target < 3);