Skip to content

Commit

Permalink
nvapi-d3d12: Use NvapiD3d12Device with its cache for OMM calls
Browse files Browse the repository at this point in the history
  • Loading branch information
Saancreed committed Mar 15, 2024
1 parent e67fb34 commit 57766a8
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 76 deletions.
115 changes: 110 additions & 5 deletions src/d3d12/nvapi_d3d12_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,91 @@ namespace dxvk {
return cubinDevice != nullptr;
}

bool NvapiD3d12Device::AreOpacityMicromapsSupported(ID3D12Device* device) {
auto ommDevice = GetOmmDevice(device);
return ommDevice != nullptr;
}

std::optional<NvAPI_Status> NvapiD3d12Device::SetCreatePipelineStateOptions(ID3D12Device5* device, const NVAPI_D3D12_SET_CREATE_PIPELINE_STATE_OPTIONS_PARAMS* params) {
auto ommDevice = GetOmmDevice(device);
if (ommDevice == nullptr)
return std::nullopt;

return static_cast<NvAPI_Status>(ommDevice->SetCreatePipelineStateOptions(params));
}

std::optional<NvAPI_Status> NvapiD3d12Device::CheckDriverMatchingIdentifierEx(ID3D12Device5* device, NVAPI_CHECK_DRIVER_MATCHING_IDENTIFIER_EX_PARAMS* params) {
auto ommDevice = GetOmmDevice(device);
if (ommDevice == nullptr)
return std::nullopt;

return static_cast<NvAPI_Status>(ommDevice->CheckDriverMatchingIdentifierEx(params));
}

std::optional<NvAPI_Status> NvapiD3d12Device::GetRaytracingAccelerationStructurePrebuildInfoEx(ID3D12Device5* device, NVAPI_GET_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO_EX_PARAMS* params) {
auto ommDevice = GetOmmDevice(device);
if (ommDevice == nullptr)
return std::nullopt;

return static_cast<NvAPI_Status>(ommDevice->GetRaytracingAccelerationStructurePrebuildInfoEx(params));
}

std::optional<NvAPI_Status> NvapiD3d12Device::GetRaytracingOpacityMicromapArrayPrebuildInfo(ID3D12Device5* device, NVAPI_GET_RAYTRACING_OPACITY_MICROMAP_ARRAY_PREBUILD_INFO_PARAMS* params) {
auto ommDevice = GetOmmDevice(device);
if (ommDevice == nullptr)
return std::nullopt;

return static_cast<NvAPI_Status>(ommDevice->GetRaytracingOpacityMicromapArrayPrebuildInfo(params));
}

std::optional<NvAPI_Status> NvapiD3d12Device::BuildRaytracingAccelerationStructureEx(ID3D12GraphicsCommandList4* commandList, const NVAPI_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_EX_PARAMS* params) {
auto commandListExt = GetCommandListExt(commandList);
if (!commandListExt.has_value())
return std::nullopt;

auto commandListVer = commandListExt.value();
if (commandListVer.InterfaceVersion < 2)
return std::nullopt;

return static_cast<NvAPI_Status>(commandListVer.CommandListExt->BuildRaytracingAccelerationStructureEx(params));
}

std::optional<NvAPI_Status> NvapiD3d12Device::BuildRaytracingOpacityMicromapArray(ID3D12GraphicsCommandList4* commandList, NVAPI_BUILD_RAYTRACING_OPACITY_MICROMAP_ARRAY_PARAMS* params) {
auto commandListExt = GetCommandListExt(commandList);
if (!commandListExt.has_value())
return std::nullopt;

auto commandListVer = commandListExt.value();
if (commandListVer.InterfaceVersion < 2)
return std::nullopt;

return static_cast<NvAPI_Status>(commandListVer.CommandListExt->BuildRaytracingOpacityMicromapArray(params));
}

std::optional<NvAPI_Status> NvapiD3d12Device::RelocateRaytracingOpacityMicromapArray(ID3D12GraphicsCommandList4* commandList, const NVAPI_RELOCATE_RAYTRACING_OPACITY_MICROMAP_ARRAY_PARAMS* params) {
auto commandListExt = GetCommandListExt(commandList);
if (!commandListExt.has_value())
return std::nullopt;

auto commandListVer = commandListExt.value();
if (commandListVer.InterfaceVersion < 2)
return std::nullopt;

return static_cast<NvAPI_Status>(commandListVer.CommandListExt->RelocateRaytracingOpacityMicromapArray(params));
}

std::optional<NvAPI_Status> NvapiD3d12Device::EmitRaytracingOpacityMicromapArrayPostbuildInfo(ID3D12GraphicsCommandList4* commandList, const NVAPI_EMIT_RAYTRACING_OPACITY_MICROMAP_ARRAY_POSTBUILD_INFO_PARAMS* params) {
auto commandListExt = GetCommandListExt(commandList);
if (!commandListExt.has_value())
return std::nullopt;

auto commandListVer = commandListExt.value();
if (commandListVer.InterfaceVersion < 2)
return std::nullopt;

return static_cast<NvAPI_Status>(commandListVer.CommandListExt->EmitRaytracingOpacityMicromapArrayPostbuildInfo(params));
}

// We are going to have single map for storing devices with extensions D3D12_VK_NVX_BINARY_IMPORT & D3D12_VK_NVX_IMAGE_VIEW_HANDLE.
// These are specific to NVIDIA and both of these extensions goes together.
Com<ID3D12DeviceExt> NvapiD3d12Device::GetCubinDevice(ID3D12Device* device) {
Expand All @@ -136,15 +221,29 @@ namespace dxvk {
if (it != m_cubinDeviceMap.end())
return it->second;

auto cubinDevice = GetDeviceExt(device, D3D12_VK_NVX_BINARY_IMPORT);
auto cubinDevice = GetDeviceExt<ID3D12DeviceExt>(device, D3D12_VK_NVX_BINARY_IMPORT);
if (cubinDevice != nullptr)
m_cubinDeviceMap.emplace(device, cubinDevice.ptr());

return cubinDevice;
}

Com<ID3D12DeviceExt> NvapiD3d12Device::GetDeviceExt(ID3D12Device* device, D3D12_VK_EXTENSION extension) {
Com<ID3D12DeviceExt> deviceExt;
Com<ID3D12DeviceExt1> NvapiD3d12Device::GetOmmDevice(ID3D12Device* device) {
std::scoped_lock lock(m_ommDeviceMutex);
auto it = m_ommDeviceMap.find(device);
if (it != m_ommDeviceMap.end())
return it->second;

auto ommDevice = GetDeviceExt<ID3D12DeviceExt1>(device, D3D12_VK_EXT_OPACITY_MICROMAP);
if (ommDevice != nullptr)
m_ommDeviceMap.emplace(device, ommDevice.ptr());

return ommDevice;
}

template <typename T>
Com<T> NvapiD3d12Device::GetDeviceExt(ID3D12Device* device, D3D12_VK_EXTENSION extension) {
Com<T> deviceExt;
if (FAILED(device->QueryInterface(IID_PPV_ARGS(&deviceExt))))
return nullptr;

Expand Down Expand Up @@ -176,15 +275,21 @@ namespace dxvk {
if (it != m_commandListMap.end())
return it->second;

Com<ID3D12GraphicsCommandListExt2> commandListExt2 = nullptr;
if (SUCCEEDED(commandList->QueryInterface(IID_PPV_ARGS(&commandListExt2)))) {
NvapiD3d12Device::CommandListExtWithVersion cmdListVer{commandListExt2.ptr(), 2};
return std::make_optional(m_commandListMap.emplace(commandList, cmdListVer).first->second);
}

Com<ID3D12GraphicsCommandListExt1> commandListExt1 = nullptr;
if (SUCCEEDED(commandList->QueryInterface(IID_PPV_ARGS(&commandListExt1)))) {
NvapiD3d12Device::CommandListExtWithVersion cmdListVer{commandListExt1.ptr(), 1};
NvapiD3d12Device::CommandListExtWithVersion cmdListVer{reinterpret_cast<ID3D12GraphicsCommandListExt2*>(commandListExt1.ptr()), 1};
return std::make_optional(m_commandListMap.emplace(commandList, cmdListVer).first->second);
}

Com<ID3D12GraphicsCommandListExt> commandListExt = nullptr;
if (SUCCEEDED(commandList->QueryInterface(IID_PPV_ARGS(&commandListExt)))) {
NvapiD3d12Device::CommandListExtWithVersion cmdListVer{reinterpret_cast<ID3D12GraphicsCommandListExt1*>(commandListExt.ptr()), 0};
NvapiD3d12Device::CommandListExtWithVersion cmdListVer{reinterpret_cast<ID3D12GraphicsCommandListExt2*>(commandListExt.ptr()), 0};
return std::make_optional(m_commandListMap.emplace(commandList, cmdListVer).first->second);
}

Expand Down
19 changes: 17 additions & 2 deletions src/d3d12/nvapi_d3d12_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace dxvk {
class NvapiD3d12Device {

struct CommandListExtWithVersion {
ID3D12GraphicsCommandListExt1* CommandListExt;
ID3D12GraphicsCommandListExt2* CommandListExt;
uint32_t InterfaceVersion;
};

Expand All @@ -29,22 +29,37 @@ namespace dxvk {
static bool CaptureUAVInfo(ID3D12Device* device, NVAPI_UAV_INFO* uavInfo);
static bool IsFatbinPTXSupported(ID3D12Device* device);

static bool AreOpacityMicromapsSupported(ID3D12Device* device);
static std::optional<NvAPI_Status> SetCreatePipelineStateOptions(ID3D12Device5* device, const NVAPI_D3D12_SET_CREATE_PIPELINE_STATE_OPTIONS_PARAMS* params);
static std::optional<NvAPI_Status> CheckDriverMatchingIdentifierEx(ID3D12Device5* device, NVAPI_CHECK_DRIVER_MATCHING_IDENTIFIER_EX_PARAMS* params);
static std::optional<NvAPI_Status> GetRaytracingAccelerationStructurePrebuildInfoEx(ID3D12Device5* device, NVAPI_GET_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO_EX_PARAMS* params);
static std::optional<NvAPI_Status> GetRaytracingOpacityMicromapArrayPrebuildInfo(ID3D12Device5* device, NVAPI_GET_RAYTRACING_OPACITY_MICROMAP_ARRAY_PREBUILD_INFO_PARAMS* params);
static std::optional<NvAPI_Status> BuildRaytracingAccelerationStructureEx(ID3D12GraphicsCommandList4* commandList, const NVAPI_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_EX_PARAMS* params);
static std::optional<NvAPI_Status> BuildRaytracingOpacityMicromapArray(ID3D12GraphicsCommandList4* commandList, NVAPI_BUILD_RAYTRACING_OPACITY_MICROMAP_ARRAY_PARAMS* params);
static std::optional<NvAPI_Status> RelocateRaytracingOpacityMicromapArray(ID3D12GraphicsCommandList4* commandList, const NVAPI_RELOCATE_RAYTRACING_OPACITY_MICROMAP_ARRAY_PARAMS* params);
static std::optional<NvAPI_Status> EmitRaytracingOpacityMicromapArrayPostbuildInfo(ID3D12GraphicsCommandList4* commandList, const NVAPI_EMIT_RAYTRACING_OPACITY_MICROMAP_ARRAY_POSTBUILD_INFO_PARAMS* params);

static void ClearCacheMaps();

private:
inline static std::unordered_map<ID3D12Device*, ID3D12DeviceExt1*> m_ommDeviceMap;
inline static std::unordered_map<ID3D12Device*, ID3D12DeviceExt*> m_cubinDeviceMap;
inline static std::unordered_map<ID3D12CommandQueue*, ID3D12CommandQueueExt*> m_commandQueueMap;
inline static std::unordered_map<ID3D12GraphicsCommandList*, CommandListExtWithVersion> m_commandListMap;
inline static std::unordered_map<NVDX_ObjectHandle, NvU32> m_cubinSmemMap;

inline static std::mutex m_commandListMutex;
inline static std::mutex m_commandQueueMutex;
inline static std::mutex m_ommDeviceMutex;
inline static std::mutex m_cubinDeviceMutex;
inline static std::mutex m_cubinSmemMutex;

[[nodiscard]] static Com<ID3D12DeviceExt1> GetOmmDevice(ID3D12Device* device);
[[nodiscard]] static Com<ID3D12DeviceExt> GetCubinDevice(ID3D12Device* device);
[[nodiscard]] static Com<ID3D12DeviceExt> GetDeviceExt(ID3D12Device* device, D3D12_VK_EXTENSION extension);
[[nodiscard]] static Com<ID3D12CommandQueueExt> GetCommandQueueExt(ID3D12CommandQueue* commandQueue);
[[nodiscard]] static std::optional<CommandListExtWithVersion> GetCommandListExt(ID3D12GraphicsCommandList* commandList);

template <typename T>
[[nodiscard]] static Com<T> GetDeviceExt(ID3D12Device* device, D3D12_VK_EXTENSION extension);
};
}
Loading

0 comments on commit 57766a8

Please sign in to comment.