Skip to content

Commit

Permalink
MVKShaderLibrary: Handle specializtion with macros
Browse files Browse the repository at this point in the history
The converted MSL may use macro instead of function constants to realize
spirv specialization constant for various reasons (e.g. when the constant
is used as array size).

In this case, we should define the macros at shader compilation stage and
generate different variants of the metal shader library depending on
macro-value mapping to make specializtion work properly when we cannot rely
on metal's specialization. We use the information from SPIRV-Cross to decide
the usage of macros and store variants of MTLLibrary according to macro values.
  • Loading branch information
dboyan committed Feb 10, 2025
1 parent 94ab701 commit ffdccab
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 16 deletions.
32 changes: 27 additions & 5 deletions MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,17 @@ typedef struct MVKMTLFunction {
/** A MVKMTLFunction indicating an invalid MTLFunction. The mtlFunction member is nil. */
const MVKMTLFunction MVKMTLFunctionNull(nil, mvk::SPIRVToMSLConversionResultInfo(), MTLSizeMake(1, 1, 1));

/** Wraps a single MTLLibrary. */
/**
* Wraps a single MTLLibrary or a set of MTLLibrary variants with macro-based specialization
*
* The latter case is used when Vulkan specialization constants cannot be realized with
* Metal function constants. Those specialization constants are turned into macros, and
* when specialized, we have to *recompile* the MTLLibrary from source.
*
* To keep the details transparent to users, when specialization on macro occurs,
* MVKShaderLibrary creates specialized variants (each one also a MVKShaderLibrary) behind
* the scene and cache them in a map according to the macro-value mapping.
*/
class MVKShaderLibrary : public MVKBaseDeviceObject {

public:
Expand Down Expand Up @@ -84,9 +94,14 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {
MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const mvk::SPIRVToMSLConversionResult& conversionResult);

/**
* When specializationMacroDef is not null, creates a macro-specialized library
* specializationMacroDef contains (specialization id, value) mappings, should be sorted
*/
MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const mvk::SPIRVToMSLConversionResultInfo& resultInfo,
const MVKCompressor<std::string> compressedMSL);
const MVKCompressor<std::string> compressedMSL,
const std::vector<std::pair<uint32_t, int32_t>>* specializationMacroDef = nullptr);

MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const void* mslCompiledCodeData,
Expand All @@ -108,15 +123,21 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {
MVKShaderModule* shaderModule);
void handleCompilationError(NSError* err, const char* opDesc);
MTLFunctionConstant* getFunctionConstant(NSArray<MTLFunctionConstant*>* mtlFCs, NSUInteger mtlFCID);
void compileLibrary(const std::string& msl);
void compileLibrary(const std::string& msl,
const std::vector<std::pair<uint32_t, int32_t> >* specializationMacroDef = nullptr);
void compressMSL(const std::string& msl);
void decompressMSL(std::string& msl);
MVKCompressor<std::string>& getCompressedMSL() { return _compressedMSL; }

MVKVulkanAPIDeviceObject* _owner;
id<MTLLibrary> _mtlLibrary;
MVKCompressor<std::string> _compressedMSL;
mvk::SPIRVToMSLConversionResultInfo _shaderConversionResultInfo;
mvk::SPIRVToMSLConversionResultInfo _shaderConversionResultInfo;

/** When true, representing a library created with source, but never specialized */
bool _maySpecializeWithMacro;
/** Can only be populated when _maySpecializeWithMacro is true */
std::map<std::vector<std::pair<uint32_t, int32_t>>, MVKShaderLibrary *> _specializationVariants;
};


Expand Down Expand Up @@ -260,7 +281,8 @@ class MVKShaderLibraryCompiler : public MVKMetalCompiler {
* nanoseconds, an error will be generated and logged, and nil will be returned.
*/
id<MTLLibrary> newMTLLibrary(NSString* mslSourceCode,
const mvk::SPIRVToMSLConversionResultInfo& shaderConversionResults);
const mvk::SPIRVToMSLConversionResultInfo& shaderConversionResults,
const std::vector<std::pair<std::string, int32_t>>& macroDef);


#pragma mark Construction
Expand Down
86 changes: 75 additions & 11 deletions MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,40 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }

id<MTLLibrary> lib = _mtlLibrary;

if (pSpecializationInfo && _maySpecializeWithMacro) {
vector<pair<uint32_t, int32_t>> spec_list;
for (uint32_t specIdx = 0; specIdx < pSpecializationInfo->mapEntryCount; specIdx++) {
const VkSpecializationMapEntry* pMapEntry = &pSpecializationInfo->pMapEntries[specIdx];
uint32_t const_id = pMapEntry->constantID;
int32_t spec_val = *(int32_t *)((char *)pSpecializationInfo->pData + pMapEntry->offset);
if (_shaderConversionResultInfo.specializationMacros.find(const_id) != _shaderConversionResultInfo.specializationMacros.end()) {
spec_list.push_back(make_pair(const_id, spec_val));
}
}

if (!spec_list.empty()) {
// Sort the specialization list before it is used as a key to index the variants
std::sort(spec_list.begin(), spec_list.end());
auto entry = _specializationVariants.find(spec_list);
if (entry != _specializationVariants.end()) {
lib = entry->second->_mtlLibrary;
} else {
MVKShaderLibrary *new_mvklib = new MVKShaderLibrary(_owner, _shaderConversionResultInfo, _compressedMSL, &spec_list);
_specializationVariants[spec_list] = new_mvklib;
lib = new_mvklib->_mtlLibrary;
}
}
}


@synchronized (getMTLDevice()) {
@autoreleasepool {
NSString* mtlFuncName = @(_shaderConversionResultInfo.entryPoint.mtlFunctionName.c_str());

uint64_t startTime = pShaderFeedback ? mvkGetTimestamp() : getPerformanceTimestamp();
id<MTLFunction> mtlFunc = [[_mtlLibrary newFunctionWithName: mtlFuncName] autorelease];
id<MTLFunction> mtlFunc = [[lib newFunctionWithName: mtlFuncName] autorelease];
addPerformanceInterval(getPerformanceStats().shaderCompilation.functionRetrieval, startTime);
if (pShaderFeedback) {
if (mtlFunc) {
Expand Down Expand Up @@ -120,7 +148,7 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
if (pShaderFeedback) {
startTime = mvkGetTimestamp();
}
mtlFunc = [fs.newMTLFunction(_mtlLibrary, mtlFuncName, mtlFCVals) autorelease];
mtlFunc = [fs.newMTLFunction(lib, mtlFuncName, mtlFCVals) autorelease];
if (pShaderFeedback) {
pShaderFeedback->duration += mvkGetElapsedNanoseconds(startTime);
}
Expand Down Expand Up @@ -169,7 +197,8 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
MVKShaderLibrary::MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const SPIRVToMSLConversionResult& conversionResult) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(true) {

_shaderConversionResultInfo = conversionResult.resultInfo;
compressMSL(conversionResult.msl);
Expand All @@ -178,21 +207,36 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const SPIRVToMSLConversionResultInfo& resultInfo,
const MVKCompressor<std::string> compressedMSL) :
const MVKCompressor<std::string> compressedMSL,
const std::vector<std::pair<uint32_t, int32_t> >* specializationMacroDef) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(specializationMacroDef == nullptr) {

_shaderConversionResultInfo = resultInfo;
_compressedMSL = compressedMSL;
string msl;
decompressMSL(msl);
compileLibrary(msl);
compileLibrary(msl, specializationMacroDef);
}

void MVKShaderLibrary::compileLibrary(const string& msl) {
void MVKShaderLibrary::compileLibrary(const string& msl,
const vector<pair<uint32_t, int32_t> >* specializationMacroDef) {
MVKShaderLibraryCompiler* slc = new MVKShaderLibraryCompiler(_owner);
NSString* nsSrc = [[NSString alloc] initWithUTF8String: msl.c_str()]; // temp retained
_mtlLibrary = slc->newMTLLibrary(nsSrc, _shaderConversionResultInfo); // retained

// If specialization macro is used, translate the id to macro name and pass it to compiler
vector<pair<string,int32_t>> macro_def;
if (specializationMacroDef) {
for (auto& def: *specializationMacroDef) {
const auto& macro_name_iter = _shaderConversionResultInfo.specializationMacros.find(def.first);
if (macro_name_iter != _shaderConversionResultInfo.specializationMacros.end()) {
macro_def.push_back(make_pair(macro_name_iter->second, def.second));
}
}
}

_mtlLibrary = slc->newMTLLibrary(nsSrc, _shaderConversionResultInfo, macro_def); // retained
[nsSrc release]; // release temp string
slc->destroy();
}
Expand All @@ -201,7 +245,8 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
const void* mslCompiledCodeData,
size_t mslCompiledCodeLength) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(false) {

uint64_t startTime = getPerformanceTimestamp();
@autoreleasepool {
Expand All @@ -219,7 +264,9 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::MVKShaderLibrary(const MVKShaderLibrary& other) :
MVKBaseDeviceObject(other._device),
_owner(other._owner) {
_owner(other._owner),
_maySpecializeWithMacro(other._maySpecializeWithMacro),
_specializationVariants(other._specializationVariants) {

_mtlLibrary = [other._mtlLibrary retain];
_shaderConversionResultInfo = other._shaderConversionResultInfo;
Expand Down Expand Up @@ -255,6 +302,10 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::~MVKShaderLibrary() {
[_mtlLibrary release];

for (auto& item: _specializationVariants) {
delete item.second;
}
}


Expand Down Expand Up @@ -499,14 +550,27 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
#pragma mark MVKShaderLibraryCompiler

id<MTLLibrary> MVKShaderLibraryCompiler::newMTLLibrary(NSString* mslSourceCode,
const SPIRVToMSLConversionResultInfo& shaderConversionResults) {
const SPIRVToMSLConversionResultInfo& shaderConversionResults,
const vector<pair<string, int32_t>>& specializationMacroDef) {
unique_lock<mutex> lock(_completionLock);

compile(lock, ^{
auto mtlDev = getMTLDevice();
@synchronized (mtlDev) {
auto mtlCompileOptions = getDevice()->getMTLCompileOptions(shaderConversionResults.entryPoint.supportsFastMath,
shaderConversionResults.isPositionInvariant);
if (!specializationMacroDef.empty()) {
size_t macro_count = specializationMacroDef.size();
NSString *macro_names[macro_count];
NSNumber *macro_values[macro_count];
for (uint32_t i = 0; i < specializationMacroDef.size(); i++) {
macro_names[i] = @(specializationMacroDef[i].first.c_str());
macro_values[i] = @(specializationMacroDef[i].second);
}
mtlCompileOptions.preprocessorMacros = [NSDictionary dictionaryWithObjects: macro_values
forKeys: macro_names
count: macro_count];
}
MVKLogInfoIf(getMVKConfig().debugMode, "Compiling Metal shader%s.", mtlCompileOptions.fastMathEnabled ? " with FastMath enabled" : "");
[mtlDev newLibraryWithSource: mslSourceCode
options: mtlCompileOptions
Expand Down

0 comments on commit ffdccab

Please sign in to comment.