Skip to content

Commit

Permalink
Fix validation of array length.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dejan Mircevski committed Apr 4, 2016
1 parent 6fa3f8a commit 3fb2676
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 33 deletions.
5 changes: 3 additions & 2 deletions include/spirv-tools/libspirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,11 @@ spv_result_t spvBinaryToText(const spv_const_context context,
// pointer.
void spvBinaryDestroy(spv_binary binary);

// Validates a SPIR-V binary for correctness.
// Validates a SPIR-V binary for correctness. Any errors will be written into
// *diagnostic.
spv_result_t spvValidate(const spv_const_context context,
const spv_const_binary binary,
spv_diagnostic* pDiagnostic);
spv_diagnostic* diagnostic);

// Creates a diagnostic object. The position parameter specifies the location in
// the text/binary stream. The message parameter, copied into the diagnostic
Expand Down
3 changes: 1 addition & 2 deletions source/opcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,12 @@ int32_t spvOpcodeIsConstant(const SpvOp opcode) {
case SpvOpConstant:
case SpvOpConstantComposite:
case SpvOpConstantSampler:
// case SpvOpConstantNull:
case SpvOpConstantNull:
case SpvOpSpecConstantTrue:
case SpvOpSpecConstantFalse:
case SpvOpSpecConstant:
case SpvOpSpecConstantComposite:
// case SpvOpSpecConstantOp:
case SpvOpSpecConstantOp:
return true;
default:
return false;
Expand Down
59 changes: 37 additions & 22 deletions source/validate_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,26 @@ bool idUsage::isValid<SpvOpTypeSampler>(const spv_instruction_t*,
return true;
}

// True if the integer constant is > 0. constWords are words of the
// constant-defining instruction (either OpConstant or
// OpSpecConstant). typeWords are the words of the constant's-type-defining
// OpTypeInt.
bool aboveZero(const std::vector<uint32_t>& constWords,
const std::vector<uint32_t>& typeWords) {
const uint32_t width = typeWords[2];
const bool is_signed = typeWords[3];
const uint32_t loWord = constWords[3];
if (width > 32) {
// The spec currently doesn't allow integers wider than 64 bits.
const uint32_t hiWord = constWords[4]; // Must exist, per spec.
if (is_signed && (hiWord >> 31)) return false;
return loWord | hiWord;
} else {
if (is_signed && (loWord >> 31)) return false;
return loWord > 0;
}
}

template <>
bool idUsage::isValid<SpvOpTypeArray>(const spv_instruction_t* inst,
const spv_opcode_desc) {
Expand All @@ -278,8 +298,7 @@ bool idUsage::isValid<SpvOpTypeArray>(const spv_instruction_t* inst,
}
auto lengthIndex = 3;
auto length = usedefs_.FindDef(inst->words[lengthIndex]);
if (!length.first || (SpvOpConstant != length.second.opcode &&
SpvOpSpecConstant != length.second.opcode)) {
if (!length.first || !spvOpcodeIsConstant(length.second.opcode)) {
DIAG(lengthIndex) << "OpTypeArray Length <id> '" << inst->words[lengthIndex]
<< "' is not a scalar constant type.";
return false;
Expand All @@ -294,27 +313,23 @@ bool idUsage::isValid<SpvOpTypeArray>(const spv_instruction_t* inst,
<< "' is not a constant integer type.";
return false;
}
if (4 == constInst.size()) {
spvCheck(1 > constInst[3], DIAG(lengthIndex)
<< "OpTypeArray Length <id> '"
<< inst->words[lengthIndex]
<< "' value must be at least 1.";
return false);
} else if (5 == constInst.size()) {
uint64_t value = constInst[3] | ((uint64_t)constInst[4]) << 32;
bool signedness = constResultType.second.words[3] != 0;
if (signedness) {
spvCheck(1 > (int64_t)value, DIAG(lengthIndex)
<< "OpTypeArray Length <id> '"
<< inst->words[lengthIndex]
<< "' value must be at least 1.";
return false);
} else {
spvCheck(1 > value, DIAG(lengthIndex) << "OpTypeArray Length <id> '"
<< inst->words[lengthIndex]
<< "' value must be at least 1.";
return false);

switch (length.second.opcode) {
case SpvOpSpecConstant:
case SpvOpConstant:
if (aboveZero(length.second.words, constResultType.second.words)) break;
// Else fall through!
case SpvOpConstantNull: {
DIAG(lengthIndex) << "OpTypeArray Length <id> '"
<< inst->words[lengthIndex]
<< "' default value must be at least 1.";
return false;
}
case SpvOpSpecConstantOp:
// Assume it's OK, rather than try to evaluate the operation.
break;
default:
assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int");
}
return true;
}
Expand Down
125 changes: 118 additions & 7 deletions test/ValidateID.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.

#include <sstream>
#include <string>

#include "UnitSPIRV.h"
#include "TestFixture.h"

// NOTE: The tests in this file are ONLY testing ID usage, there for the input
// SPIR-V does not follow the logical layout rules from the spec in all cases in
Expand All @@ -36,6 +37,11 @@

namespace {

using ::testing::ValuesIn;
using std::ostringstream;
using std::string;
using std::vector;

class ValidateID : public ::testing::Test {
public:
virtual void TearDown() { spvBinaryDestroy(binary); }
Expand Down Expand Up @@ -69,6 +75,8 @@ const char kOpenCLMemoryModel64[] = R"(
OpMemoryModel Physical64 OpenCL
)";

// TODO(dekimir): this can be removed by adding a method to ValidateID akin to
// OpTypeArrayLengthTest::Val().
#define CHECK(str, expected) \
spv_diagnostic diagnostic; \
spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_0); \
Expand All @@ -78,7 +86,7 @@ const char kOpenCLMemoryModel64[] = R"(
if (error) { \
spvDiagnosticPrint(diagnostic); \
spvDiagnosticDestroy(diagnostic); \
ASSERT_EQ(SPV_SUCCESS, error); \
ASSERT_EQ(SPV_SUCCESS, error) << shader; \
} \
spv_result_t result = spvValidate(context, get_const_binary(), &diagnostic); \
if (SPV_SUCCESS != result) { \
Expand Down Expand Up @@ -371,19 +379,122 @@ TEST_F(ValidateID, OpTypeArrayGood) {
%3 = OpTypeArray %1 %2)";
CHECK(spirv, SPV_SUCCESS);
}

TEST_F(ValidateID, OpTypeArrayElementTypeBad) {
const char* spirv = R"(
%1 = OpTypeInt 32 0
%2 = OpConstant %1 1
%3 = OpTypeArray %2 %2)";
CHECK(spirv, SPV_ERROR_INVALID_ID);
}
TEST_F(ValidateID, OpTypeArrayLengthBad) {

// Signed or unsigned.
enum Signed { kSigned, kUnsigned };

// Creates an assembly snippet declaring OpTypeArray with the given length.
string MakeArrayLength(const string& len, Signed isSigned, int width) {
ostringstream ss;
ss << kGLSL450MemoryModel;
ss << " %t = OpTypeInt " << width << (isSigned == kSigned ? " 1" : " 0")
<< " %l = OpConstant %t " << len << " %a = OpTypeArray %t %l";
return ss.str();
}

// Tests OpTypeArray. Parameter is the width (in bits) of the array-length's
// type.
class OpTypeArrayLengthTest
: public spvtest::TextToBinaryTestBase<::testing::TestWithParam<int>> {
protected:
OpTypeArrayLengthTest()
: position_{0, 0, 0}, diagnostic_(spvDiagnosticCreate(&position_, "")) {}

~OpTypeArrayLengthTest() { spvDiagnosticDestroy(diagnostic_); }

// Runs spvValidate() on v, printing any errors via spvDiagnosticPrint().
spv_result_t Val(const SpirvVector& v) {
spv_const_binary_t cbinary{v.data(), v.size()};
const auto status = spvValidate(context, &cbinary, &diagnostic_);
if (status != SPV_SUCCESS) {
spvDiagnosticPrint(diagnostic_);
}
return status;
}

private:
spv_position_t position_; // For creating diagnostic_.
spv_diagnostic diagnostic_;
};

TEST_P(OpTypeArrayLengthTest, LengthPositive) {
const int width = GetParam();
EXPECT_EQ(SPV_SUCCESS,
Val(CompileSuccessfully(MakeArrayLength("1", kSigned, width))));
EXPECT_EQ(SPV_SUCCESS,
Val(CompileSuccessfully(MakeArrayLength("1", kUnsigned, width))));
EXPECT_EQ(SPV_SUCCESS,
Val(CompileSuccessfully(MakeArrayLength("2", kSigned, width))));
EXPECT_EQ(SPV_SUCCESS,
Val(CompileSuccessfully(MakeArrayLength("2", kUnsigned, width))));
EXPECT_EQ(SPV_SUCCESS,
Val(CompileSuccessfully(MakeArrayLength("55", kSigned, width))));
EXPECT_EQ(SPV_SUCCESS,
Val(CompileSuccessfully(MakeArrayLength("55", kUnsigned, width))));
const string fpad(width / 4 - 1, 'F');
EXPECT_EQ(
SPV_SUCCESS,
Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width))));
EXPECT_EQ(SPV_SUCCESS, Val(CompileSuccessfully(
MakeArrayLength("0xF" + fpad, kUnsigned, width))));
}

TEST_P(OpTypeArrayLengthTest, LengthZero) {
const int width = GetParam();
EXPECT_EQ(SPV_ERROR_INVALID_ID,
Val(CompileSuccessfully(MakeArrayLength("0", kSigned, width))));
EXPECT_EQ(SPV_ERROR_INVALID_ID,
Val(CompileSuccessfully(MakeArrayLength("0", kUnsigned, width))));
}

TEST_P(OpTypeArrayLengthTest, LengthNegative) {
const int width = GetParam();
EXPECT_EQ(SPV_ERROR_INVALID_ID,
Val(CompileSuccessfully(MakeArrayLength("-1", kSigned, width))));
EXPECT_EQ(SPV_ERROR_INVALID_ID,
Val(CompileSuccessfully(MakeArrayLength("-2", kSigned, width))));
EXPECT_EQ(SPV_ERROR_INVALID_ID,
Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width))));
const string neg_max = "0x8" + string(width / 4 - 1, '0');
EXPECT_EQ(SPV_ERROR_INVALID_ID,
Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width))));
}

INSTANTIATE_TEST_CASE_P(Widths, OpTypeArrayLengthTest,
ValuesIn(vector<int>{8, 16, 32, 48, 64}));

TEST_F(ValidateID, OpTypeArrayLengthNull) {
const char* spirv = R"(
%i32 = OpTypeInt 32 1
%len = OpConstantNull %i32
%ary = OpTypeArray %i32 %len)";
CHECK(spirv, SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLengthSpecConst) {
const char* spirv = R"(
%i32 = OpTypeInt 32 1
%len = OpSpecConstant %i32 2
%ary = OpTypeArray %i32 %len)";
CHECK(spirv, SPV_SUCCESS);
}

TEST_F(ValidateID, OpTypeArrayLengthSpecConstOp) {
const char* spirv = R"(
%1 = OpTypeInt 32 0
%2 = OpConstant %1 0
%3 = OpTypeArray %1 %2)";
CHECK(spirv, SPV_ERROR_INVALID_ID);
%i32 = OpTypeInt 32 1
%c1 = OpConstant %i32 1
%c2 = OpConstant %i32 2
%len = OpSpecConstantOp %i32 IAdd %c1 %c2
%ary = OpTypeArray %i32 %len)";
CHECK(spirv, SPV_SUCCESS);
}

TEST_F(ValidateID, OpTypeRuntimeArrayGood) {
Expand Down

0 comments on commit 3fb2676

Please sign in to comment.