-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA] Initial support for dynamic shared memory #8466
Changes from all commits
f709892
f20780b
44ca4bf
d9afa37
2c81667
51e87cb
95e1b81
719f40d
6d9c7c4
e0fbac2
d781526
933f9c5
509a8c1
0ea0962
84666a4
5bcfacd
3ac8401
283e04c
4389ccb
2085bfe
94b1a78
64a9d5e
c333cdc
19ec309
8deab0b
7cbc700
1151a52
587e5b6
ffc138a
b8c05a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
|
||
/*! | ||
* \file thread_storage_scope.h | ||
* \brief Extract thread axis configuration from TVMArgs. | ||
* \brief Extract launch parameters configuration from TVMArgs. | ||
*/ | ||
#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ | ||
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ | ||
|
@@ -29,6 +29,8 @@ | |
#include <string> | ||
#include <vector> | ||
|
||
#include "meta_data.h" | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
|
||
|
@@ -182,6 +184,8 @@ struct ThreadScope { | |
struct ThreadWorkLoad { | ||
// array, first three are thread configuration. | ||
size_t work_size[6]; | ||
// Dynamic shared memory allocation size in bytes. | ||
size_t dyn_shmem_size{0}; | ||
/*! | ||
* \param i The block dimension. | ||
* \return i-th block dim | ||
|
@@ -193,17 +197,23 @@ struct ThreadWorkLoad { | |
*/ | ||
inline size_t grid_dim(size_t i) const { return work_size[i]; } | ||
}; | ||
/*! \brief Thread axis configuration */ | ||
class ThreadAxisConfig { | ||
/*! \brief Launch parameters configuration */ | ||
class LaunchParamConfig { | ||
public: | ||
void Init(size_t base, const std::vector<std::string>& thread_axis_tags) { | ||
void Init(size_t base, const std::vector<std::string>& launch_param_tags) { | ||
base_ = base; | ||
std::vector<bool> filled(6, false); | ||
for (size_t i = 0; i < thread_axis_tags.size(); ++i) { | ||
const std::string& tag = thread_axis_tags[i]; | ||
ThreadScope ts = ThreadScope::Create(tag); | ||
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); | ||
filled[ts.rank * 3 + ts.dim_index] = true; | ||
for (size_t i = 0; i < launch_param_tags.size(); ++i) { | ||
const std::string& tag = launch_param_tags[i]; | ||
if (tag == kUseDynamicSharedMemoryTag) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let us always assert this is the last tag that follows the arg_index map. Since it can be strange if the tag appears in the middle. We can also handle a generalized case where the arg indicates the location of the parameter. In that case if the tag is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added an assert |
||
ICHECK_EQ(i, launch_param_tags.size() - 1) | ||
<< "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; | ||
use_dyn_shared_memory_ = true; | ||
} else { | ||
ThreadScope ts = ThreadScope::Create(tag); | ||
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); | ||
filled[ts.rank * 3 + ts.dim_index] = true; | ||
} | ||
} | ||
work_dim_ = 1; | ||
for (int i = 0; i < 3; ++i) { | ||
|
@@ -223,6 +233,9 @@ class ThreadAxisConfig { | |
w.work_size[arg_index_map_[i]] = size; | ||
} | ||
} | ||
if (use_dyn_shared_memory_) { | ||
w.dyn_shmem_size = static_cast<size_t>(x.values[base_ + arg_index_map_.size()].v_int64); | ||
} | ||
return w; | ||
} | ||
// return the work dim | ||
|
@@ -235,6 +248,8 @@ class ThreadAxisConfig { | |
size_t work_dim_; | ||
/*! \brief The index mapping. */ | ||
std::vector<uint32_t> arg_index_map_; | ||
/*! \brief Whether or not use dynamic shared memory. */ | ||
bool use_dyn_shared_memory_{false}; | ||
}; | ||
|
||
} // namespace runtime | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us introduce a special meta-data to indicate that dynamic shared memory is used. This is to make sure the calling convention is backward compatible when dyn shared memory is not provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done