Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Initial support for dynamic shared memory #8466

Merged
merged 30 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f709892
send dyn shmem size to runtime
masahi Jun 21, 2021
f20780b
add dyn shared storage scope
masahi Jun 21, 2021
44ca4bf
associate buffer var and its storage scoe in split_host_device
masahi Jun 21, 2021
d9afa37
tried NVPTX but failed with INVALID_PTX error
Jun 27, 2021
2c81667
test stub
Jun 27, 2021
51e87cb
dynamic shmem reduce working
Jun 27, 2021
95e1b81
log2 issue fixed
Jun 28, 2021
719f40d
nvptx working
Jun 28, 2021
6d9c7c4
refactor llvm shmem allocation
Jun 28, 2021
e0fbac2
make linkage argument
Jun 28, 2021
d781526
support rocm too
Jun 28, 2021
933f9c5
send dyn shmem param to hip runtime
Jun 28, 2021
509a8c1
remove alloc map from split_host_device.cc
Jun 29, 2021
0ea0962
remove attr::storage_scope from split_host_device
Jul 14, 2021
84666a4
lint fix
Jul 14, 2021
5bcfacd
formatting
Jul 14, 2021
3ac8401
update calling convention doc
Jul 14, 2021
283e04c
minor update to test
Jul 14, 2021
4389ccb
remove log
masahi Jul 14, 2021
2085bfe
remove kDynShared, dyn.shared -> shared.dyn
Jul 17, 2021
94b1a78
support backward compat
Jul 17, 2021
64a9d5e
update json/binary reader/writer
Jul 17, 2021
c333cdc
thread_axis_tags -> launch_param_tags
Jul 20, 2021
19ec309
ThreadAxisConfig -> LaunchParamConfig
Jul 20, 2021
8deab0b
remove use_dynamic_shared_memory from FunctionInfo meta data
Jul 20, 2021
7cbc700
revert change in test_tir_ir_builder.py
Jul 20, 2021
1151a52
make sure kUseDynamicSharedMemoryTag is the last tag
Jul 20, 2021
587e5b6
remove continue
Jul 20, 2021
ffc138a
update doc string following name change
Jul 20, 2021
b8c05a5
more comment update following name change
Jul 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ The first time you invoke the compiled module with ``fadd(a, b, c)``, ``GetFunct
auto it = fmap_.find(name);
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}

Expand All @@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(
fcache_[device_id],
wl.grid_dim(0),
Expand Down
11 changes: 10 additions & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,26 @@ namespace attr {
*
* Call(f,
* [arg1, arg2, ..., arg_n,
* work_size_1, work_size_2, ... work_size_m])
* work_size_1, work_size_2, ... work_size_m, dyn_shmem_size])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us introduce a special meta-data to indicate that dynamic shared memory is used. This is to make sure the calling convention is backward compatible when dyn shared memory is not provided.

constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory";

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

*
* Here n = len(arg), m = len(work_size) = len(device_thread_axis).
*
* When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted.
*
* The list of device_thread_axis indicates how can be bind the
* work_size arguments to the corresponding threads.
*
* \sa tvm::CallingConv::kDeviceKernelLaunch
*/
constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";

/*!
* \brief Whether or not use dynamic shared memory.
*
* Type: Integer
*/
constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory";

/*!
* \brief Whether to set noalias rule on the function arguments.
*
Expand Down
14 changes: 7 additions & 7 deletions src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ class CUDAWrappedFunc {
public:
// initialize the CUDA function.
void Init(CUDAModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_void_args, const std::vector<std::string>& thread_axis_tags) {
size_t num_void_args, const std::vector<std::string>& launch_param_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
launch_param_config_.Init(num_void_args, launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
Expand All @@ -168,10 +168,10 @@ class CUDAWrappedFunc {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), 0, strm, void_args, nullptr);
wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr);
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
const char* msg;
cuGetErrorName(result, &msg);
Expand Down Expand Up @@ -201,8 +201,8 @@ class CUDAWrappedFunc {
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<CUfunction, kMaxNumGPUs> fcache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
// launch parameters configuration
LaunchParamConfig launch_param_config_;
};

class CUDAPrepGlobalBarrier {
Expand Down Expand Up @@ -241,7 +241,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& name,
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}

Expand Down
10 changes: 6 additions & 4 deletions src/runtime/file_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
writer->EndObject();
}

Expand All @@ -52,7 +52,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareField("thread_axis_tags", &thread_axis_tags);
helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
helper.DeclareOptionalField("thread_axis_tags",
&launch_param_tags); // for backward compatibility
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
Expand All @@ -63,13 +65,13 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(thread_axis_tags);
writer->Write(launch_param_tags);
}

bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&thread_axis_tags)) return false;
if (!reader->Read(&launch_param_tags)) return false;
return true;
}

Expand Down
5 changes: 4 additions & 1 deletion src/runtime/meta_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ Module MetadataModuleCreate(
const std::unordered_map<std::string, NDArray>& metadata,
const std::unordered_map<std::string, std::vector<std::string>>& sym_vars);

/*! \brief A tag to specify whether or not dynamic shared memory is used */
constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory";

/*! \brief function information needed by device */
struct FunctionInfo {
std::string name;
std::vector<DLDataType> arg_types;
std::vector<std::string> thread_axis_tags;
std::vector<std::string> launch_param_tags;

void Save(dmlc::JSONWriter* writer) const;
void Load(dmlc::JSONReader* reader);
Expand Down
12 changes: 6 additions & 6 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ void SaveToBinary(dmlc::Stream* stream) final {
// initialize the METAL function.
void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
const std::vector<std::string>& launch_param_tags) {
w_ = metal::MetalWorkspace::Global();
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
std::fill(scache_.begin(), scache_.end(), (id<MTLComputePipelineState>)nil);
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags);
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int dev_id = t->device.device_id;
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
Expand All @@ -201,7 +201,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
Expand Down Expand Up @@ -242,8 +242,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
// Device state cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<id<MTLComputePipelineState>, kMetalMaxNumDevice> scache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
// launch parameters configuration
LaunchParamConfig launch_param_config_;
};

PackedFunc MetalModuleNode::GetFunction(const std::string& name,
Expand All @@ -261,7 +261,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
MetalWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
info.thread_axis_tags);
info.launch_param_tags);
pf = PackFuncNonBufferArg(f, info.arg_types);
};
return pf;
Expand Down
14 changes: 7 additions & 7 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ class OpenCLWrappedFunc {
// initialize the OpenCL function.
void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr, OpenCLModuleNode::KTRefEntry entry,
std::string func_name, std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) {
const std::vector<std::string>& launch_param_tags) {
w_ = m->GetGlobalWorkspace();
m_ = m;
sptr_ = sptr;
entry_ = entry;
func_name_ = func_name;
arg_size_ = arg_size;
thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags);
launch_param_config_.Init(arg_size.size(), launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
Expand All @@ -73,8 +73,8 @@ class OpenCLWrappedFunc {
OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg));
}
cl_command_queue queue = w_->GetQueue(t->device);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
ThreadWorkLoad wl = launch_param_config_.Extract(args);
cl_uint work_dim = static_cast<cl_uint>(launch_param_config_.work_dim());
for (cl_uint i = 0; i < work_dim; ++i) {
wl.work_size[i] *= wl.work_size[i + 3];
}
Expand All @@ -96,8 +96,8 @@ class OpenCLWrappedFunc {
std::string func_name_;
// convert code for void argument
std::vector<size_t> arg_size_;
// thread axis config
ThreadAxisConfig thread_axis_cfg_;
// launch parameters config
LaunchParamConfig launch_param_config_;
};

OpenCLModuleNode::~OpenCLModuleNode() {
Expand Down Expand Up @@ -148,7 +148,7 @@ PackedFunc OpenCLModuleNode::GetFunction(const std::string& name,
}
}
// initialize the wrapped func.
f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags);
f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}

Expand Down
19 changes: 10 additions & 9 deletions src/runtime/rocm/rocm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ class ROCMWrappedFunc {
public:
// initialize the ROCM function.
void Init(ROCMModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_void_args, const std::vector<std::string>& thread_axis_tags) {
size_t num_void_args, const std::vector<std::string>& launch_param_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
launch_param_config_.Init(num_void_args, launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const {
Expand All @@ -164,13 +164,14 @@ class ROCMWrappedFunc {

hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);

ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE,
&packed_nbytes, HIP_LAUNCH_PARAM_END};
// HIP supports only extra_args.
ROCM_DRIVER_CALL(hipModuleLaunchKernel(
fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0),
wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast<void**>(&config)));
ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr,
reinterpret_cast<void**>(&config)));
}

private:
Expand All @@ -183,8 +184,8 @@ class ROCMWrappedFunc {
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
// launch parameters configuration
LaunchParamConfig launch_param_config_;
};

PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
Expand All @@ -195,7 +196,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
ROCMWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
return PackFuncPackedArg(f, info.arg_types);
}

Expand Down
33 changes: 24 additions & 9 deletions src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file thread_storage_scope.h
* \brief Extract thread axis configuration from TVMArgs.
* \brief Extract launch parameters configuration from TVMArgs.
*/
#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
Expand All @@ -29,6 +29,8 @@
#include <string>
#include <vector>

#include "meta_data.h"

namespace tvm {
namespace runtime {

Expand Down Expand Up @@ -182,6 +184,8 @@ struct ThreadScope {
struct ThreadWorkLoad {
// array, first three are thread configuration.
size_t work_size[6];
// Dynamic shared memory allocation size in bytes.
size_t dyn_shmem_size{0};
/*!
* \param i The block dimension.
* \return i-th block dim
Expand All @@ -193,17 +197,23 @@ struct ThreadWorkLoad {
*/
inline size_t grid_dim(size_t i) const { return work_size[i]; }
};
/*! \brief Thread axis configuration */
class ThreadAxisConfig {
/*! \brief Launch parameters configuration */
class LaunchParamConfig {
public:
void Init(size_t base, const std::vector<std::string>& thread_axis_tags) {
void Init(size_t base, const std::vector<std::string>& launch_param_tags) {
base_ = base;
std::vector<bool> filled(6, false);
for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
const std::string& tag = thread_axis_tags[i];
ThreadScope ts = ThreadScope::Create(tag);
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
filled[ts.rank * 3 + ts.dim_index] = true;
for (size_t i = 0; i < launch_param_tags.size(); ++i) {
const std::string& tag = launch_param_tags[i];
if (tag == kUseDynamicSharedMemoryTag) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us always assert this is the last tag that follows the arg_index map. Since it can be strange if the tag appears in the middle.

We can also handle a generalized case where the arg indicates the location of the parameter. In that case if the tag is
[dyn_mem, threadIdx.x], then x.values[base_+0] would correspond to the dynamic shared mem size

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assert

ICHECK_EQ(i, launch_param_tags.size() - 1)
<< "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags.";
use_dyn_shared_memory_ = true;
} else {
ThreadScope ts = ThreadScope::Create(tag);
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
filled[ts.rank * 3 + ts.dim_index] = true;
}
}
work_dim_ = 1;
for (int i = 0; i < 3; ++i) {
Expand All @@ -223,6 +233,9 @@ class ThreadAxisConfig {
w.work_size[arg_index_map_[i]] = size;
}
}
if (use_dyn_shared_memory_) {
w.dyn_shmem_size = static_cast<size_t>(x.values[base_ + arg_index_map_.size()].v_int64);
}
return w;
}
// return the work dim
Expand All @@ -235,6 +248,8 @@ class ThreadAxisConfig {
size_t work_dim_;
/*! \brief The index mapping. */
std::vector<uint32_t> arg_index_map_;
/*! \brief Whether or not use dynamic shared memory. */
bool use_dyn_shared_memory_{false};
};

} // namespace runtime
Expand Down
Loading