From 99166bccf368c5784cc30c5d9df1e94dcd63081e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Mar 2021 11:33:03 +0900 Subject: [PATCH 1/3] add HW param for VK --- src/auto_scheduler/search_task.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 22c2893141cf..d405a9a82664 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -106,6 +106,14 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target auto target_device = target->GetAttr("device", ""); LOG(FATAL) << "No default hardware parameters for opencl target device: " << target_device; } + } else if (device_type == kDLVulkan) { + int max_shared_memory_per_block = 48000; + int max_local_memory_per_block = INT32_MAX; // skip the check on local memory + int max_threads_per_block = 256; + int warp_size = 64; + int max_vthread_extent = warp_size / 4; + return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, + max_threads_per_block, max_vthread_extent, warp_size); } else { LOG(FATAL) << "No default hardware parameters for target: " << target; } From e4dd4b7a80c94401fabc36a3c917e34727d7d2e8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Mar 2021 13:38:59 +0900 Subject: [PATCH 2/3] query warp size properly --- src/auto_scheduler/search_task.cc | 22 ++++++++++++++++++---- src/runtime/vulkan/vulkan.cc | 25 ++++++++++++++++--------- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index d405a9a82664..00363609afbc 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -107,10 +107,24 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target LOG(FATAL) << "No default hardware parameters for opencl target device: " << target_device; } } else if (device_type == kDLVulkan) { - int max_shared_memory_per_block = 48000; - int max_local_memory_per_block = INT32_MAX; // skip the check on local memory - int max_threads_per_block = 256; - int warp_size = 64; + auto ctx = TVMContext{static_cast(device_type), 0}; + auto device_name = "device_api.vulkan"; + auto func = tvm::runtime::Registry::Get(device_name); + ICHECK(func != nullptr) << "Cannot find Vulkan device_api in registry"; + auto device_api = static_cast(((*func)()).operator void*()); + + tvm::runtime::TVMRetValue ret; + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); + int max_shared_memory_per_block = ret; + + int max_local_memory_per_block = INT32_MAX; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); + int max_threads_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); + int warp_size = ret; + int max_vthread_extent = warp_size / 4; return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, max_threads_per_block, max_vthread_extent, warp_size); diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 794f3c570f96..ff1b82f930d7 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -367,28 +367,37 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* } ICHECK_LT(index, context_.size()) << "Invalid device id " << index; const auto& vctx = context(index); + VkPhysicalDeviceProperties phy_prop; + vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); + switch (kind) { case kMaxThreadsPerBlock: { - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations; *rv = value; break; } case kMaxSharedMemoryPerBlock: { - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); int64_t value = phy_prop.limits.maxComputeSharedMemorySize; *rv = value; break; } case kWarpSize: { - *rv = 1; + VkPhysicalDeviceSubgroupProperties subgroup_prop; + subgroup_prop.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + subgroup_prop.pNext = NULL; + + VkPhysicalDeviceProperties2 phy_prop2; + phy_prop2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; + phy_prop2.pNext = &subgroup_prop; + + vkGetPhysicalDeviceProperties2(vctx.phy_device, &phy_prop2); + int64_t subgroup_size = subgroup_prop.subgroupSize; + ICHECK(subgroup_size >= 1); + + *rv = subgroup_size; break; } case kComputeVersion: { - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); int64_t value = phy_prop.apiVersion; std::ostringstream os; os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "." @@ -405,8 +414,6 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* case kExist: break; case kMaxThreadDimensions: { - VkPhysicalDeviceProperties phy_prop; - vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); int64_t dims[3]; dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0]; dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1]; From 88ddaa4df8ae79376836c947e6e26e50bedc4c2b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Mar 2021 16:53:54 +0900 Subject: [PATCH 3/3] guard against warp_size < 4 case --- src/auto_scheduler/search_task.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 00363609afbc..f25e581dbf24 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -125,7 +125,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); int warp_size = ret; - int max_vthread_extent = warp_size / 4; + int max_vthread_extent = std::max(1, warp_size / 4); + return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, max_threads_per_block, max_vthread_extent, warp_size); } else {