Skip to content

Commit

Permalink
ThreadAxisConfig -> LaunchParamConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
masa committed Jul 20, 2021
1 parent 2fae2a7 commit 06187d6
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(
fcache_[device_id],
wl.grid_dim(0),
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class CUDAWrappedFunc {
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
thread_axis_cfg_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory);
launch_param_config_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
Expand All @@ -169,7 +169,7 @@ class CUDAWrappedFunc {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr);
Expand Down Expand Up @@ -203,7 +203,7 @@ class CUDAWrappedFunc {
// mark as mutable, to enable lazy initialization
mutable std::array<CUfunction, kMaxNumGPUs> fcache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
LaunchParamConfig launch_param_config_;
};

class CUDAPrepGlobalBarrier {
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_na
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
std::fill(scache_.begin(), scache_.end(), (id<MTLComputePipelineState>)nil);
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, launch_param_tags);
launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags);
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int dev_id = t->device.device_id;
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
Expand All @@ -201,7 +201,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
Expand Down Expand Up @@ -243,7 +243,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
// mark as mutable, to enable lazy initialization
mutable std::array<id<MTLComputePipelineState>, kMetalMaxNumDevice> scache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
LaunchParamConfig launch_param_config_;
};

PackedFunc MetalModuleNode::GetFunction(const std::string& name,
Expand Down
8 changes: 4 additions & 4 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class OpenCLWrappedFunc {
entry_ = entry;
func_name_ = func_name;
arg_size_ = arg_size;
thread_axis_cfg_.Init(arg_size.size(), launch_param_tags);
launch_param_config_.Init(arg_size.size(), launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
Expand All @@ -73,8 +73,8 @@ class OpenCLWrappedFunc {
OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg));
}
cl_command_queue queue = w_->GetQueue(t->device);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
ThreadWorkLoad wl = launch_param_config_.Extract(args);
cl_uint work_dim = static_cast<cl_uint>(launch_param_config_.work_dim());
for (cl_uint i = 0; i < work_dim; ++i) {
wl.work_size[i] *= wl.work_size[i + 3];
}
Expand All @@ -97,7 +97,7 @@ class OpenCLWrappedFunc {
// convert code for void argument
std::vector<size_t> arg_size_;
// thread axis config
ThreadAxisConfig thread_axis_cfg_;
LaunchParamConfig launch_param_config_;
};

OpenCLModuleNode::~OpenCLModuleNode() {
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/rocm/rocm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class ROCMWrappedFunc {
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
thread_axis_cfg_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory);
launch_param_config_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const {
Expand All @@ -165,7 +165,7 @@ class ROCMWrappedFunc {

hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);

ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE,
&packed_nbytes, HIP_LAUNCH_PARAM_END};
// HIP supports only extra_args.
Expand All @@ -186,7 +186,7 @@ class ROCMWrappedFunc {
// mark as mutable, to enable lazy initialization
mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
LaunchParamConfig launch_param_config_;
};

PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ struct ThreadWorkLoad {
inline size_t grid_dim(size_t i) const { return work_size[i]; }
};
/*! \brief Thread axis configuration */
class ThreadAxisConfig {
class LaunchParamConfig {
public:
void Init(size_t base, const std::vector<std::string>& launch_param_tags,
bool use_dyn_shared_memory = false) {
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/vulkan/vulkan_wrapped_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr<Object> sptr,
func_name_ = func_name;
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, launch_param_tags);
launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags);
}

void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
Expand All @@ -50,7 +50,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
}
const auto& pipeline = scache_[device_id];
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
std::vector<VkDescriptorBufferInfo> descriptor_buffers;
descriptor_buffers.resize(num_buffer_args_);
for (size_t i = 0; i < num_buffer_args_; ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/vulkan/vulkan_wrapped_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class VulkanWrappedFunc {
// Device state cache per device.
// mark as mutable, to enable lazy initialization
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
LaunchParamConfig launch_param_config_;

mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
};
Expand Down

0 comments on commit 06187d6

Please sign in to comment.