Skip to content

Commit

Permalink
[Vulkan] Remove some interface block decoration (apache#8102)
Browse files Browse the repository at this point in the history
* Remove block decorator for shared/local variables

* Fix lint
  • Loading branch information
llehtahw authored and Trevor Morris committed Jun 17, 2021
1 parent 82ff03d commit 9025cc7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
33 changes: 18 additions & 15 deletions src/target/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass stora
return t;
}

SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems) {
auto key = std::make_pair(value_type.id, num_elems);
SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems,
bool interface_block) {
auto key = std::make_tuple(value_type.id, num_elems, interface_block);
auto it = struct_array_type_tbl_.find(key);
if (it != struct_array_type_tbl_.end()) {
return it->second;
Expand Down Expand Up @@ -171,17 +172,19 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems)
.AddSeq(struct_type, 0, spv::DecorationOffset, 0)
.Commit(&decorate_);

// Runtime array are always decorated as Block or BufferBlock
// (shader storage buffer)
if (spirv_support_.supports_storage_buffer_storage_class) {
// If SPIRV 1.3+, or with extension
// SPV_KHR_storage_buffer_storage_class, BufferBlock is
// deprecated.
extensions_used_.insert("SPV_KHR_storage_buffer_storage_class");
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
} else {
if (num_elems == 0) {
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
if (interface_block) {
// Runtime array are always decorated as Block or BufferBlock
// (shader storage buffer)
if (spirv_support_.supports_storage_buffer_storage_class) {
// If SPIRV 1.3+, or with extension
// SPV_KHR_storage_buffer_storage_class, BufferBlock is
// deprecated.
extensions_used_.insert("SPV_KHR_storage_buffer_storage_class");
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
} else {
if (num_elems == 0) {
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
}
}
}
struct_array_type_tbl_[key] = struct_type;
Expand Down Expand Up @@ -224,7 +227,7 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set
storage_class = spv::StorageClassUniform;
}

SType sarr_type = GetStructArrayType(value_type, 0);
SType sarr_type = GetStructArrayType(value_type, 0, true);
SType ptr_type = GetPointerType(sarr_type, storage_class);
Value val = NewValue(ptr_type, kStructArrayPtr);

Expand Down Expand Up @@ -335,7 +338,7 @@ void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) {
Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems,
spv::StorageClass storage_class) {
ICHECK_NE(num_elems, 0U);
SType sarr_type = GetStructArrayType(value_type, num_elems);
SType sarr_type = GetStructArrayType(value_type, num_elems, false);
SType ptr_type = GetPointerType(sarr_type, storage_class);
Value val = NewValue(ptr_type, kStructArrayPtr);
if (storage_class == spv::StorageClassFunction) {
Expand Down
7 changes: 5 additions & 2 deletions src/target/spirv/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <tuple>
#include <spirv.hpp>
// clang-format on

Expand Down Expand Up @@ -432,10 +433,12 @@ class IRBuilder {
* \param value_type the content value type.
* \param num_elems number of elements in array
* num_elems = 0 means runtime array with BufferBlock Decoration
* \param interface_block if this array type for interface blocks(input, output, uniform,
* storage buffer).
*
* \return The corresponding spirv type.
*/
SType GetStructArrayType(const SType& value_type, uint32_t num_elems);
SType GetStructArrayType(const SType& value_type, uint32_t num_elems, bool interface_block);
/*!
* \brief Get a struct array access with a given index.
* \param ptr_type The pointer type.
Expand Down Expand Up @@ -634,7 +637,7 @@ class IRBuilder {
/*! \brief map from type code to the type */
std::unordered_map<uint32_t, SType> pod_type_tbl_;
/*! \brief map from value to array type */
std::map<std::pair<uint32_t, uint32_t>, SType> struct_array_type_tbl_;
std::map<std::tuple<uint32_t, uint32_t, bool>, SType> struct_array_type_tbl_;
/*! \brief map from value to its pointer type */
std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_;
/*! \brief map from constant int to its value */
Expand Down
3 changes: 3 additions & 0 deletions tests/python/integration/test_ewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def check_device(device, host="stackvm"):
if not tvm.testing.device_enabled(host):
return
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled.." % device)
return
fexp = tvm.build(s, [A, B], device, host, name="myexp")
dev = tvm.device(device, 0)
# launch the kernel.
Expand Down

0 comments on commit 9025cc7

Please sign in to comment.