diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 263a4ca869..19d5f47c96 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -758,6 +758,12 @@ class ExecutionTest { return nullptr; } + if (!DoesDeviceSupportWaveMatrix(pDevice)) { + LogCommentFmt(L"WaveMatrix not supported on this device."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return nullptr; + } + CComPtr pStream; ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream); @@ -1630,6 +1636,19 @@ class ExecutionTest { #endif } + bool DoesDeviceSupportWaveMatrix(ID3D12Device *pDevice) { +#if defined(NTDDI_WIN10_FE) && WDK_NTDDI_VERSION >= NTDDI_WIN10_FE + D3D12_FEATURE_DATA_D3D12_OPTIONS9 O9; + if (FAILED(pDevice->CheckFeatureSupport( + (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS9, &O9, sizeof(O9)))) + return false; + return O9.WaveMMATier >= D3D12_WAVE_MMA_TIER_1_0; +#else + UNREFERENCED_PARAMETER(pDevice); + return false; +#endif + } + bool DoesDeviceSupportAdvancedTexOps(ID3D12Device *pDevice) { #if defined(NTDDI_WIN10_CU) && WDK_NTDDI_VERSION >= NTDDI_WIN10_CU D3D12_FEATURE_DATA_D3D12_OPTIONS14 O14; @@ -9002,7 +9021,8 @@ void LoadStoreMat(int M, int N, bool LEFT, int MEM_TYPE, uint32_t K, uint32_t k, } // define WAVE_MMA types if building with SDK that does not support it yet -#if !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613) +// For now: Force this on, until we know the version. +#if 1 // !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613) typedef enum D3D12_WAVE_MMA_INPUT_DATATYPE { D3D12_WAVE_MMA_INPUT_DATATYPE_INVALID = 0, D3D12_WAVE_MMA_INPUT_DATATYPE_BYTE =