From 3ca393faea8b04541fa8dcb8effb33ed2441cf46 Mon Sep 17 00:00:00 2001 From: Tex Riddell Date: Fri, 1 Dec 2023 15:19:16 -0800 Subject: [PATCH] Check for WaveMatrix support; always declare check feature struct for now WaveMatrix wasn't checking WaveMMATier before running tests (assuming it was always supported for SM 6.8). This change fixes this. D3D12_SDK_VERSION check for WAVE_MMA feature structure and enum definitions assumed they will be defined in SDK version 613. This isn't accurate, so this block of local definitions will always be enabled until we have the correct version in the future. --- .../unittests/HLSLExec/ExecutionTest.cpp | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 263a4ca869..19d5f47c96 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -758,6 +758,12 @@ class ExecutionTest { return nullptr; } + if (!DoesDeviceSupportWaveMatrix(pDevice)) { + LogCommentFmt(L"WaveMatrix not supported on this device."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return nullptr; + } + CComPtr pStream; ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream); @@ -1630,6 +1636,19 @@ class ExecutionTest { #endif } + bool DoesDeviceSupportWaveMatrix(ID3D12Device *pDevice) { +#if defined(NTDDI_WIN10_FE) && WDK_NTDDI_VERSION >= NTDDI_WIN10_FE + D3D12_FEATURE_DATA_D3D12_OPTIONS9 O9; + if (FAILED(pDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS9, &O9, sizeof(O9)))) + return false; + return O9.WaveMMATier >= D3D12_WAVE_MMA_TIER_1_0; +#else + UNREFERENCED_PARAMETER(pDevice); + return false; +#endif + } + bool DoesDeviceSupportAdvancedTexOps(ID3D12Device *pDevice) { #if defined(NTDDI_WIN10_CU) && WDK_NTDDI_VERSION >= NTDDI_WIN10_CU D3D12_FEATURE_DATA_D3D12_OPTIONS14 O14; @@ -9002,7 +9021,8 @@ void LoadStoreMat(int M, int N, bool LEFT, int MEM_TYPE, uint32_t K, uint32_t k, } // define WAVE_MMA types if building with SDK that does not support it yet -#if !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613) +// For now: Force this on, until we know the version. +#if 1 // !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613) typedef enum D3D12_WAVE_MMA_INPUT_DATATYPE { D3D12_WAVE_MMA_INPUT_DATATYPE_INVALID = 0, D3D12_WAVE_MMA_INPUT_DATATYPE_BYTE =