diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 0e2cc04347ce..d24bf3c02186 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -40,11 +40,10 @@ class SPIRVTools { ~SPIRVTools() { spvContextDestroy(ctx_); } std::string BinaryToText(const std::vector& bin) { spv_text text = nullptr; - spv_diagnostic diagnostic; + spv_diagnostic diagnostic = nullptr; spv_const_binary_t spv_bin{bin.data(), bin.size()}; - spv_result_t res; - res = + spv_result_t res = spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, &text, &diagnostic); @@ -53,12 +52,25 @@ class SPIRVTools { << " column=" << diagnostic->position.column << " index=" << diagnostic->position.index << " error:" << diagnostic->error; + spvDiagnosticDestroy(diagnostic); std::string ret(text->str); spvTextDestroy(text); return ret; } + void ValidateShader(const std::vector& bin) { + spv_const_binary_t spv_bin{bin.data(), bin.size()}; + + spv_diagnostic diagnostic = nullptr; + spv_result_t res = spvValidate(ctx_, &spv_bin, &diagnostic); + + ICHECK_EQ(res, SPV_SUCCESS) << " index=" << diagnostic->position.index + << " error:" << diagnostic->error; + + spvDiagnosticDestroy(diagnostic); + } + private: spv_context ctx_; }; @@ -92,6 +104,8 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) VulkanShader shader = cg.BuildFunction(f, entry); + spirv_tools.ValidateShader(shader.data); + if (webgpu_restriction) { for (auto param : f->params) { ICHECK(param.dtype().is_handle()) << "WebGPU does not yet support non-buffer arguments";