Skip to content

Commit

Permalink
[Ansor] Add HW param for Vulkan tuning (apache#7626)
Browse files Browse the repository at this point in the history
* add HW param for VK

* query warp size properly

* guard against warp_size < 4 case

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
  • Loading branch information
2 people authored and Trevor Morris committed May 6, 2021
1 parent 88219b6 commit 6400d0c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
23 changes: 23 additions & 0 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,29 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
auto target_device = target->GetAttr<String>("device", "");
LOG(FATAL) << "No default hardware parameters for opencl target device: " << target_device;
}
} else if (device_type == kDLVulkan) {
auto ctx = TVMContext{static_cast<DLDeviceType>(device_type), 0};
auto device_name = "device_api.vulkan";
auto func = tvm::runtime::Registry::Get(device_name);
ICHECK(func != nullptr) << "Cannot find Vulkan device_api in registry";
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());

tvm::runtime::TVMRetValue ret;
device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
int max_shared_memory_per_block = ret;

int max_local_memory_per_block = INT32_MAX;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
int max_threads_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
int warp_size = ret;

int max_vthread_extent = std::max(1, warp_size / 4);

return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else {
LOG(FATAL) << "No default hardware parameters for target: " << target;
}
Expand Down
25 changes: 16 additions & 9 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,28 +367,37 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
}
ICHECK_LT(index, context_.size()) << "Invalid device id " << index;
const auto& vctx = context(index);
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);

switch (kind) {
case kMaxThreadsPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations;
*rv = value;
break;
}
case kMaxSharedMemoryPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
*rv = value;
break;
}
case kWarpSize: {
*rv = 1;
VkPhysicalDeviceSubgroupProperties subgroup_prop;
subgroup_prop.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
subgroup_prop.pNext = NULL;

VkPhysicalDeviceProperties2 phy_prop2;
phy_prop2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
phy_prop2.pNext = &subgroup_prop;

vkGetPhysicalDeviceProperties2(vctx.phy_device, &phy_prop2);
int64_t subgroup_size = subgroup_prop.subgroupSize;
ICHECK(subgroup_size >= 1);

*rv = subgroup_size;
break;
}
case kComputeVersion: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t value = phy_prop.apiVersion;
std::ostringstream os;
os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "."
Expand All @@ -405,8 +414,6 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
case kExist:
break;
case kMaxThreadDimensions: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t dims[3];
dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0];
dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1];
Expand Down

0 comments on commit 6400d0c

Please sign in to comment.