Skip to content

Commit

Permalink
Check for WaveMatrix support; always declare check feature struct for…
Browse files Browse the repository at this point in the history
… now

WaveMatrix wasn't checking WaveMMATier before running tests (assuming it was always supported for SM 6.8).
This change fixes this.

D3D12_SDK_VERSION check for WAVE_MMA feature structure and enum definitions assumed they will be defined in SDK version 613.
This isn't accurate, so this block of local definitions will always be enabled until we have the correct version in the future.
  • Loading branch information
tex3d committed Dec 1, 2023
1 parent 8a37a92 commit 3ca393f
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tools/clang/unittests/HLSLExec/ExecutionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,12 @@ class ExecutionTest {
return nullptr;
}

if (!DoesDeviceSupportWaveMatrix(pDevice)) {
LogCommentFmt(L"WaveMatrix not supported on this device.");
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
return nullptr;
}

CComPtr<IStream> pStream;
ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);

Expand Down Expand Up @@ -1630,6 +1636,19 @@ class ExecutionTest {
#endif
}

bool DoesDeviceSupportWaveMatrix(ID3D12Device *pDevice) {
#if defined(NTDDI_WIN10_FE) && WDK_NTDDI_VERSION >= NTDDI_WIN10_FE
D3D12_FEATURE_DATA_D3D12_OPTIONS9 O9;
if (FAILED(pDevice->CheckFeatureSupport(
(D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS9, &O9, sizeof(O9))))
return false;
return O9.WaveMMATier >= D3D12_WAVE_MMA_TIER_1_0;
#else
UNREFERENCED_PARAMETER(pDevice);
return false;
#endif
}

bool DoesDeviceSupportAdvancedTexOps(ID3D12Device *pDevice) {
#if defined(NTDDI_WIN10_CU) && WDK_NTDDI_VERSION >= NTDDI_WIN10_CU
D3D12_FEATURE_DATA_D3D12_OPTIONS14 O14;
Expand Down Expand Up @@ -9002,7 +9021,8 @@ void LoadStoreMat(int M, int N, bool LEFT, int MEM_TYPE, uint32_t K, uint32_t k,
}

// define WAVE_MMA types if building with SDK that does not support it yet
#if !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613)
// For now: Force this on, until we know the version.
#if 1 // !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613)
typedef enum D3D12_WAVE_MMA_INPUT_DATATYPE {
D3D12_WAVE_MMA_INPUT_DATATYPE_INVALID = 0,
D3D12_WAVE_MMA_INPUT_DATATYPE_BYTE =
Expand Down

0 comments on commit 3ca393f

Please sign in to comment.