Skip to content

Commit

Permalink
MVKShaderLibrary: Handle specializtion with macros
Browse files Browse the repository at this point in the history
The converted MSL may use macro instead of function constants to realize
spirv specialization constant for various reasons (e.g. when the constant
is used as array size).

In this case, we should define the macros at shader compilation stage and
generate different variants of the metal shader library depending on
macro-value mapping to make specializtion work properly when we cannot rely
on metal's specialization. We use the information from SPIRV-Cross to decide
the usage of macros and store variants of MTLLibrary according to macro values.
  • Loading branch information
dboyan committed Feb 17, 2025
1 parent 70a3a15 commit bfb76a6
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 25 deletions.
54 changes: 49 additions & 5 deletions MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,38 @@ typedef struct MVKMTLFunction {
/** A MVKMTLFunction indicating an invalid MTLFunction. The mtlFunction member is nil. */
const MVKMTLFunction MVKMTLFunctionNull(nil, mvk::SPIRVToMSLConversionResultInfo(), MTLSizeMake(1, 1, 1));

/** Wraps a single MTLLibrary. */
typedef struct MVKShaderMacroValue {
union {
int8_t si8;
uint8_t ui8;
int16_t si16;
uint16_t ui16;
int32_t si32;
uint32_t ui32;
int64_t si64;
uint64_t ui64;
float f32;
double f64;
} value;
size_t size;

inline bool operator<(const MVKShaderMacroValue& other) const {
return value.ui64 < other.value.ui64 ||
(value.ui64 == other.value.ui64 && size < other.size);
}
} MVKShaderMacroValue;

/**
* Wraps a single MTLLibrary or a set of MTLLibrary variants with macro-based specialization
*
* The latter case is used when Vulkan specialization constants cannot be realized with
* Metal function constants. Those specialization constants are turned into macros, and
* when specialized, we have to *recompile* the MTLLibrary from source.
*
* To keep the details transparent to users, when specialization on macro occurs,
* MVKShaderLibrary creates specialized variants (each one also a MVKShaderLibrary) behind
* the scene and cache them in a map according to the macro-value mapping.
*/
class MVKShaderLibrary : public MVKBaseDeviceObject {

public:
Expand Down Expand Up @@ -84,9 +115,14 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {
MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const mvk::SPIRVToMSLConversionResult& conversionResult);

/**
* When specializationMacroDef is not null, creates a macro-specialized library
* specializationMacroDef contains (specialization id, value) mappings, should be sorted
*/
MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const mvk::SPIRVToMSLConversionResultInfo& resultInfo,
const MVKCompressor<std::string> compressedMSL);
const MVKCompressor<std::string> compressedMSL,
const std::vector<std::pair<uint32_t, MVKShaderMacroValue>>* specializationMacroDef = nullptr);

MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const void* mslCompiledCodeData,
Expand All @@ -108,15 +144,21 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {
MVKShaderModule* shaderModule);
void handleCompilationError(NSError* err, const char* opDesc);
MTLFunctionConstant* getFunctionConstant(NSArray<MTLFunctionConstant*>* mtlFCs, NSUInteger mtlFCID);
void compileLibrary(const std::string& msl);
void compileLibrary(const std::string& msl,
const std::vector<std::pair<uint32_t, MVKShaderMacroValue> >* specializationMacroDef = nullptr);
void compressMSL(const std::string& msl);
void decompressMSL(std::string& msl);
MVKCompressor<std::string>& getCompressedMSL() { return _compressedMSL; }

MVKVulkanAPIDeviceObject* _owner;
id<MTLLibrary> _mtlLibrary;
MVKCompressor<std::string> _compressedMSL;
mvk::SPIRVToMSLConversionResultInfo _shaderConversionResultInfo;
mvk::SPIRVToMSLConversionResultInfo _shaderConversionResultInfo;

/** When true, representing a library created with source, but never specialized */
bool _maySpecializeWithMacro;
/** Can only be populated when _maySpecializeWithMacro is true */
std::map<std::vector<std::pair<uint32_t, MVKShaderMacroValue>>, MVKShaderLibrary *> _specializationVariants;
};


Expand Down Expand Up @@ -260,7 +302,8 @@ class MVKShaderLibraryCompiler : public MVKMetalCompiler {
* nanoseconds, an error will be generated and logged, and nil will be returned.
*/
id<MTLLibrary> newMTLLibrary(NSString* mslSourceCode,
const mvk::SPIRVToMSLConversionResultInfo& shaderConversionResults);
const mvk::SPIRVToMSLConversionResultInfo& shaderConversionResults,
const std::vector<std::pair<mvk::MSLSpecializationMacroInfo, MVKShaderMacroValue>>& macroDef);


#pragma mark Construction
Expand All @@ -273,6 +316,7 @@ class MVKShaderLibraryCompiler : public MVKMetalCompiler {
~MVKShaderLibraryCompiler() override;

protected:
NSNumber *getMacroValue(const mvk::MSLSpecializationMacroInfo& info, const MVKShaderMacroValue& value);
bool compileComplete(id<MTLLibrary> mtlLibrary, NSError *error);
void handleError() override;

Expand Down
166 changes: 146 additions & 20 deletions MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,47 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }

id<MTLLibrary> lib = _mtlLibrary;

// If specialization happens on constants mapped to macro, find or compile a library variant
// with proper macro definition instead of the "generic" library
if (pSpecializationInfo && _maySpecializeWithMacro) {
// Create the list of macro-value mapping
vector<pair<uint32_t, MVKShaderMacroValue>> spec_list;
for (uint32_t specIdx = 0; specIdx < pSpecializationInfo->mapEntryCount; specIdx++) {
const VkSpecializationMapEntry* pMapEntry = &pSpecializationInfo->pMapEntries[specIdx];
uint32_t const_id = pMapEntry->constantID;
MVKShaderMacroValue macro_value = {};
size_t size = min(pMapEntry->size, sizeof(macro_value.value));

memcpy(&macro_value.value, (char *)pSpecializationInfo->pData + pMapEntry->offset, size);
macro_value.size = size;
if (_shaderConversionResultInfo.specializationMacros.find(const_id) != _shaderConversionResultInfo.specializationMacros.end()) {
spec_list.push_back(make_pair(const_id, macro_value));
}
}

if (!spec_list.empty()) {
// Sort the specialization list before it is used as a key to index the variants
std::sort(spec_list.begin(), spec_list.end());
auto entry = _specializationVariants.find(spec_list);
if (entry != _specializationVariants.end()) {
lib = entry->second->_mtlLibrary;
} else {
MVKShaderLibrary *new_mvklib = new MVKShaderLibrary(_owner, _shaderConversionResultInfo, _compressedMSL, &spec_list);
_specializationVariants[spec_list] = new_mvklib;
lib = new_mvklib->_mtlLibrary;
}
}
}


@synchronized (getMTLDevice()) {
@autoreleasepool {
NSString* mtlFuncName = @(_shaderConversionResultInfo.entryPoint.mtlFunctionName.c_str());

uint64_t startTime = pShaderFeedback ? mvkGetTimestamp() : getPerformanceTimestamp();
id<MTLFunction> mtlFunc = [[_mtlLibrary newFunctionWithName: mtlFuncName] autorelease];
id<MTLFunction> mtlFunc = [[lib newFunctionWithName: mtlFuncName] autorelease];
addPerformanceInterval(getPerformanceStats().shaderCompilation.functionRetrieval, startTime);
if (pShaderFeedback) {
if (mtlFunc) {
Expand Down Expand Up @@ -120,7 +155,7 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
if (pShaderFeedback) {
startTime = mvkGetTimestamp();
}
mtlFunc = [fs.newMTLFunction(_mtlLibrary, mtlFuncName, mtlFCVals) autorelease];
mtlFunc = [fs.newMTLFunction(lib, mtlFuncName, mtlFCVals) autorelease];
if (pShaderFeedback) {
pShaderFeedback->duration += mvkGetElapsedNanoseconds(startTime);
}
Expand Down Expand Up @@ -169,7 +204,8 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
MVKShaderLibrary::MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const SPIRVToMSLConversionResult& conversionResult) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(true) {

_shaderConversionResultInfo = conversionResult.resultInfo;
compressMSL(conversionResult.msl);
Expand All @@ -178,21 +214,36 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const SPIRVToMSLConversionResultInfo& resultInfo,
const MVKCompressor<std::string> compressedMSL) :
const MVKCompressor<std::string> compressedMSL,
const vector<pair<uint32_t, MVKShaderMacroValue> >* specializationMacroDef) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(specializationMacroDef == nullptr) {

_shaderConversionResultInfo = resultInfo;
_compressedMSL = compressedMSL;
string msl;
decompressMSL(msl);
compileLibrary(msl);
compileLibrary(msl, specializationMacroDef);
}

void MVKShaderLibrary::compileLibrary(const string& msl) {
void MVKShaderLibrary::compileLibrary(const string& msl,
const vector<pair<uint32_t, MVKShaderMacroValue> >* specializationMacroDef) {
MVKShaderLibraryCompiler* slc = new MVKShaderLibraryCompiler(_owner);
NSString* nsSrc = [[NSString alloc] initWithUTF8String: msl.c_str()]; // temp retained
_mtlLibrary = slc->newMTLLibrary(nsSrc, _shaderConversionResultInfo); // retained

// If specialization macro is used, translate the id to macro information and pass it to compiler
vector<pair<MSLSpecializationMacroInfo, MVKShaderMacroValue>> macro_def;
if (specializationMacroDef) {
for (auto& def: *specializationMacroDef) {
const auto& macro_name_iter = _shaderConversionResultInfo.specializationMacros.find(def.first);
if (macro_name_iter != _shaderConversionResultInfo.specializationMacros.end()) {
macro_def.push_back(make_pair(macro_name_iter->second, def.second));
}
}
}

_mtlLibrary = slc->newMTLLibrary(nsSrc, _shaderConversionResultInfo, macro_def); // retained
[nsSrc release]; // release temp string
slc->destroy();
}
Expand All @@ -201,7 +252,8 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
const void* mslCompiledCodeData,
size_t mslCompiledCodeLength) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(false) {

uint64_t startTime = getPerformanceTimestamp();
@autoreleasepool {
Expand All @@ -219,7 +271,9 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::MVKShaderLibrary(const MVKShaderLibrary& other) :
MVKBaseDeviceObject(other._device),
_owner(other._owner) {
_owner(other._owner),
_maySpecializeWithMacro(other._maySpecializeWithMacro),
_specializationVariants(other._specializationVariants) {

_mtlLibrary = [other._mtlLibrary retain];
_shaderConversionResultInfo = other._shaderConversionResultInfo;
Expand Down Expand Up @@ -255,6 +309,10 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::~MVKShaderLibrary() {
[_mtlLibrary release];

for (auto& item: _specializationVariants) {
delete item.second;
}
}


Expand Down Expand Up @@ -499,27 +557,95 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
#pragma mark MVKShaderLibraryCompiler

id<MTLLibrary> MVKShaderLibraryCompiler::newMTLLibrary(NSString* mslSourceCode,
const SPIRVToMSLConversionResultInfo& shaderConversionResults) {
const SPIRVToMSLConversionResultInfo& shaderConversionResults,
const vector<pair<MSLSpecializationMacroInfo, MVKShaderMacroValue>>& specializationMacroDef) {
unique_lock<mutex> lock(_completionLock);

compile(lock, ^{
auto mtlDev = getMTLDevice();
@synchronized (mtlDev) {
auto mtlCompileOptions = getDevice()->getMTLCompileOptions(shaderConversionResults.entryPoint.supportsFastMath,
shaderConversionResults.isPositionInvariant);
MVKLogInfoIf(getMVKConfig().debugMode, "Compiling Metal shader%s.", mtlCompileOptions.fastMathEnabled ? " with FastMath enabled" : "");
[mtlDev newLibraryWithSource: mslSourceCode
options: mtlCompileOptions
completionHandler: ^(id<MTLLibrary> mtlLib, NSError* error) {
bool isLate = compileComplete(mtlLib, error);
if (isLate) { destroy(); }
}];
@autoreleasepool {
auto mtlCompileOptions = getDevice()->getMTLCompileOptions(shaderConversionResults.entryPoint.supportsFastMath,
shaderConversionResults.isPositionInvariant);
if (!specializationMacroDef.empty()) {
size_t macro_count = specializationMacroDef.size();
NSString *macro_names[macro_count];
NSNumber *macro_values[macro_count];
for (uint32_t i = 0; i < specializationMacroDef.size(); i++) {
macro_names[i] = @(specializationMacroDef[i].first.name.c_str());
macro_values[i] = getMacroValue(specializationMacroDef[i].first, specializationMacroDef[i].second);
}
mtlCompileOptions.preprocessorMacros = [NSDictionary dictionaryWithObjects: macro_values
forKeys: macro_names
count: macro_count];
}
MVKLogInfoIf(getMVKConfig().debugMode, "Compiling Metal shader%s.", mtlCompileOptions.fastMathEnabled ? " with FastMath enabled" : "");
[mtlDev newLibraryWithSource: mslSourceCode
options: mtlCompileOptions
completionHandler: ^(id<MTLLibrary> mtlLib, NSError* error) {
bool isLate = compileComplete(mtlLib, error);
if (isLate) { destroy(); }
}];
}
}
});

return [_mtlLibrary retain];
}

NSNumber *MVKShaderLibraryCompiler::getMacroValue(const MSLSpecializationMacroInfo& info,
const MVKShaderMacroValue& value) {
NSNumber *result;

if (info.isFloat) {
if (value.size == sizeof(double)) {
result = [NSNumber numberWithDouble: value.value.f64];
} else {
result = [NSNumber numberWithFloat: value.value.f32];
}
} else {
if (info.isSigned) {
switch (value.size) {
case 1:
result = [NSNumber numberWithChar: value.value.si8];
break;
case 2:
result = [NSNumber numberWithShort: value.value.si16];
break;
case 4:
result = [NSNumber numberWithInt: value.value.si32];
break;
case 8:
result = [NSNumber numberWithLongLong: value.value.si64];
break;
default:
result = [NSNumber numberWithInt: value.value.si32];
break;
}
} else {
switch (value.size) {
case 1:
result = [NSNumber numberWithUnsignedChar: value.value.ui8];
break;
case 2:
result = [NSNumber numberWithUnsignedShort: value.value.ui16];
break;
case 4:
result = [NSNumber numberWithUnsignedInt: value.value.ui32];
break;
case 8:
result = [NSNumber numberWithUnsignedLongLong: value.value.ui64];
break;
default:
result = [NSNumber numberWithUnsignedInt: value.value.ui32];
break;
}
}
}

return result;
}

void MVKShaderLibraryCompiler::handleError() {
if (_mtlLibrary) {
MVKLogInfo("%s compilation succeeded with warnings (Error code %li):\n%s", _compilerType.c_str(),
Expand Down

0 comments on commit bfb76a6

Please sign in to comment.