Skip to content

Commit

Permalink
Reflection interface now accepts a DXIL program header. (microsoft#3946)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-yang authored Sep 13, 2021
1 parent 29a5af5 commit 7914122
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 67 deletions.
91 changes: 57 additions & 34 deletions lib/HLSL/DxilContainerReflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class DxilModuleReflection {
void CreateReflectionObjectForResource(DxilResourceBase *R);

HRESULT LoadRDAT(const DxilPartHeader *pPart);
HRESULT LoadModule(const DxilPartHeader *pPart);
HRESULT LoadProgramHeader(const DxilProgramHeader *pProgramHeader);

// Common code
ID3D12ShaderReflectionConstantBuffer* _GetConstantBufferByIndex(UINT Index);
Expand Down Expand Up @@ -180,7 +180,7 @@ class DxilShaderReflection : public DxilModuleReflection, public ID3D12ShaderRef
return hr;
}

HRESULT Load(const DxilPartHeader *pModulePart, const DxilPartHeader *pRDATPart);
HRESULT Load(const DxilProgramHeader *pProgramHeader, const DxilPartHeader *pRDATPart);

// ID3D12ShaderReflection
STDMETHODIMP GetDesc(THIS_ _Out_ D3D12_SHADER_DESC *pDesc);
Expand Down Expand Up @@ -245,7 +245,7 @@ class DxilLibraryReflection : public DxilModuleReflection, public ID3D12LibraryR
return DoBasicQueryInterface<ID3D12LibraryReflection>(this, iid, ppvObject);
}

HRESULT Load(const DxilPartHeader *pModulePart, const DxilPartHeader *pDXILPart);
HRESULT Load(const DxilProgramHeader *pProgramHeader, const DxilPartHeader *pRDATPart);

// ID3D12LibraryReflection
STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_LIBRARY_DESC * pDesc);
Expand All @@ -254,28 +254,69 @@ class DxilLibraryReflection : public DxilModuleReflection, public ID3D12LibraryR
};

namespace hlsl {
HRESULT CreateDxilShaderReflection(const DxilPartHeader *pModulePart, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject) {

HRESULT CreateDxilShaderReflection(const DxilProgramHeader *pProgramHeader, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject) {
if (!ppvObject)
return E_INVALIDARG;
CComPtr<DxilShaderReflection> pReflection = DxilShaderReflection::Alloc(DxcGetThreadMallocNoRef());
IFROOM(pReflection.p);
PublicAPI api = DxilShaderReflection::IIDToAPI(iid);
pReflection->SetPublicAPI(api);
// pRDATPart to be used for transition.
IFR(pReflection->Load(pModulePart, pRDATPart));
IFR(pReflection->Load(pProgramHeader, pRDATPart));
IFR(pReflection.p->QueryInterface(iid, ppvObject));
return S_OK;
}
HRESULT CreateDxilLibraryReflection(const DxilPartHeader *pModulePart, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject) {

HRESULT CreateDxilLibraryReflection(const DxilProgramHeader *pProgramHeader, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject) {
if (!ppvObject)
return E_INVALIDARG;
CComPtr<DxilLibraryReflection> pReflection = DxilLibraryReflection::Alloc(DxcGetThreadMallocNoRef());
IFROOM(pReflection.p);
// pRDATPart used for resource usage per-function.
IFR(pReflection->Load(pModulePart, pRDATPart));
IFR(pReflection->Load(pProgramHeader, pRDATPart));
IFR(pReflection.p->QueryInterface(iid, ppvObject));
return S_OK;
}

HRESULT CreateDxilShaderOrLibraryReflectionFromProgramHeader(const DxilProgramHeader *pProgramHeader, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject) {
// Detect whether library, or if unrecognized program version.
DXIL::ShaderKind SK = GetVersionShaderType(pProgramHeader->ProgramVersion);
if (!(SK < DXIL::ShaderKind::Invalid))
return E_INVALIDARG;
bool bIsLibrary = DXIL::ShaderKind::Library == SK;

if (bIsLibrary) {
IFR(hlsl::CreateDxilLibraryReflection(pProgramHeader, pRDATPart, iid, ppvObject));
} else {
IFR(hlsl::CreateDxilShaderReflection(pProgramHeader, pRDATPart, iid, ppvObject));
}
return S_OK;
}

bool IsValidReflectionModulePart(DxilFourCC fourCC) {
return fourCC == DFCC_DXIL || fourCC == DFCC_ShaderDebugInfoDXIL || fourCC == DFCC_ShaderStatistics;
}

HRESULT CreateDxilShaderOrLibraryReflectionFromModulePart(const DxilPartHeader *pModulePart, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject) {
if (!pModulePart)
return E_INVALIDARG;

if (!IsValidReflectionModulePart((DxilFourCC)pModulePart->PartFourCC))
return E_INVALIDARG;

const DxilProgramHeader *pProgramHeader =
reinterpret_cast<const DxilProgramHeader*>(GetDxilPartData(pModulePart));
if (!IsValidDxilProgramHeader(pProgramHeader, pModulePart->PartSize))
return E_INVALIDARG;

// If bitcode is too small, it's probably been stripped, and we cannot create reflection with it.
if (pModulePart->PartSize - pProgramHeader->BitcodeHeader.BitcodeOffset < 4)
return DXC_E_MISSING_PART;

return CreateDxilShaderOrLibraryReflectionFromProgramHeader(pProgramHeader, pRDATPart, iid, ppvObject);
}

}

_Use_decl_annotations_
Expand Down Expand Up @@ -366,10 +407,8 @@ HRESULT DxilContainerReflection::GetPartReflection(UINT32 idx, REFIID iid, void
if (!IsLoaded()) return E_NOT_VALID_STATE;
if (idx >= m_pHeader->PartCount) return E_BOUNDS;
const DxilPartHeader *pPart = GetDxilContainerPart(m_pHeader, idx);
if (pPart->PartFourCC != DFCC_DXIL && pPart->PartFourCC != DFCC_ShaderDebugInfoDXIL &&
pPart->PartFourCC != DFCC_ShaderStatistics) {
if (!hlsl::IsValidReflectionModulePart((hlsl::DxilFourCC)pPart->PartFourCC))
return E_NOTIMPL;
}

// Use DFCC_ShaderStatistics for reflection instead of DXIL part, until switch
// to using RDAT for reflection instead of module.
Expand All @@ -391,21 +430,10 @@ HRESULT DxilContainerReflection::GetPartReflection(UINT32 idx, REFIID iid, void
}
}

const DxilProgramHeader *pProgramHeader =
reinterpret_cast<const DxilProgramHeader*>(GetDxilPartData(pPart));
if (!IsValidDxilProgramHeader(pProgramHeader, pPart->PartSize)) {
return E_INVALIDARG;
}

DxcThreadMalloc TM(m_pMalloc);
HRESULT hr = S_OK;

DXIL::ShaderKind SK = GetVersionShaderType(pProgramHeader->ProgramVersion);
if (SK == DXIL::ShaderKind::Library) {
IFC(hlsl::CreateDxilLibraryReflection(pPart, pRDATPart, iid, ppvObject));
} else {
IFC(hlsl::CreateDxilShaderReflection(pPart, pRDATPart, iid, ppvObject));
}
IFC(hlsl::CreateDxilShaderOrLibraryReflectionFromModulePart(pPart, pRDATPart, iid, ppvObject));

Cleanup:
return hr;
Expand Down Expand Up @@ -2058,14 +2086,11 @@ HRESULT DxilModuleReflection::LoadRDAT(const DxilPartHeader *pPart) {
return S_OK;
}

HRESULT DxilModuleReflection::LoadModule(const DxilPartHeader *pShaderPart) {
if (pShaderPart == nullptr)
return E_INVALIDARG;
const char *pData = GetDxilPartData(pShaderPart);
HRESULT DxilModuleReflection::LoadProgramHeader(const DxilProgramHeader *pProgramHeader) {
try {
const char *pBitcode;
uint32_t bitcodeLength;
GetDxilProgramBitcode((DxilProgramHeader *)pData, &pBitcode, &bitcodeLength);
GetDxilProgramBitcode((DxilProgramHeader *)pProgramHeader, &pBitcode, &bitcodeLength);
std::unique_ptr<MemoryBuffer> pMemBuffer =
MemoryBuffer::getMemBufferCopy(StringRef(pBitcode, bitcodeLength));
bool bBitcodeLoadError = false;
Expand Down Expand Up @@ -2093,12 +2118,11 @@ HRESULT DxilModuleReflection::LoadModule(const DxilPartHeader *pShaderPart) {
return S_OK;
}
CATCH_CPP_RETURN_HRESULT();
};
}

HRESULT DxilShaderReflection::Load(const DxilPartHeader *pModulePart,
const DxilPartHeader *pRDATPart) {
HRESULT DxilShaderReflection::Load(const DxilProgramHeader *pProgramHeader, const DxilPartHeader *pRDATPart) {
IFR(LoadRDAT(pRDATPart));
IFR(LoadModule(pModulePart));
IFR(LoadProgramHeader(pProgramHeader));

try {
// Set cbuf usage.
Expand Down Expand Up @@ -2744,10 +2768,9 @@ void DxilLibraryReflection::SetCBufferUsage() {

// ID3D12LibraryReflection

HRESULT DxilLibraryReflection::Load(const DxilPartHeader *pModulePart,
const DxilPartHeader *pRDATPart) {
HRESULT DxilLibraryReflection::Load(const DxilProgramHeader *pProgramHeader, const DxilPartHeader *pRDATPart) {
IFR(LoadRDAT(pRDATPart));
IFR(LoadModule(pModulePart));
IFR(LoadProgramHeader(pProgramHeader));

try {
AddResourceDependencies();
Expand Down
40 changes: 7 additions & 33 deletions tools/clang/tools/dxcompiler/dxclibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ using namespace hlsl;
#ifdef _WIN32
// Temporary: Define these here until a better header location is found.
namespace hlsl {
HRESULT CreateDxilShaderReflection(const DxilPartHeader *pModulePart, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject);
HRESULT CreateDxilLibraryReflection(const DxilPartHeader *pModulePart, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject);
HRESULT CreateDxilShaderOrLibraryReflectionFromProgramHeader(const DxilProgramHeader *pProgramHeader, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject);
HRESULT CreateDxilShaderOrLibraryReflectionFromModulePart(const DxilPartHeader *pModulePart, const DxilPartHeader *pRDATPart, REFIID iid, void **ppvObject);
}
#endif

Expand Down Expand Up @@ -401,6 +401,10 @@ class DxcUtils : public IDxcUtils {
pModulePart = pStatsPart ? pStatsPart : pDebugDXILPart ? pDebugDXILPart : pDXILPart;
if (nullptr == pModulePart)
return DXC_E_MISSING_PART;
} else if (hlsl::IsValidDxilProgramHeader((const hlsl::DxilProgramHeader *)pData->Ptr, pData->Size)) {

return hlsl::CreateDxilShaderOrLibraryReflectionFromProgramHeader((const hlsl::DxilProgramHeader *)pData->Ptr, pRDATPart, iid, ppvReflection);

} else {
// Not a container, try a statistics part that holds a valid program part.
// In the future, this will just be the RDAT part.
Expand All @@ -424,37 +428,7 @@ class DxcUtils : public IDxcUtils {
}
}

bool bIsLibrary = false;

if (pModulePart) {
if (pModulePart->PartFourCC != DFCC_DXIL &&
pModulePart->PartFourCC != DFCC_ShaderDebugInfoDXIL &&
pModulePart->PartFourCC != DFCC_ShaderStatistics) {
return E_INVALIDARG;
}
const DxilProgramHeader *pProgramHeader =
reinterpret_cast<const DxilProgramHeader*>(GetDxilPartData(pModulePart));
if (!IsValidDxilProgramHeader(pProgramHeader, pModulePart->PartSize))
return E_INVALIDARG;

// If bitcode is too small, it's probably been stripped, and we cannot create reflection with it.
if (pModulePart->PartSize - pProgramHeader->BitcodeHeader.BitcodeOffset < 4)
return DXC_E_MISSING_PART;

// Detect whether library, or if unrecognized program version.
DXIL::ShaderKind SK = GetVersionShaderType(pProgramHeader->ProgramVersion);
if (!(SK < DXIL::ShaderKind::Invalid))
return E_INVALIDARG;
bIsLibrary = DXIL::ShaderKind::Library == SK;
}

if (bIsLibrary) {
IFR(hlsl::CreateDxilLibraryReflection(pModulePart, pRDATPart, iid, ppvReflection));
} else {
IFR(hlsl::CreateDxilShaderReflection(pModulePart, pRDATPart, iid, ppvReflection));
}

return S_OK;
return hlsl::CreateDxilShaderOrLibraryReflectionFromModulePart(pModulePart, pRDATPart, iid, ppvReflection);
}
CATCH_CPP_RETURN_HRESULT();
#else
Expand Down
61 changes: 61 additions & 0 deletions tools/clang/unittests/HLSL/CompilerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class CompilerTest : public ::testing::Test {
TEST_METHOD(CompileWhenDebugWorksThenStripDebug)
TEST_METHOD(CompileWhenWorksThenAddRemovePrivate)
TEST_METHOD(CompileThenAddCustomDebugName)
TEST_METHOD(CompileThenTestReflectionWithProgramHeader)
TEST_METHOD(CompileThenTestPdbUtils)
TEST_METHOD(CompileThenTestPdbUtilsWarningOpt)
TEST_METHOD(CompileThenTestPdbInPrivate)
Expand Down Expand Up @@ -1631,6 +1632,66 @@ void CompilerTest::TestPdbUtils(bool bSlim, bool bSourceInDebugModule, bool bStr
}
}

TEST_F(CompilerTest, CompileThenTestReflectionWithProgramHeader) {
CComPtr<IDxcCompiler> pCompiler;
CComPtr<IDxcBlobEncoding> pSource;
CComPtr<IDxcOperationResult> pOperationResult;

const char* source = R"x(
cbuffer cb : register(b1) {
float foo;
};
[RootSignature("CBV(b1)")]
float4 main(float a : A) : SV_Target {
return a + foo;
}
)x";
std::string included_File = "#define ZERO 0";

VERIFY_SUCCEEDED(CreateCompiler(&pCompiler));
CreateBlobFromText(source, &pSource);

const WCHAR * args[] = {
L"-Zi",
};

VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"source.hlsl", L"main",
L"ps_6_0", args, _countof(args), nullptr, 0, nullptr, &pOperationResult));

HRESULT CompileStatus = S_OK;
VERIFY_SUCCEEDED(pOperationResult->GetStatus(&CompileStatus));
VERIFY_SUCCEEDED(CompileStatus);

CComPtr<IDxcResult> pResult;
VERIFY_SUCCEEDED(pOperationResult.QueryInterface(&pResult));

CComPtr<IDxcBlob> pPdbBlob;
VERIFY_SUCCEEDED(pResult->GetOutput(DXC_OUT_PDB, IID_PPV_ARGS(&pPdbBlob), nullptr));

CComPtr<IDxcContainerReflection> pContainerReflection;
VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcContainerReflection, &pContainerReflection));

VERIFY_SUCCEEDED(pContainerReflection->Load(pPdbBlob));
UINT32 index = 0;
VERIFY_SUCCEEDED(pContainerReflection->FindFirstPartKind(hlsl::DFCC_ShaderDebugInfoDXIL, &index));

CComPtr<IDxcBlob> pDebugDxilBlob;
VERIFY_SUCCEEDED(pContainerReflection->GetPartContent(index, &pDebugDxilBlob));

CComPtr<IDxcUtils> pUtils;
VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcUtils, &pUtils));

DxcBuffer buf = {};
buf.Ptr = pDebugDxilBlob->GetBufferPointer();
buf.Size = pDebugDxilBlob->GetBufferSize();

CComPtr<ID3D12ShaderReflection> pReflection;
VERIFY_SUCCEEDED(pUtils->CreateReflection(&buf, IID_PPV_ARGS(&pReflection)));

ID3D12ShaderReflectionConstantBuffer *cb = pReflection->GetConstantBufferByName("cb");
VERIFY_IS_TRUE(cb != nullptr);
}

TEST_F(CompilerTest, CompileThenTestPdbUtils) {
if (m_ver.SkipDxilVersion(1, 5)) return;
TestPdbUtils(/*bSlim*/true, /*bSourceInDebugModule*/false, /*strip*/true); // Slim PDB, where source info is stored in its own part, and debug module is NOT present
Expand Down

0 comments on commit 7914122

Please sign in to comment.