diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h index 7de0750f3b..57515a2b1d 100644 --- a/include/spirv-tools/libspirv.h +++ b/include/spirv-tools/libspirv.h @@ -383,10 +383,11 @@ spv_result_t spvBinaryToText(const spv_const_context context, // pointer. void spvBinaryDestroy(spv_binary binary); -// Validates a SPIR-V binary for correctness. +// Validates a SPIR-V binary for correctness. Any errors will be written into +// *diagnostic. spv_result_t spvValidate(const spv_const_context context, const spv_const_binary binary, - spv_diagnostic* pDiagnostic); + spv_diagnostic* diagnostic); // Creates a diagnostic object. The position parameter specifies the location in // the text/binary stream. The message parameter, copied into the diagnostic diff --git a/source/opcode.cpp b/source/opcode.cpp index 2421610a59..61ba9a4762 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -448,13 +448,12 @@ int32_t spvOpcodeIsConstant(const SpvOp opcode) { case SpvOpConstant: case SpvOpConstantComposite: case SpvOpConstantSampler: - // case SpvOpConstantNull: case SpvOpConstantNull: case SpvOpSpecConstantTrue: case SpvOpSpecConstantFalse: case SpvOpSpecConstant: case SpvOpSpecConstantComposite: - // case SpvOpSpecConstantOp: + case SpvOpSpecConstantOp: return true; default: return false; diff --git a/source/validate_id.cpp b/source/validate_id.cpp index 513f2575cc..4efcc7efce 100644 --- a/source/validate_id.cpp +++ b/source/validate_id.cpp @@ -264,6 +264,26 @@ bool idUsage::isValid(const spv_instruction_t*, return true; } +// True if the integer constant is > 0. constWords are words of the +// constant-defining instruction (either OpConstant or +// OpSpecConstant). typeWords are the words of the constant's-type-defining +// OpTypeInt. +bool aboveZero(const std::vector& constWords, + const std::vector& typeWords) { + const uint32_t width = typeWords[2]; + const bool is_signed = typeWords[3]; + const uint32_t loWord = constWords[3]; + if (width > 32) { + // The spec currently doesn't allow integers wider than 64 bits. + const uint32_t hiWord = constWords[4]; // Must exist, per spec. + if (is_signed && (hiWord >> 31)) return false; + return loWord | hiWord; + } else { + if (is_signed && (loWord >> 31)) return false; + return loWord > 0; + } +} + template <> bool idUsage::isValid(const spv_instruction_t* inst, const spv_opcode_desc) { @@ -278,8 +298,7 @@ bool idUsage::isValid(const spv_instruction_t* inst, } auto lengthIndex = 3; auto length = usedefs_.FindDef(inst->words[lengthIndex]); - if (!length.first || (SpvOpConstant != length.second.opcode && - SpvOpSpecConstant != length.second.opcode)) { + if (!length.first || !spvOpcodeIsConstant(length.second.opcode)) { DIAG(lengthIndex) << "OpTypeArray Length '" << inst->words[lengthIndex] << "' is not a scalar constant type."; return false; @@ -294,27 +313,23 @@ bool idUsage::isValid(const spv_instruction_t* inst, << "' is not a constant integer type."; return false; } - if (4 == constInst.size()) { - spvCheck(1 > constInst[3], DIAG(lengthIndex) - << "OpTypeArray Length '" - << inst->words[lengthIndex] - << "' value must be at least 1."; - return false); - } else if (5 == constInst.size()) { - uint64_t value = constInst[3] | ((uint64_t)constInst[4]) << 32; - bool signedness = constResultType.second.words[3] != 0; - if (signedness) { - spvCheck(1 > (int64_t)value, DIAG(lengthIndex) - << "OpTypeArray Length '" - << inst->words[lengthIndex] - << "' value must be at least 1."; - return false); - } else { - spvCheck(1 > value, DIAG(lengthIndex) << "OpTypeArray Length '" - << inst->words[lengthIndex] - << "' value must be at least 1."; - return false); + + switch (length.second.opcode) { + case SpvOpSpecConstant: + case SpvOpConstant: + if (aboveZero(length.second.words, constResultType.second.words)) break; + // Else fall through! + case SpvOpConstantNull: { + DIAG(lengthIndex) << "OpTypeArray Length '" + << inst->words[lengthIndex] + << "' default value must be at least 1."; + return false; } + case SpvOpSpecConstantOp: + // Assume it's OK, rather than try to evaluate the operation. + break; + default: + assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int"); } return true; } diff --git a/test/ValidateID.cpp b/test/ValidateID.cpp index e29672fb2e..03a7ce3f0f 100644 --- a/test/ValidateID.cpp +++ b/test/ValidateID.cpp @@ -24,9 +24,10 @@ // TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE // MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. +#include #include -#include "UnitSPIRV.h" +#include "TestFixture.h" // NOTE: The tests in this file are ONLY testing ID usage, there for the input // SPIR-V does not follow the logical layout rules from the spec in all cases in @@ -36,6 +37,11 @@ namespace { +using ::testing::ValuesIn; +using std::ostringstream; +using std::string; +using std::vector; + class ValidateID : public ::testing::Test { public: virtual void TearDown() { spvBinaryDestroy(binary); } @@ -69,6 +75,8 @@ const char kOpenCLMemoryModel64[] = R"( OpMemoryModel Physical64 OpenCL )"; +// TODO(dekimir): this can be removed by adding a method to ValidateID akin to +// OpTypeArrayLengthTest::Val(). #define CHECK(str, expected) \ spv_diagnostic diagnostic; \ spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_0); \ @@ -78,7 +86,7 @@ const char kOpenCLMemoryModel64[] = R"( if (error) { \ spvDiagnosticPrint(diagnostic); \ spvDiagnosticDestroy(diagnostic); \ - ASSERT_EQ(SPV_SUCCESS, error); \ + ASSERT_EQ(SPV_SUCCESS, error) << shader; \ } \ spv_result_t result = spvValidate(context, get_const_binary(), &diagnostic); \ if (SPV_SUCCESS != result) { \ @@ -371,6 +379,7 @@ TEST_F(ValidateID, OpTypeArrayGood) { %3 = OpTypeArray %1 %2)"; CHECK(spirv, SPV_SUCCESS); } + TEST_F(ValidateID, OpTypeArrayElementTypeBad) { const char* spirv = R"( %1 = OpTypeInt 32 0 @@ -378,12 +387,114 @@ TEST_F(ValidateID, OpTypeArrayElementTypeBad) { %3 = OpTypeArray %2 %2)"; CHECK(spirv, SPV_ERROR_INVALID_ID); } -TEST_F(ValidateID, OpTypeArrayLengthBad) { + +// Signed or unsigned. +enum Signed { kSigned, kUnsigned }; + +// Creates an assembly snippet declaring OpTypeArray with the given length. +string MakeArrayLength(const string& len, Signed isSigned, int width) { + ostringstream ss; + ss << kGLSL450MemoryModel; + ss << " %t = OpTypeInt " << width << (isSigned == kSigned ? " 1" : " 0") + << " %l = OpConstant %t " << len << " %a = OpTypeArray %t %l"; + return ss.str(); +} + +// Tests OpTypeArray. Parameter is the width (in bits) of the array-length's +// type. +class OpTypeArrayLengthTest + : public spvtest::TextToBinaryTestBase<::testing::TestWithParam> { + protected: + OpTypeArrayLengthTest() + : position_{0, 0, 0}, diagnostic_(spvDiagnosticCreate(&position_, "")) {} + + ~OpTypeArrayLengthTest() { spvDiagnosticDestroy(diagnostic_); } + + // Runs spvValidate() on v, printing any errors via spvDiagnosticPrint(). + spv_result_t Val(const SpirvVector& v) { + spv_const_binary_t cbinary{v.data(), v.size()}; + const auto status = spvValidate(context, &cbinary, &diagnostic_); + if (status != SPV_SUCCESS) { + spvDiagnosticPrint(diagnostic_); + } + return status; + } + + private: + spv_position_t position_; // For creating diagnostic_. + spv_diagnostic diagnostic_; +}; + +TEST_P(OpTypeArrayLengthTest, LengthPositive) { + const int width = GetParam(); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("1", kSigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("1", kUnsigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("2", kSigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("2", kUnsigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("55", kSigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("55", kUnsigned, width)))); + const string fpad(width / 4 - 1, 'F'); + EXPECT_EQ( + SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width)))); + EXPECT_EQ(SPV_SUCCESS, Val(CompileSuccessfully( + MakeArrayLength("0xF" + fpad, kUnsigned, width)))); +} + +TEST_P(OpTypeArrayLengthTest, LengthZero) { + const int width = GetParam(); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("0", kSigned, width)))); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("0", kUnsigned, width)))); +} + +TEST_P(OpTypeArrayLengthTest, LengthNegative) { + const int width = GetParam(); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-1", kSigned, width)))); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-2", kSigned, width)))); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width)))); + const string neg_max = "0x8" + string(width / 4 - 1, '0'); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width)))); +} + +INSTANTIATE_TEST_CASE_P(Widths, OpTypeArrayLengthTest, + ValuesIn(vector{8, 16, 32, 48, 64})); + +TEST_F(ValidateID, OpTypeArrayLengthNull) { + const char* spirv = R"( +%i32 = OpTypeInt 32 1 +%len = OpConstantNull %i32 +%ary = OpTypeArray %i32 %len)"; + CHECK(spirv, SPV_ERROR_INVALID_ID); +} + +TEST_F(ValidateID, OpTypeArrayLengthSpecConst) { + const char* spirv = R"( +%i32 = OpTypeInt 32 1 +%len = OpSpecConstant %i32 2 +%ary = OpTypeArray %i32 %len)"; + CHECK(spirv, SPV_SUCCESS); +} + +TEST_F(ValidateID, OpTypeArrayLengthSpecConstOp) { const char* spirv = R"( -%1 = OpTypeInt 32 0 -%2 = OpConstant %1 0 -%3 = OpTypeArray %1 %2)"; - CHECK(spirv, SPV_ERROR_INVALID_ID); +%i32 = OpTypeInt 32 1 +%c1 = OpConstant %i32 1 +%c2 = OpConstant %i32 2 +%len = OpSpecConstantOp %i32 IAdd %c1 %c2 +%ary = OpTypeArray %i32 %len)"; + CHECK(spirv, SPV_SUCCESS); } TEST_F(ValidateID, OpTypeRuntimeArrayGood) {