Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix validation of array length. #169

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions source/opcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,12 @@ int32_t spvOpcodeIsConstant(const SpvOp opcode) {
case SpvOpConstant:
case SpvOpConstantComposite:
case SpvOpConstantSampler:
// case SpvOpConstantNull:
case SpvOpConstantNull:
case SpvOpSpecConstantTrue:
case SpvOpSpecConstantFalse:
case SpvOpSpecConstant:
case SpvOpSpecConstantComposite:
// case SpvOpSpecConstantOp:
case SpvOpSpecConstantOp:
return true;
default:
return false;
Expand Down
63 changes: 41 additions & 22 deletions source/validate_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,30 @@ bool idUsage::isValid<SpvOpTypeSampler>(const spv_instruction_t*,
return true;
}

// True if the integer constant is > 0. constWords are words of the
// constant-defining instruction (either OpConstant or
// OpSpecConstant). typeWords are the words of the constant's-type-defining
// OpTypeInt.
bool aboveZero(const std::vector<uint32_t>& constWords,
const std::vector<uint32_t>& typeWords) {
const auto width = typeWords[2];
const bool is_signed = typeWords[3];
if (width == 64) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equality test is not right.
The assembler supports bit widths up to and including 64 bits.
So this code is wrong when using 48 bits, for example. The 48 bit case is tested elsewhere in the assembler and disassembler. You may as well handle that case correctly here too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SPIR-V spec doesn't allow 48-bit integers -- the "Data rules" part of section 2.16.1 allows only 32-bit integers plus whatever's enabled by capabilities. And capabilities only exist for Int64, Int16, and Int8.

if (is_signed) {
int64_t value = constWords[3] | (uint64_t{constWords[4]} << 32);
return value > 0;
} else {
uint64_t value = constWords[3] | (uint64_t{constWords[4]} << 32);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An unsigned value is non-zero if any of its bits are non-zero. So you could have saved some work here. Or you could have been a little more DRY by sharing this computation with the signed case, and then had the signed case do the constructor-cast.

return value > 0;
}
} else { // Per spec, must be 32 bits or less, sign-extended.
if (is_signed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code could be more DRY by putting the constWords[3] value into a variable with a nicer name.

return (constWords[3] > 0) && !(constWords[3] >> 31);
else
return constWords[3] > 0;
}
}

template <>
bool idUsage::isValid<SpvOpTypeArray>(const spv_instruction_t* inst,
const spv_opcode_desc) {
Expand All @@ -278,8 +302,7 @@ bool idUsage::isValid<SpvOpTypeArray>(const spv_instruction_t* inst,
}
auto lengthIndex = 3;
auto length = usedefs_.FindDef(inst->words[lengthIndex]);
if (!length.first || (SpvOpConstant != length.second.opcode &&
SpvOpSpecConstant != length.second.opcode)) {
if (!length.first || !spvOpcodeIsConstant(length.second.opcode)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too permissive given the code that later checks the value.
The later code uses the lengths of the instructions and maps that to an assumption about it being a SpvOpConstant or SpvOpSpecConstant.
Should have a TODO saying the remaining cases should be handled (or punted entirely)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That comment no longer applies now that the later code switches on opcode.

DIAG(lengthIndex) << "OpTypeArray Length <id> '" << inst->words[lengthIndex]
<< "' is not a scalar constant type.";
return false;
Expand All @@ -294,27 +317,23 @@ bool idUsage::isValid<SpvOpTypeArray>(const spv_instruction_t* inst,
<< "' is not a constant integer type.";
return false;
}
if (4 == constInst.size()) {
spvCheck(1 > constInst[3], DIAG(lengthIndex)
<< "OpTypeArray Length <id> '"
<< inst->words[lengthIndex]
<< "' value must be at least 1.";
return false);
} else if (5 == constInst.size()) {
uint64_t value = constInst[3] | ((uint64_t)constInst[4]) << 32;
bool signedness = constResultType.second.words[3] != 0;
if (signedness) {
spvCheck(1 > (int64_t)value, DIAG(lengthIndex)
<< "OpTypeArray Length <id> '"
<< inst->words[lengthIndex]
<< "' value must be at least 1.";
return false);
} else {
spvCheck(1 > value, DIAG(lengthIndex) << "OpTypeArray Length <id> '"
<< inst->words[lengthIndex]
<< "' value must be at least 1.";
return false);

switch (length.second.opcode) {
case SpvOpSpecConstant:
case SpvOpConstant:
if (aboveZero(length.second.words, constResultType.second.words)) break;
// Else fall through!
case SpvOpConstantNull: {
DIAG(lengthIndex) << "OpTypeArray Length <id> '"
<< inst->words[lengthIndex]
<< "' default value must be at least 1.";
return false;
}
case SpvOpSpecConstantOp:
// Assume it's OK, rather than try to evaluate the operation.
break;
default:
assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int");
}
return true;
}
Expand Down
94 changes: 89 additions & 5 deletions test/ValidateID.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.

#include <sstream>
#include <string>

#include "UnitSPIRV.h"
Expand All @@ -36,6 +37,9 @@

namespace {

using std::ostringstream;
using std::string;

class ValidateID : public ::testing::Test {
public:
virtual void TearDown() { spvBinaryDestroy(binary); }
Expand Down Expand Up @@ -78,7 +82,7 @@ const char kOpenCLMemoryModel64[] = R"(
if (error) { \
spvDiagnosticPrint(diagnostic); \
spvDiagnosticDestroy(diagnostic); \
ASSERT_EQ(SPV_SUCCESS, error); \
ASSERT_EQ(SPV_SUCCESS, error) << shader; \
} \
spv_result_t result = spvValidate(context, get_const_binary(), &diagnostic); \
if (SPV_SUCCESS != result) { \
Expand Down Expand Up @@ -371,21 +375,101 @@ TEST_F(ValidateID, OpTypeArrayGood) {
%3 = OpTypeArray %1 %2)";
CHECK(spirv, SPV_SUCCESS);
}

TEST_F(ValidateID, OpTypeArrayElementTypeBad) {
const char* spirv = R"(
%1 = OpTypeInt 32 0
%2 = OpConstant %1 1
%3 = OpTypeArray %2 %2)";
CHECK(spirv, SPV_ERROR_INVALID_ID);
}
TEST_F(ValidateID, OpTypeArrayLengthBad) {

// Signed or unsigned.
enum Signed { kSigned, kUnsigned };

// Creates an assembly snippet declaring OpTypeArray with the given length.
string MakeArrayLength(const string& len, Signed isSigned, int width) {
ostringstream ss;
ss << " %t = OpTypeInt " << width << (isSigned == kSigned ? " 1" : " 0")
<< " %l = OpConstant %t " << len << " %a = OpTypeArray %t %l";
return ss.str();
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test the 48 bit case.

TEST_F(ValidateID, OpTypeArrayLength0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This next run of tests are all invalid cases.
How about some positive test cases for lengths that are constant integers: signed case, and also signed/unsigned wider than 32 bits (e.g. 48 and 64)

CHECK(MakeArrayLength("0", kSigned, 32), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLength0U) {
CHECK(MakeArrayLength("0", kUnsigned, 32), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLengthNegative1) {
CHECK(MakeArrayLength("-1", kSigned, 32), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLengthNegative2) {
CHECK(MakeArrayLength("-2", kSigned, 32), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLengthNegative123) {
CHECK(MakeArrayLength("-123", kSigned, 32), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLengthNegativeMax) {
CHECK(MakeArrayLength("0x80000000", kSigned, 32), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLength64Bit0) {
CHECK(MakeArrayLength("0", kSigned, 64), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLength64Bit0U) {
CHECK(MakeArrayLength("0", kUnsigned, 64), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLength64BitNegative1) {
CHECK(MakeArrayLength("-1", kSigned, 64), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLength64BitNegative2) {
CHECK(MakeArrayLength("-2", kSigned, 64), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLength64BitNegative123) {
CHECK(MakeArrayLength("-123", kSigned, 64), SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLength64BitNegativeMax) {
CHECK(MakeArrayLength("0x8000000000000000", kSigned, 64),
SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLengthNull) {
const char* spirv = R"(
%1 = OpTypeInt 32 0
%2 = OpConstant %1 0
%3 = OpTypeArray %1 %2)";
%i32 = OpTypeInt 32 1
%len = OpConstantNull %i32
%ary = OpTypeArray %i32 %len)";
CHECK(spirv, SPV_ERROR_INVALID_ID);
}

TEST_F(ValidateID, OpTypeArrayLengthSpecConst) {
const char* spirv = R"(
%i32 = OpTypeInt 32 1
%len = OpSpecConstant %i32 2
%ary = OpTypeArray %i32 %len)";
CHECK(spirv, SPV_SUCCESS);
}

TEST_F(ValidateID, OpTypeArrayLengthSpecConstOp) {
const char* spirv = R"(
%i32 = OpTypeInt 32 1
%c1 = OpConstant %i32 1
%c2 = OpConstant %i32 2
%len = OpSpecConstantOp %i32 IAdd %c1 %c2
%ary = OpTypeArray %i32 %len)";
CHECK(spirv, SPV_SUCCESS);
}

TEST_F(ValidateID, OpTypeRuntimeArrayGood) {
const char* spirv = R"(
%1 = OpTypeInt 32 0
Expand Down